Skip to content

Cheatsheet : seaborn

seaborn est une bibliothèque de visualisation statistique basée sur matplotlib. Elle est particulièrement adaptée pour explorer des données et visualiser des distributions.

Configuration de base

python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

# Style et palette
sns.set_theme(style="darkgrid")      # styles : darkgrid, whitegrid, dark, white, ticks
sns.set_palette("tab10")             # palettes : tab10, Set2, husl, coolwarm, ...

Visualiser une distribution

Histogramme et densité estimée (KDE)

python
x = stats.norm(loc=5, scale=2).rvs(size=500, random_state=42)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Histogramme seul
sns.histplot(x, bins=30, ax=axes[0])
axes[0].set_title("Histogramme")

# KDE (estimation par noyau)
sns.kdeplot(x, fill=True, ax=axes[1])
axes[1].set_title("Densité estimée (KDE)")

# Les deux superposés
sns.histplot(x, bins=30, stat="density", alpha=0.5, ax=axes[2])
sns.kdeplot(x, color="red", linewidth=2, ax=axes[2])
axes[2].set_title("Histogramme + KDE")

Comparer plusieurs distributions

python
import pandas as pd

# Générer des données
data = pd.DataFrame({
    'valeur': np.concatenate([
        stats.norm(0, 1).rvs(300),
        stats.norm(2, 1.5).rvs(300),
        stats.norm(-1, 0.5).rvs(300)
    ]),
    'groupe': ['N(0,1)'] * 300 + ['N(2,1.5²)'] * 300 + ['N(-1,0.25)'] * 300
})

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# KDE superposées
sns.kdeplot(data=data, x='valeur', hue='groupe', fill=True,
            alpha=0.3, ax=axes[0])
axes[0].set_title("KDE par groupe")

# Boxplot
sns.boxplot(data=data, x='groupe', y='valeur', ax=axes[1])
axes[1].set_title("Boxplot")

# Violinplot
sns.violinplot(data=data, x='groupe', y='valeur', ax=axes[2])
axes[2].set_title("Violin plot")

Distribution empirique (ECDF)

python
# Fonction de répartition empirique
sns.ecdfplot(x)
plt.title("Fonction de répartition empirique")

Visualiser des relations bivariées

Nuage de points avec régression

python
# Données corrélées
n = 200
x = np.random.randn(n)
y = 2 * x + 1 + np.random.randn(n) * 0.5
data = pd.DataFrame({'x': x, 'y': y})

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Nuage simple
sns.scatterplot(data=data, x='x', y='y', alpha=0.5, ax=axes[0])
axes[0].set_title("Nuage de points")

# Régression linéaire
sns.regplot(data=data, x='x', y='y', ax=axes[1])
axes[1].set_title("Régression linéaire")

# Régression polynomiale
sns.regplot(data=data, x='x', y='y', order=2, ax=axes[2])
axes[2].set_title("Régression polynomiale (degré 2)")

Distribution jointe

python
# Jointplot : nuage + marginales
sns.jointplot(data=data, x='x', y='y', kind='kde')
# kind : 'scatter', 'kde', 'hex', 'reg', 'hist'

Matrice de corrélation (heatmap)

python
# Données multivariées
mu = [0, 1, 2]
Sigma = [[1, 0.8, 0.3],
         [0.8, 1, 0.5],
         [0.3, 0.5, 1]]
samples = np.random.multivariate_normal(mu, Sigma, 500)
df = pd.DataFrame(samples, columns=['X1', 'X2', 'X3'])

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Matrice de corrélation
corr = df.corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1,
            center=0, ax=axes[0])
axes[0].set_title("Matrice de corrélation")

# Pairplot (toutes les combinaisons)
# sns.pairplot(df)  # à utiliser dans une cellule séparée

Visualiser des estimateurs

Distribution d'un estimateur par simulation

python
mu_vrai = 5
sigma_vrai = 2
tailles = [10, 30, 100, 500]
N_sim = 2000

resultats = []
for n in tailles:
    estimations = [np.mean(stats.norm(mu_vrai, sigma_vrai).rvs(n))
                   for _ in range(N_sim)]
    resultats.extend([{'n': n, 'mu_hat': est} for est in estimations])

df = pd.DataFrame(resultats)

