
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/applications/plot_over_sampling_benchmark_lfw.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_applications_plot_over_sampling_benchmark_lfw.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_applications_plot_over_sampling_benchmark_lfw.py:


==========================================================
Benchmark over-sampling methods in a face recognition task
==========================================================

In this face recognition example two faces are used from the LFW
(Faces in the Wild) dataset. Several implemented over-sampling
methods are used in conjunction with a 3NN classifier in order
to examine the improvement of the classifier's output quality
by using an over-sampler.

.. GENERATED FROM PYTHON SOURCE LINES 12-17

.. code-block:: default


    # Authors: Christos Aridas
    #          Guillaume Lemaitre <g.lemaitre58@gmail.com>
    # License: MIT








.. GENERATED FROM PYTHON SOURCE LINES 18-24

.. code-block:: default

    print(__doc__)

    import seaborn as sns

    sns.set_context("poster")





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none


    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)




.. GENERATED FROM PYTHON SOURCE LINES 25-31

Load the dataset
----------------

We will use a dataset containing image from know person where we will
build a model to recognize the person on the image. We will make this problem
a binary problem by taking picture of only George W. Bush and Bill Clinton.

.. GENERATED FROM PYTHON SOURCE LINES 33-42

.. code-block:: default

    import numpy as np
    from sklearn.datasets import fetch_lfw_people

    data = fetch_lfw_people()
    george_bush_id = 1871  # Photos of George W. Bush
    bill_clinton_id = 531  # Photos of Bill Clinton
    classes = [george_bush_id, bill_clinton_id]
    classes_name = np.array(["B. Clinton", "G.W. Bush"], dtype=np.object)





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /Users/glemaitre/Documents/packages/imbalanced-learn/examples/applications/plot_over_sampling_benchmark_lfw.py:40: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
    Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
      classes_name = np.array(["B. Clinton", "G.W. Bush"], dtype=np.object)




.. GENERATED FROM PYTHON SOURCE LINES 43-48

.. code-block:: default

    mask_photos = np.isin(data.target, classes)
    X, y = data.data[mask_photos], data.target[mask_photos]
    y = (y == george_bush_id).astype(np.int8)
    y = classes_name[y]








.. GENERATED FROM PYTHON SOURCE LINES 49-50

We can check the ratio between the two classes.

.. GENERATED FROM PYTHON SOURCE LINES 52-60

.. code-block:: default

    import pandas as pd

    class_distribution = pd.Series(y).value_counts(normalize=True)
    ax = class_distribution.plot.barh()
    ax.set_title("Class distribution")
    pos_label = class_distribution.idxmin()
    print(f"The positive label considered as the minority class is {pos_label}")




.. image:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_001.png
    :alt: Class distribution
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    The positive label considered as the minority class is B. Clinton




.. GENERATED FROM PYTHON SOURCE LINES 61-72

We see that we have an imbalanced classification problem with ~95% of the
data belonging to the class G.W. Bush.

Compare over-sampling approaches
--------------------------------

We will use different over-sampling approaches and use a kNN classifier
to check if we can recognize the 2 presidents. The evaluation will be
performed through cross-validation and we will plot the mean ROC curve.

We will create different pipelines and evaluate them.

.. GENERATED FROM PYTHON SOURCE LINES 74-88

.. code-block:: default

    from imblearn import FunctionSampler
    from imblearn.over_sampling import ADASYN, RandomOverSampler, SMOTE
    from imblearn.pipeline import make_pipeline
    from sklearn.neighbors import KNeighborsClassifier

    classifier = KNeighborsClassifier(n_neighbors=3)

    pipeline = [
        make_pipeline(FunctionSampler(), classifier),
        make_pipeline(RandomOverSampler(random_state=42), classifier),
        make_pipeline(ADASYN(random_state=42), classifier),
        make_pipeline(SMOTE(random_state=42), classifier),
    ]








.. GENERATED FROM PYTHON SOURCE LINES 89-93

.. code-block:: default

    from sklearn.model_selection import StratifiedKFold

    cv = StratifiedKFold(n_splits=3)








.. GENERATED FROM PYTHON SOURCE LINES 94-97

We will compute the mean ROC curve for each pipeline using a different splits
provided by the :class:`~sklearn.model_selection.StratifiedKFold`
cross-validation.

.. GENERATED FROM PYTHON SOURCE LINES 99-132

.. code-block:: default

    import matplotlib.pyplot as plt
    from sklearn.metrics import RocCurveDisplay, roc_curve, auc

    disp = []
    for model in pipeline:
        # compute the mean fpr/tpr to get the mean ROC curve
        mean_tpr, mean_fpr = 0.0, np.linspace(0, 1, 100)
        for train, test in cv.split(X, y):
            model.fit(X[train], y[train])
            y_proba = model.predict_proba(X[test])

            pos_label_idx = np.flatnonzero(model.classes_ == pos_label)[0]
            fpr, tpr, thresholds = roc_curve(
                y[test], y_proba[:, pos_label_idx], pos_label=pos_label
            )
            mean_tpr += np.interp(mean_fpr, fpr, tpr)
            mean_tpr[0] = 0.0

        mean_tpr /= cv.get_n_splits(X, y)
        mean_tpr[-1] = 1.0
        mean_auc = auc(mean_fpr, mean_tpr)

        # Create a display that we will reuse to make the aggregated plots for
        # all methods
        disp.append(
            RocCurveDisplay(
                fpr=mean_fpr,
                tpr=mean_tpr,
                roc_auc=mean_auc,
                estimator_name=f"{model[0].__class__.__name__}",
            )
        )








.. GENERATED FROM PYTHON SOURCE LINES 133-135

In the previous cell, we created the different mean ROC curve and we can plot
them on the same plot.

.. GENERATED FROM PYTHON SOURCE LINES 137-148

.. code-block:: default

    fig, ax = plt.subplots(figsize=(9, 9))
    for d in disp:
        d.plot(ax=ax, linestyle="--")
    ax.plot([0, 1], [0, 1], linestyle="--", color="k")
    ax.axis("square")
    fig.suptitle("Comparison of over-sampling methods with a 3NN classifier")
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    sns.despine(offset=10, ax=ax)
    plt.show()




.. image:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_002.png
    :alt: Comparison of over-sampling methods with a 3NN classifier
    :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 149-152

We see that for this task, methods that are generating new samples with some
interpolation (i.e. ADASYN and SMOTE) perform better than random
over-sampling or no resampling.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  1.231 seconds)


.. _sphx_glr_download_auto_examples_applications_plot_over_sampling_benchmark_lfw.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_over_sampling_benchmark_lfw.py <plot_over_sampling_benchmark_lfw.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_over_sampling_benchmark_lfw.ipynb <plot_over_sampling_benchmark_lfw.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
