
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/ensemble/plot_comparison_ensemble_classifier.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_ensemble_plot_comparison_ensemble_classifier.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_ensemble_plot_comparison_ensemble_classifier.py:


=============================================
Compare ensemble classifiers using resampling
=============================================

Ensemble classifiers have shown to improve classification performance compare
to single learner. However, they will be affected by class imbalance. This
example shows the benefit of balancing the training set before to learn
learners. We are making the comparison with non-balanced ensemble methods.

We make a comparison using the balanced accuracy and geometric mean which are
metrics widely used in the literature to evaluate models learned on imbalanced
set.

.. GENERATED FROM PYTHON SOURCE LINES 15-19

.. code-block:: default


    # Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
    # License: MIT








.. GENERATED FROM PYTHON SOURCE LINES 20-22

.. code-block:: default

    print(__doc__)








.. GENERATED FROM PYTHON SOURCE LINES 23-29

Load an imbalanced dataset
--------------------------

We will load the UCI SatImage dataset which has an imbalanced ratio of 9.3:1
(number of majority sample for a minority sample). The data are then split
into training and testing.

.. GENERATED FROM PYTHON SOURCE LINES 31-38

.. code-block:: default

    from imblearn.datasets import fetch_datasets
    from sklearn.model_selection import train_test_split

    satimage = fetch_datasets()["satimage"]
    X, y = satimage.data, satimage.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)








.. GENERATED FROM PYTHON SOURCE LINES 39-48

Classification using a single decision tree
-------------------------------------------

We train a decision tree classifier which will be used as a baseline for the
rest of this example.

The results are reported in terms of balanced accuracy and geometric mean
which are metrics widely used in the literature to validate model trained on
imbalanced set.

.. GENERATED FROM PYTHON SOURCE LINES 50-56

.. code-block:: default

    from sklearn.tree import DecisionTreeClassifier

    tree = DecisionTreeClassifier()
    tree.fit(X_train, y_train)
    y_pred_tree = tree.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 57-66

.. code-block:: default

    from sklearn.metrics import balanced_accuracy_score
    from imblearn.metrics import geometric_mean_score

    print("Decision tree classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_tree):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_tree):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Decision tree classifier performance:
    Balanced accuracy: 0.75 - Geometric mean 0.73




.. GENERATED FROM PYTHON SOURCE LINES 67-75

.. code-block:: default

    import seaborn as sns
    from sklearn.metrics import plot_confusion_matrix

    sns.set_context("poster")

    disp = plot_confusion_matrix(tree, X_test, y_test, colorbar=False)
    _ = disp.ax_.set_title("Decision tree")




.. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_001.png
    :alt: Decision tree
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'rocket' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'rocket_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'mako' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'mako_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'icefire' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'icefire_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'vlag' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'vlag_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'flare' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'flare_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1582: UserWarning: Trying to register the cmap 'crest' which already exists.
      mpl_cm.register_cmap(_name, _cmap)
    /Users/glemaitre/mambaforge/envs/dev/lib/python3.8/site-packages/seaborn/cm.py:1583: UserWarning: Trying to register the cmap 'crest_r' which already exists.
      mpl_cm.register_cmap(_name + "_r", _cmap_r)
    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)




.. GENERATED FROM PYTHON SOURCE LINES 76-83

Classification using bagging classifier with and without sampling
-----------------------------------------------------------------

Instead of using a single tree, we will check if an ensemble of decsion tree
can actually alleviate the issue induced by the class imbalancing. First, we
will use a bagging classifier and its counter part which internally uses a
random under-sampling to balanced each boostrap sample.

.. GENERATED FROM PYTHON SOURCE LINES 85-97

.. code-block:: default

    from sklearn.ensemble import BaggingClassifier
    from imblearn.ensemble import BalancedBaggingClassifier

    bagging = BaggingClassifier(n_estimators=50, random_state=0)
    balanced_bagging = BalancedBaggingClassifier(n_estimators=50, random_state=0)

    bagging.fit(X_train, y_train)
    balanced_bagging.fit(X_train, y_train)

    y_pred_bc = bagging.predict(X_test)
    y_pred_bbc = balanced_bagging.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 98-100

Balancing each bootstrap sample allows to increase significantly the balanced
accuracy and the geometric mean.

.. GENERATED FROM PYTHON SOURCE LINES 102-113

.. code-block:: default

    print("Bagging classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bc):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_bc):.2f}"
    )
    print("Balanced Bagging classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bbc):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_bbc):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Bagging classifier performance:
    Balanced accuracy: 0.73 - Geometric mean 0.68
    Balanced Bagging classifier performance:
    Balanced accuracy: 0.86 - Geometric mean 0.86




.. GENERATED FROM PYTHON SOURCE LINES 114-125

.. code-block:: default

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
    plot_confusion_matrix(bagging, X_test, y_test, ax=axs[0], colorbar=False)
    axs[0].set_title("Bagging")

    plot_confusion_matrix(balanced_bagging, X_test, y_test, ax=axs[1], colorbar=False)
    axs[1].set_title("Balanced Bagging")

    fig.tight_layout()




.. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_002.png
    :alt: Bagging, Balanced Bagging
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)
    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)