# Visualisation
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.kdeplot(data=df, x='mu_hat', hue='n', fill=True,
            alpha=0.3, palette='viridis', ax=axes[0])
axes[0].axvline(mu_vrai, color='red', linestyle='--', label='μ vrai')
axes[0].set_title("Distribution de μ̂ selon n")
axes[0].legend()

sns.boxplot(data=df, x='n', y='mu_hat', ax=axes[1])
axes[1].axhline(mu_vrai, color='red', linestyle='--')
axes[1].set_title("Boxplot de μ̂ selon n")

Biais-Variance avec barres d'erreur

python
# Résumé par taille d'échantillon
resume = df.groupby('n')['mu_hat'].agg(['mean', 'std']).reset_index()

fig, ax = plt.subplots(figsize=(8, 5))
ax.errorbar(resume['n'], resume['mean'], yerr=1.96 * resume['std'],
            fmt='o-', capsize=5, linewidth=2, markersize=8)
ax.axhline(mu_vrai, color='red', linestyle='--', label='μ vrai')
ax.set_xlabel('Taille d\'échantillon n')
ax.set_ylabel('μ̂ ± 1.96σ')
ax.set_title('Convergence de l\'estimateur')
ax.legend()

Visualiser la régression

Résidus

python
from sklearn.linear_model import LinearRegression

# Données
n = 100
X = np.random.randn(n, 1)
y = 3 * X.ravel() + 1 + np.random.randn(n) * 0.5

# Ajustement
model = LinearRegression().fit(X, y)
y_pred = model.predict(X)
residus = y - y_pred

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Résidus vs valeurs ajustées
sns.scatterplot(x=y_pred, y=residus, alpha=0.5, ax=axes[0])
axes[0].axhline(0, color='red', linestyle='--')
axes[0].set_xlabel('Valeurs ajustées ŷ')
axes[0].set_ylabel('Résidus e')
axes[0].set_title('Résidus vs ajustées')

# Distribution des résidus
sns.histplot(residus, kde=True, ax=axes[1])
axes[1].set_title('Distribution des résidus')

# QQ-plot
from scipy import stats as sp_stats
sp_stats.probplot(residus, plot=axes[2])
axes[2].set_title('QQ-plot des résidus')

Effet de la régularisation

python
from sklearn.linear_model import Ridge

lambdas = np.logspace(-2, 3, 50)
coefs = []

for lam in lambdas:
    model = Ridge(alpha=lam, fit_intercept=False)
    model.fit(X_multi, y)  # X_multi : matrice de design
    coefs.append(model.coef_)

coefs = np.array(coefs)

# Chemin de régularisation
fig, ax = plt.subplots(figsize=(10, 6))
for j in range(coefs.shape[1]):
    ax.plot(np.log10(lambdas), coefs[:, j], linewidth=2)
ax.set_xlabel('log₁₀(λ)')
ax.set_ylabel('Coefficients')
ax.set_title('Chemin de régularisation Ridge')
ax.axhline(0, color='black', linestyle='--', alpha=0.5)

Palettes et styles utiles

Palettes recommandées

python
# Pour catégories distinctes
sns.color_palette("tab10")      # 10 couleurs distinctes
sns.color_palette("Set2")       # tons pastel

# Pour gradients continus
sns.color_palette("viridis", 5)   # perceptuellement uniforme
sns.color_palette("coolwarm", 5)  # divergente (bleu → rouge)

Personnalisation rapide

python
# Taille des figures par défaut
sns.set_theme(rc={'figure.figsize': (10, 6)})

# Contextes (taille du texte)
sns.set_context("paper")      # petit (articles)
sns.set_context("notebook")   # moyen (par défaut)
sns.set_context("talk")       # grand (présentations)
sns.set_context("poster")     # très grand

Résumé des fonctions clés

FonctionUsage
sns.histplot()Histogramme
sns.kdeplot()Densité estimée par noyau
sns.ecdfplot()Fonction de répartition empirique
sns.boxplot()Boîte à moustaches
sns.violinplot()Violin plot
sns.scatterplot()Nuage de points
sns.regplot()Régression avec IC
sns.heatmap()Carte de chaleur
sns.jointplot()Distribution jointe + marginales
sns.pairplot()Toutes les paires de variables