Optimisation et différentiation automatique avec Pytorch

Auteur·rice

Tanguy Lefort, Amélie Vernay

Objectifs de ce TP
  • Découverte de Pytorch et de la différentiation automatique
  • Passer des arrays aux tensors, méthode .backward() et module optim

Préambule

Un cas pratique d’optimisation de fonction hautement non convexe est la minimisation du coût d’erreur pendant l’entraînement d’un modèle de décision (référé plus communément dans le langage en “une IA”). Une des grandes librairies permettant de traiter de réseaux de neurones et optimiser leurs paramètres est pytorch.

Introduction à pytorch

Dans un premier temps, importons les packages nécessaires de la manière classique suivante :

# Ceci est un commentaire pour l'import des packages.
import numpy as np  # package de calcul scientifique
import torch  # librairie pytorch
import matplotlib.pyplot as plt  # graphiques
Important

La librairie numpy utilise des arrays. La librairie pytorch utilise des tensors qui sont des array optimisés pour faire du calcul en grande dimension.

Mon premier tensor

Commençons par créer un tenseur de la même façon qu’avec numpy.

data = [[1, 2], [3, 4]]
x_data = torch.tensor(data)
print(x_data, x_data.shape)
tensor([[1, 2],
        [3, 4]]) torch.Size([2, 2])

On peut même transformer un array de numpy en tensor de pytorch:

numpy_array = np.array(data)
print(type(numpy_array))
tensor_from_numpy = torch.from_numpy(numpy_array)
print(type(tensor_from_numpy))
<class 'numpy.ndarray'>
<class 'torch.Tensor'>

Les fonctions np.ones, np.zeros, np.eye ont aussi leurs équivalents en torch avec les mêmes noms.

Question : Initialiser des tensors

Créer un tensor rempli de 1 de taille (20,5) sans utiliser numpy. Afficher la première colonne. Remplacer l’élément (5,2) par un 0. Afficher la taille du tenseur, le type des nombres contenus (.dtype) et faire le produit matriciel @ avec un tensor de taille (5,1) contenant seulement des 2

A = ...
b = ...
print(A[...])
A[...] = ...
print(A ... b)

Graphe computationnel

Mathématiquement, chaque opération d’une fonction peut se voir comme un graphe. Pour calculer le gradient de la fonction, on a souvent recours à la règle de dérivation en chaîne. Prenons un exemple avec la fonction f définie sur [-1,6]\times [1,6]:

f(x,y) = \exp\left(1 - \frac{2\log(y)}{\cos(x)+2}\right)

Question : Visualiser la fonction f

A partir du code suivant, visualiser f:

def f(x, y):
    ...


x = np.arange(-1, 6, 0.1)
y = np.arange(1, 6, 0.1)
X, Y = np.meshgrid(...)
Z = ...

fig = plt.figure(figsize=(12, 10))
ax = plt.axes(projection="3d")
surf = ax.plot_surface(X, Y, Z, cmap=plt.cm.cividis)
ax.set_xlabel("x", labelpad=20)
ax.set_ylabel("y", labelpad=20)
ax.set_zlabel("z", labelpad=20)
ax.view_init(10, 25)  # élévation de 10 degrés et déplacement horizontal de 25 degrés
fig.colorbar(surf, shrink=0.5, aspect=8)
plt.show()
2024-04-08T14:05:00.947004 image/svg+xml Matplotlib v3.5.3, https://matplotlib.org/

Graphe et différentiation

La représentation de la définition de f par un graphe computationnel est :

flowchart LR
  A[x] --> B(a = 2 + cos x);
  C[y] --> D(b = 2 * log y);
  B --> E(c = b / a);
  D --> E;
  E --> F(d = 1 - c);
  F --> G(exp d);
  G --> H{f}

Pour obtenir le gradient, on rétropropage (backpropagation) avec la règle de dérivation en chaîne. Pour cela on calcule toutes les dérivées partielles de gauche à droite dans le graphe computationnel :

\frac{\partial f}{\partial d} = \exp d,\ \frac{\partial d}{\partial c} = -1,\ \frac{\partial c}{\partial b} = \frac{1}{a}, \frac{\partial c}{\partial a} = - \frac{b}{a^2},\ \frac{\partial a}{\partial x} = -\sin x,\ \frac{\partial b}{\partial y} = \frac{2}{y}.

Et ainsi par exemple \frac{\partial f}{\partial x} = \frac{\partial f}{\partial d}\frac{\partial d}{\partial c}\frac{\partial c}{\partial a}\frac{\partial a}{\partial x} = \exp(d) \cdot (-1) \cdot \left(-\frac{b}{a^2}\right)\cdot (-\sin x),

Et on finit en remplaçant pour obtenir la formule dépendante de x. On fera de même pour calculer \frac{\partial f}{\partial y}. En pratique, on ne calcule pas la formule générale, ce qui serait très coûteux, mais on évalue le gradient en le point courant.

Avec pytorch