.. GENERATED FROM PYTHON SOURCE LINES 126-132

Classification using random forest classifier with and without sampling
-----------------------------------------------------------------------

Random forest is another popular ensemble method and it is usually
outperforming bagging. Here, we used a vanilla random forest and its balanced
counterpart in which each bootstrap sample is balanced.

.. GENERATED FROM PYTHON SOURCE LINES 134-146

.. code-block:: default

    from sklearn.ensemble import RandomForestClassifier
    from imblearn.ensemble import BalancedRandomForestClassifier

    rf = RandomForestClassifier(n_estimators=50, random_state=0)
    brf = BalancedRandomForestClassifier(n_estimators=50, random_state=0)

    rf.fit(X_train, y_train)
    brf.fit(X_train, y_train)

    y_pred_rf = rf.predict(X_test)
    y_pred_brf = brf.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 147-150

Similarly to the previous experiment, the balanced classifier outperform the
classifier which learn from imbalanced bootstrap samples. In addition, random
forest outsperforms the bagging classifier.

.. GENERATED FROM PYTHON SOURCE LINES 152-163

.. code-block:: default

    print("Random Forest classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rf):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_rf):.2f}"
    )
    print("Balanced Random Forest classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_brf):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_brf):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Random Forest classifier performance:
    Balanced accuracy: 0.73 - Geometric mean 0.68
    Balanced Random Forest classifier performance:
    Balanced accuracy: 0.88 - Geometric mean 0.88




.. GENERATED FROM PYTHON SOURCE LINES 164-173

.. code-block:: default

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
    plot_confusion_matrix(rf, X_test, y_test, ax=axs[0], colorbar=False)
    axs[0].set_title("Random forest")

    plot_confusion_matrix(brf, X_test, y_test, ax=axs[1], colorbar=False)
    axs[1].set_title("Balanced random forest")

    fig.tight_layout()




.. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_003.png
    :alt: Random forest, Balanced random forest
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)
    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)




.. GENERATED FROM PYTHON SOURCE LINES 174-180

Boosting classifier
-------------------

In the same manner, easy ensemble classifier is a bag of balanced AdaBoost
classifier. However, it will be slower to train than random forest and will
achieve worse performance.

.. GENERATED FROM PYTHON SOURCE LINES 182-194

.. code-block:: default

    from sklearn.ensemble import AdaBoostClassifier
    from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier

    base_estimator = AdaBoostClassifier(n_estimators=10)
    eec = EasyEnsembleClassifier(n_estimators=10, base_estimator=base_estimator)
    eec.fit(X_train, y_train)
    y_pred_eec = eec.predict(X_test)

    rusboost = RUSBoostClassifier(n_estimators=10, base_estimator=base_estimator)
    rusboost.fit(X_train, y_train)
    y_pred_rusboost = rusboost.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 195-206

.. code-block:: default

    print("Easy ensemble classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_eec):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_eec):.2f}"
    )
    print("RUSBoost classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rusboost):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_rusboost):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    Easy ensemble classifier performance:
    Balanced accuracy: 0.85 - Geometric mean 0.85
    RUSBoost classifier performance:
    Balanced accuracy: 0.85 - Geometric mean 0.85




.. GENERATED FROM PYTHON SOURCE LINES 207-216

.. code-block:: default

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))

    plot_confusion_matrix(eec, X_test, y_test, ax=axs[0], colorbar=False)
    axs[0].set_title("Easy Ensemble")
    plot_confusion_matrix(rusboost, X_test, y_test, ax=axs[1], colorbar=False)
    axs[1].set_title("RUSBoost classifier")

    fig.tight_layout()
    plt.show()



.. image:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_004.png
    :alt: Easy Ensemble, RUSBoost classifier
    :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 Out:

 .. code-block:: none

    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)
    /Users/glemaitre/Documents/packages/scikit-learn/sklearn/utils/deprecation.py:87: FutureWarning: Function plot_confusion_matrix is deprecated; Function `plot_confusion_matrix` is deprecated in 1.0 and will be removed in 1.2. Use one of the class methods: ConfusionMatrixDisplay.from_predictions or ConfusionMatrixDisplay.from_estimator.
      warnings.warn(msg, category=FutureWarning)





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  2.926 seconds)


.. _sphx_glr_download_auto_examples_ensemble_plot_comparison_ensemble_classifier.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_comparison_ensemble_classifier.py <plot_comparison_ensemble_classifier.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_comparison_ensemble_classifier.ipynb <plot_comparison_ensemble_classifier.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
