
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

def plot_fun(f,dom):
    
    # affichage du graphe d'une fonction de R^2 dans R,
    # de ses lignes de niveau,
    # et valeur de son minimum sur une grille discrète.
    # f : fonction
    # dom : domaine de la grille
    
    # Definition des valeurs x1 et x2, des grilles correspondantes et evaluation de f
    x, y = np.linspace(dom[0],dom[1],200), np.linspace(dom[2],dom[3],200)
    x, y = np.meshgrid(x,y)
    z = f(x,y)
    
    # valeur et position du minimum sur la grille:
    imin = np.unravel_index(np.argmin(z),z.shape)
    zmin = z[imin]
    print('valeur du minimum sur la grille : ',zmin)
    xmin, ymin = x[imin], y[imin]
    print('position du minimiseur : (',xmin,',',ymin,')')
    
    # Graphique 3D:
    fig = plt.figure()
    ax = plt.axes(projection="3d")
    ax.plot([xmin,xmin],[ymin,ymin],[0,zmin],marker='o')
    ax.plot_surface(x, y, z)
    ax.set_title('Graphe de la fonction')
    
    # Graphique des lignes de niveaux
    fig, ax = plt.subplots()
    CS = ax.contour(x, y, z)
    ax.clabel(CS, inline=1, fontsize=10)
    ax.set_title('Lignes de niveau')
    
    
def f1(x,y):
    return x**2 + 2*y**2 + x*y + x - 3*y + 30
dom1 = [-3,3,-3,3]
plot_fun(f1,dom1)
plt.show()

def f2(x,y):
    return 100*(y-x**2)**2 + (1-x)**2
dom2 = [-3,3,-1,4]
plot_fun(f2,dom2)
plt.show()

def f3(x,y):
    return (x**2 + y - 11)**2 + (x+y**2-7)**2;
dom3 = [-5,5,-5,5]
plot_fun(f3,dom3)
plt.show()

def f4(x,y):
    return - np.log(x) - np.log(y) - np.log(1-x) - np.log(1.5-y)
dom4 = [0.001,0.999,0.001,1.499]
plot_fun(f4,dom4)
plt.show()

def plot_grad(f,gradf,dom,xstar,ystar):
    
    # Affichage des lignes de niveau et du champ de gradients
    # d'une fonction de R^2 dans R
    # f : fonction
    # gradf : gradient de la fonction
    # dom : domaine de d?finition
    # xstar,ystar : minimiseur de la fonction sur le domaine
    
    x, y = np.linspace(dom[0],dom[1],200), np.linspace(dom[2],dom[3],200)
    x, y = np.meshgrid(x,y)
    z = f(x,y)
    xg, yg = np.linspace(dom[0],dom[1],20), np.linspace(dom[2],dom[3],20) # grille plus grossiere
    xg, yg = np.meshgrid(xg,yg)
    u, v = gradf(xg,yg)
    fig, ax = plt.subplots()
    CS = ax.contour(x, y, z)
    ax.clabel(CS, inline=1, fontsize=10)
    plt.plot(xstar,ystar,'*r')
    plt.quiver(xg,yg,u,v)


def gradf1(x,y):
    gx = 2*x + y + 1
    gy = 4*y + x - 3
    return gx, gy
xstar, ystar = -1,1
plot_grad(f1,gradf1,dom1,xstar,ystar)
plt.show()