Les tenseurs peuvent être initialisés avec requires_grad=True pour indiquer que l’on cherchera à optimiser ces paramètres. Si on souhaite ne pas le faire au moment de l’initialisation, on peut plus tard appeler la méthode x.requires_grad_(True) au besoin.

xy = torch.randn(2, requires_grad=True)  # point initial en (x,y) aléatoir
out_1 = 1 - 2 * torch.log(xy[1]) / (torch.cos(xy[0]) + 2)
out_f = torch.exp(out_1)
print(f"Gradient au cours de la chaîne = {out_1.grad_fn}")
print(f"Gradient final = {out_f.grad_fn}")
Gradient au cours de la chaîne = <RsubBackward1 object at 0x7f73b77ffe50>
Gradient final = <ExpBackward0 object at 0x7f73b77ff4f0>

Le gradient n’est pas calculé directement, mais pytorch sauvegarde l’arbre des opérations effectuées. Les fonctions de pytorch ont une méthode associée appelée .backward qui permet de calculer les dérivées partielles.

Pour obtenir les valeurs des gradients, on appelle .backward() sur le tenseur obtenu après avoir appliqué les opérations. Et on les affiche avec l’attribut .grad.

out_f.backward()
print(xy.grad)
tensor([ -0.2038, -22.2255])

Question : Calcul de gradient

Calculer le gradient de la fonction f au point (1.5,2) et au point (0,1)

pt1 = ...
pt2 = ...
print(...)
Gradient en (1.5, 2): tensor([-0.4488, -0.6721])
Gradient en (0, 1): tensor([ 0.0000, -1.8122])

Descente de gradient avec pytorch

Pytorch a un module d’optimisation appelé torch.optim permettant d’appliquer différentes méthodes d’optimisation comme la descente de gradient, Adam,… Pour ce faire, on définit un optimiseur avec ses hyperparamètres. Cet optimiseur doit être réinitialisé à chaque pas à l’aide de la méthode .zero_grad(). Ensuite on calcule le gradient avec la méthode .backward() et finalement on applique le pas de notre méthode avec la méthode .step().

Un exemple avec la fonction f:x\mapsto x^2

x = torch.randn(1, requires_grad=True)
optimizer = torch.optim.SGD([x], lr=0.1)  # descente de gradient
for i in range(101):
    optimizer.zero_grad()  # on remet à 0 l'arbre des gradients
    fx = x**2
    fx.backward()  # calcul des gradients
    optimizer.step()  # pas de la descente
    if i % 10 == 0:
        print(x)  # itérés succesifs toutes les 10 itérations
print(x.detach().numpy())  # afficher en numpy la solution
tensor([-0.7283], requires_grad=True)
tensor([-0.0782], requires_grad=True)
tensor([-0.0084], requires_grad=True)
tensor([-0.0009], requires_grad=True)
tensor([-9.6802e-05], requires_grad=True)
tensor([-1.0394e-05], requires_grad=True)
tensor([-1.1161e-06], requires_grad=True)
tensor([-1.1984e-07], requires_grad=True)
tensor([-1.2867e-08], requires_grad=True)
tensor([-1.3816e-09], requires_grad=True)
tensor([-1.4835e-10], requires_grad=True)
[-1.4834915e-10]

Question : Optimisation avec pytorch

A l’aide de pytorch, afficher les valeurs successives de (f(x_k, y_k))_{k=1,\dots,100} au cours de l’optimisation à l’aide de la descente de gradient à pas fixe \eta=0.001 de la fonction f(x,y)=100(y-x^2)^2 + (1-x)^2 + 2. On utilisera une échelle logarithmique et on vérifiera que l’on est bien arrivé à un minimum local.

def f(x, y):
    return ...

ll = []
x = torch.randn(2, ...)
optimizer = torch.optim.SGD(..., lr=...)  # descente de gradient
for i in range(...):
    # XXX
    fx = ...
    ll.append(...)
    # XXX
    # XXX

print("Le dernier gradient calculé vaut:", ...)
print("Le dernier itéré est:", ...)

plt.figure()
plt.plot(list(range(len(ll))), ll, color="blue")
plt.xlabel(...)
plt.ylabel(...)
# XXX TODO échelle log en y
plt.tight_layout()
plt.show()
Le dernier gradient calculé vaut : tensor([-0.0699, -0.1207])
Le dernier itéré est : [0.86114675 0.7409704 ]
2024-04-08T14:05:02.057080 image/svg+xml Matplotlib v3.5.3, https://matplotlib.org/

Pour aller plus loin

  • Montrer que le gradient de f s’annule pour les points ((-1)^{k+1}, k\pi)_{k\in\mathbb{Z}} et (0, \frac{\pi}{2}+k\pi)_{k\in\mathbb{Z}}.
  • À l’aide des conditions du second ordre, montrer que les conditions suffisantes d’optimalité sont seulement vérifiées pour les points ((-1)^{k+1}, k\pi)_{k\in\mathbb{Z}}.
  • Afficher la fonction f et ses courbes de niveaux.

Ressources utiles