Source code for galeritas.precision_and_recall_by_probability_threshold

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.stats import mstats
from galeritas.utils.creditas_palette import get_palette
from galeritas.precision_recall_threshold_confidence_interval_aux import (threshold_confidence_interval,
                                                                          _get_threshold_metrics_intervals)


[docs]def plot_precision_and_recall_by_probability_threshold( df, prediction_column_name, target_name, target=1, n_trials=50, sample_size_percent=0.5, quantiles=[0.05, 0.5, 0.95], thresholds_to_highlight=None, x_label="Model probability threshold", y_label="Metric's Ratio", plot_title=None, colors=None, color_palette=None, figsize=(16, 7), ax=None, return_fig=False, **legend_kwargs): """ Determines precision, recall e support scores for different thresholds for the positive class, using a data sample with replacement. Adapted from `Insight Data Science's post <https://blog.insightdatascience.com/visualizing-machine-learning-thresholds-to-make-better-business-decisions-4ab07f823415>`__. :param df: Dataframe containing predictions and target columns. :type df: DataFrame :param prediction_column_name: String that indicates the name of the columns where the predictions are. :type prediction_column_name: str :param target_name: String that indicates the target name. :type target_name: str :param target: Indicates the target class. |default| :code:`1` :type target: int, optional :param n_trials: Indicates the number of times to resample the data and make predictions. |default| :code:`50` :type n_trials: int, optional :param sample_size_percent: Indicates the percentage of the dataset that needs to be used to perform the sample data. |default| :code:`0.5` :type sample_size_percent: float, optional :param quantiles: Indicates the upper, median and lower quantiles to be used to plot the graph. |default| :code:`[0.05, 0.5, 0.95]` :type quantiles: list, optional :param thresholds_to_highlight: Indicates the score(s) where the thresholds will be drawn. |default| :code:`None` :type thresholds_to_highlight: list, optional :param x_label: Text to describe the x-axis label. |default| :code:`"Model probability threshold"` :type x_label: str, optional :param y_label: Text to describe the y-axis label. |default| :code:`"Metric's Ratio"` :type y_label: str, optional :param plot_title: Text to describe the plot's title. |default| :code:`None` :type plot_title: str, optional :param colors: A list containing the hexadecimal colors of each hue. The number of elements on the list must be the same of hue groups. |default| :code:`None` :type colors: list of str, optional :param color_palette: If this parameter is set, uses the color_palette to set different colors of the palette for each hue value. If both colors and color_palette parameters are None, uses Galeritas default palette. |default| :code:`None` :type color_palette: str, optional :param figsize: A tuple that indicates the figure size (respectively, width and height in inches). |default| :code:`(16, 7)` :type figsize: tuple, optional :param ax: Custom figure axes to plot. |default| :code: `None` :type ax: matplotlib.axes, optional :param return_fig: If True return figure object. |default| :code:`False` :type return_fig: bool, optional :param legend_kwargs: Matplotlib.pyplot's legend arguments such as *bbox_to_anchor* and *ncol*. Further informations `here <http://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.legend>`__. :type legend_kwargs: key, value mappings :return: Returns the figure object with the plot (*return_fig parameter needs to be set) :rtype: Figure """ if target_name not in [col for col in df if np.isin(df[col].unique(), [0, 1]).all()]: raise ValueError(f'The target must be binary! Column "{target_name}" contains more values.') if colors is None: colors = get_palette() if color_palette: colors = sns.color_palette(color_palette, 3) if colors is not None and len(colors) < 3: raise KeyError(f'Expected 3 colors but only {len(colors)} was/were passed.') metric_names = ['precision', 'recall', 'support_rate'] colormap = dict(zip(metric_names, colors)) uniform_precision_plots, uniform_recall_plots, uniform_support_plots, uniform_support_rate_plots, uniform_thresholds \ = threshold_confidence_interval(df, target_name, prediction_column_name, target=target, n_trials=n_trials, sample_size_percent=sample_size_percent) confidence_interval = 100 * round(quantiles[-1] - quantiles[0], 2) lower_precision, median_precision, upper_precision = mstats.mquantiles(uniform_precision_plots, quantiles, axis=0) lower_recall, median_recall, upper_recall = mstats.mquantiles(uniform_recall_plots, quantiles, axis=0) lower_support, median_support, upper_support = mstats.mquantiles(uniform_support_plots, quantiles, axis=0) lower_support_rate, median_support_rate, upper_support_rate = mstats.mquantiles(uniform_support_rate_plots, quantiles, axis=0) # Plot if ax is None: fig, ax = plt.subplots(figsize=figsize, dpi=120) ax.plot(uniform_thresholds, median_precision, color=colormap['precision']) ax.plot(uniform_thresholds, median_recall, color=colormap['recall']) ax.plot(uniform_thresholds, median_support_rate, color=colormap['support_rate']) ax.fill_between(uniform_thresholds, upper_precision, lower_precision, alpha=0.5, linewidth=0, color=colormap['precision']) ax.fill_between(uniform_thresholds, upper_recall, lower_recall, alpha=0.5, linewidth=0, color=colormap['recall']) ax.fill_between(uniform_thresholds, upper_support_rate, lower_support_rate, alpha=0.5, linewidth=0, color=colormap['support_rate']) ax.legend( ('precision', 'recall', 'support'), frameon=True, **legend_kwargs ) ax.text( 0.05, -0.15, f"Confidence Interval: {confidence_interval}%", horizontalalignment='center', verticalalignment='center', bbox=dict(boxstyle='round', alpha=0.25, facecolor='gray') ) ax.set_title(plot_title, pad=30) ax.set_xticks(np.arange(0, 1.1, 0.1)) ax.set_yticks(np.arange(0, 1.1, 0.1)) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.grid(True, alpha=0.6, linestyle='--') if thresholds_to_highlight: label = "A" thresholds_metrics_summary_range = {} metrics_columns_range_order = [ "probability_threshold", "Median Precision", "Precision Range", "Median Recall", "Recall Range", "Median Support", "Support Range", "Median Support Rate", "Support Rate Range" ] for highlight_threshold in thresholds_to_highlight: ax.plot( np.repeat(highlight_threshold, len(uniform_thresholds)), uniform_thresholds, "k--" ) thresholds_metrics_summary_range[label] = _get_threshold_metrics_intervals( lower_precision, median_precision, upper_precision, lower_recall, median_recall, upper_recall, lower_support, median_support, upper_support, lower_support_rate, median_support_rate, upper_support_rate, highlight_threshold, uniform_thresholds ) thresholds_metrics_summary_range[label]["probability_threshold"] = highlight_threshold ax.text(highlight_threshold - 0.006, 1.08, label) label = chr(ord(label) + 1) thresholds_metrics_dataframe = pd.DataFrame(thresholds_metrics_summary_range).T[metrics_columns_range_order] thresholds_metrics_dataframe.index.name = "Threshold Label" display(thresholds_metrics_dataframe) if return_fig: plt.show() plt.close() return fig