Model Interpretability with XGBoost and the Agaricus Dataset

43 minute read

Overview

This notebook will take a look at agaricus dataset (Mushroom Database) originally drawn from The Audubon Society Field Guide to North American Mushrooms and hosted in the UCI Machine Learning Repository.

The goal is to create model that can accurately differentiate between edible and poisonous mushrooms.

To do this two models will be used:

  • sklearn's RandomForestClassifer
  • XGBoost's XGBClassifier

Each model will be used on both a simple numeric mapping and a one-hot encoding of the dataset. In addition to model performance, feature importances will be examined for each model and decision trees built when possible.

Finally, well use investigate each model further using:

  • Permutation Importance
  • LIME
  • SHAP

to fully understand feature contributions.

In [1]:
import itertools
from collections import Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns;sns.set()
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_validate
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.preprocessing import LabelEncoder,LabelBinarizer, OneHotEncoder
from sklearn.compose import ColumnTransformer
from io import StringIO
from IPython.display import Image, display_html
from sklearn import tree
import pydotplus
import eli5
from eli5.sklearn import PermutationImportance
import shap
import lime
import statsmodels.api as sm
import scipy.stats as ss
C:\Users\Rygu\Anaconda3\envs\i4061\lib\site-packages\tqdm\autonotebook\__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)
  " (e.g. in jupyter console)", TqdmExperimentalWarning)
In [2]:
RS = 404 # global random state/seed to use
pd.set_option('max_columns',25)
shap.initjs()

Data

The dataset may be obtained from:

https://archive.ics.uci.edu/ml/datasets/mushroom or https://www.kaggle.com/uciml/mushroom-classification

Additionally, a dataset is used from the XGBoost repository which can be found here:
https://github.com/dmlc/xgboost/tree/master/demo/data

The Kaggle link is preferred simply for convenience as the columns have already been labeled with sensible names.

This dataset includes descriptions of hypothetical samples corresponding to 23 species of gilled mushrooms in the Agaricus and Lepiota Family Mushroom drawn from The Audubon Society Field Guide to North American Mushrooms (1981). Each species is identified as definitely edible, definitely poisonous, or of unknown edibility and not recommended. This latter class was combined with the poisonous one.

Each entry in dataset contains only a single letter, a reference table containing corresponding meanings can be found at https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names

Or in the list shown here:

  • class : [edible=e, poisonous=p]
  • cap-shape : [bell=b, conical=c, convex=x, flat=f, knobbed=k, sunken=s]
  • cap-surface : [fibrous=f, grooves=g, scaly=y, smooth=s]
  • cap-color : [brown=n, buff=b, cinnamon=c, gray=g, green=r, pink=p, purple=u, red=e, white=w, yellow=y
  • bruises : [bruises=t, no=f]
  • odor : [almond=a, anise=l, creosote=c, fishy=y, foul=f, musty=m, none=n, pungent=p, spicy=s]
  • gill-attachment : [attached=a, descending=d, free=f, notched=n]
  • gill-spacing : [close=c, crowded=w, distant=d]
  • gill-size : [broad=b, narrow=n]
  • gill-color : [black=k, brown=n, buff=b, chocolate=h, gray=g, green=r, orange=o, pink=p, purple=u, red=e, white=w, yellow=y]
  • stalk-shape : [enlarging=e, tapering=t]
  • stalk-root : [bulbous=b, club=c, cup=u, equal=e, rhizomorphs=z, rooted=r, missing=?]
  • stalk-surface-above-ring : [fibrous=f, scaly=y, silky=k, smooth=s]
  • stalk-surface-below-ring : [fibrous=f, scaly=y, silky=k, smooth=s]
  • stalk-color-above-ring : [brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y]
  • stalk-color-below-ring : [brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y]
  • veil-type : [partial=p, universal=u]
  • veil-color : [brown=n, orange=o, white=w, yellow=y]
  • ring-number : [none=n, one=o, two=t]
  • ring-type : [cobwebby=c, evanescent=e, flaring=f, large=l, none=n, pendant=p, sheathing=s, zone=z]
  • spore-print-color : [black=k, brown=n, buff=b, chocolate=h, green=r, orange=o, purple=u, white=w, yellow=y]
  • population : [abundant=a, clustered=c, numerous=n, scattered=s, several=v, solitary=y]
  • habitat : [grasses=g, leaves=l, meadows=m, paths=p, urban=u, waste=w, woods=d]

UCI Machine Learning Repository data:

!wget https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names
!wget https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data

XGBoost data:

!wget https://raw.githubusercontent.com/dmlc/xgboost/master/demo/data/agaricus.txt.train
!wget https://raw.githubusercontent.com/dmlc/xgboost/master/demo/data/agaricus.txt.test
!wget https://raw.githubusercontent.com/dmlc/xgboost/master/demo/data/featmap.txt
In [3]:
agar = pd.read_csv('data/mushrooms.csv',dtype='category')
# replace dash with underscore for dot access
agar.columns = agar.columns.str.replace('-','_')
# class is python reserved keyword, also impedes dot access
agar.rename(columns={'class':'toxic'},inplace=True)
In [4]:
agar.head()
Out[4]:
toxic cap_shape cap_surface cap_color bruises odor gill_attachment gill_spacing gill_size gill_color stalk_shape stalk_root stalk_surface_above_ring stalk_surface_below_ring stalk_color_above_ring stalk_color_below_ring veil_type veil_color ring_number ring_type spore_print_color population habitat
0 p x s n t p f c n k e e s s w w p w o p k s u
1 e x s y t a f c b k e c s s w w p w o p n n g
2 e b s w t l f c b n e c s s w w p w o p n n m
3 p x y w t p f c n n e e s s w w p w o p k s u
4 e x s g f n f w b k t e s s w w p w o e n a g

Exploratory Data Analysis

Short of one variable having with NA values, this dataset is quite clean. It is entirely comprised of categorical values, each with a relatively low cardinality. It is slightly imbalanced, having 4208 (51.8%) entries marked as edible and 3916 (48.2%) marked poisonous, a small discrepancy, but it shouldn't have a significant affect on our model, however.

In [5]:
agar.toxic.value_counts()
Out[5]:
e    4208
p    3916
Name: toxic, dtype: int64
In [6]:
agar.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8124 entries, 0 to 8123
Data columns (total 23 columns):
toxic                       8124 non-null category
cap_shape                   8124 non-null category
cap_surface                 8124 non-null category
cap_color                   8124 non-null category
bruises                     8124 non-null category
odor                        8124 non-null category
gill_attachment             8124 non-null category
gill_spacing                8124 non-null category
gill_size                   8124 non-null category
gill_color                  8124 non-null category
stalk_shape                 8124 non-null category
stalk_root                  8124 non-null category
stalk_surface_above_ring    8124 non-null category
stalk_surface_below_ring    8124 non-null category
stalk_color_above_ring      8124 non-null category
stalk_color_below_ring      8124 non-null category
veil_type                   8124 non-null category
veil_color                  8124 non-null category
ring_number                 8124 non-null category
ring_type                   8124 non-null category
spore_print_color           8124 non-null category
population                  8124 non-null category
habitat                     8124 non-null category
dtypes: category(23)
memory usage: 187.3 KB

At first glance, it appears as though there are no NA values but from reading the data description we know that stalk_root does have null values, they are just labeled with a '?'

In [7]:
agar.stalk_root.value_counts()
Out[7]:
b    3776
?    2480
e    1120
c     556
r     192
Name: stalk_root, dtype: int64

Since there is only one variable with any missing values and we are about to encode everything anyway, there is no point in replacing it with np.nan

In [8]:
agar.nunique().sort_values(ascending=False)
Out[8]:
gill_color                  12
cap_color                   10
spore_print_color            9
odor                         9
stalk_color_below_ring       9
stalk_color_above_ring       9
habitat                      7
population                   6
cap_shape                    6
stalk_root                   5
ring_type                    5
stalk_surface_above_ring     4
stalk_surface_below_ring     4
veil_color                   4
cap_surface                  4
ring_number                  3
stalk_shape                  2
gill_size                    2
gill_spacing                 2
gill_attachment              2
bruises                      2
toxic                        2
veil_type                    1
dtype: int64

veil_type only has 1 entry, so it adds no information to our model and can be safely dropped.

In [9]:
agar_enc = agar.drop(columns='veil_type').apply(lambda x: x.cat.codes)
In [10]:
X, y = agar.drop(columns=['toxic','veil_type']), agar.toxic
X_enc, y_enc = X.apply(lambda x: x.cat.codes), y.cat.codes # numericize
In [11]:
# categorical encoded dataset
X_train, X_test, y_train, y_test = train_test_split(X_enc, y_enc, test_size=.20, random_state=RS)
In [12]:
# One-hot encoded dataset
Xoh = pd.get_dummies(X, drop_first=False)
Xoh_train, Xoh_test, yoh_train, yoh_test = train_test_split(Xoh, y_enc, test_size=.20, random_state=RS)
X.shape,Xoh.shape
Out[12]:
((8124, 21), (8124, 116))

Using a simple one-hot encoding increases the feature count to 116, though it makes the data rather sparse, it is not unreasonable in size since we started with only 21 features and a max cardinality of 12.

In [13]:
Xoh.head()
Out[13]:
cap_shape_b cap_shape_c cap_shape_f cap_shape_k cap_shape_s cap_shape_x cap_surface_f cap_surface_g cap_surface_s cap_surface_y cap_color_b cap_color_c ... population_c population_n population_s population_v population_y habitat_d habitat_g habitat_l habitat_m habitat_p habitat_u habitat_w
0 0 0 0 0 0 1 0 0 1 0 0 0 ... 0 0 1 0 0 0 0 0 0 0 1 0
1 0 0 0 0 0 1 0 0 1 0 0 0 ... 0 1 0 0 0 0 1 0 0 0 0 0
2 1 0 0 0 0 0 0 0 1 0 0 0 ... 0 1 0 0 0 0 0 0 1 0 0 0
3 0 0 0 0 0 1 0 0 0 1 0 0 ... 0 0 1 0 0 0 0 0 0 0 1 0
4 0 0 0 0 0 1 0 0 1 0 0 0 ... 0 0 0 0 0 0 1 0 0 0 0 0

5 rows × 116 columns

In [14]:
ftnames = X.columns.values # feature names
ftnames_oh = Xoh.columns.values # One-hot encoded feature names
In [15]:
# https://towardsdatascience.com/the-search-for-categorical-correlation-a1cf7f1888c9
# https://github.com/shakedzy/dython
def conditional_entropy(x, y):
    """Calculates the conditional entropy of x given y: S(x|y)

    Parameters
    ----------
    x,y : array-like, int-encoded categorical value arrays
    
    Returns
    -------
    entropy : float
    
    See Also
    --------
    Wikipedia: https://en.wikipedia.org/wiki/Conditional_entropy
    
    """
    y=y.astype(np.int64)
    y_counter = np.bincount(y)
    xy_counter = Counter(list(zip(x,y)))
    total_occurrences = y_counter.sum()
    entropy = 0.0
    for k,v in xy_counter.items():
        p_xy = v / total_occurrences
        p_y = y_counter[k[1]] / total_occurrences
        entropy += p_xy * np.log(p_y/p_xy)
    return entropy

def cramers_v(x, y):
    """
    Calculates Cramer's V statistic for categorical-categorical association.
    Uses correction from Bergsma and Wicher, Journal of the Korean Statistical Society 42 (2013): 323-328.
    This is a symmetric coefficient: V(x,y) = V(y,x)

    Parameters
    ----------
    x : list / NumPy ndarray / Pandas Series
        A sequence of categorical measurements
    y : list / NumPy ndarray / Pandas Series
        A sequence of categorical measurements
        
    Returns 
    -------
    float, between range [0,1]
    
    See Also
    --------
    Original function from: https://stackoverflow.com/a/46498792/5863503
    Wikipedia: https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V
    """
    confusion_matrix = pd.crosstab(x,y)
    chi2 = ss.chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    phi2 = chi2/n
    r,k = confusion_matrix.shape
    phi2corr = max(0, phi2-((k-1)*(r-1))/(n-1))
    rcorr = r-((r-1)**2)/(n-1)
    kcorr = k-((k-1)**2)/(n-1)
    return np.sqrt(phi2corr/min((kcorr-1),(rcorr-1)))
    
def theils_u(x, y):
    """Calculates Theil's U statistic (Uncertainty coefficient) for categorical-categorical association.
   
    This is the uncertainty of x given y: value is on the range of [0,1] 
    - where 0 means y provides no information about x, 
    and 1 means y provides full information about x.
    This is an asymmetric coefficient: U(x,y) != U(y,x)
    
    Parameters
    ----------
    x,y : array-like, int-encoded categorical value arrays
    
    Returns
    -------
    uncertainity coefficent : float between [0,1]
    
    See Also
    --------
    Wikipedia: https://en.wikipedia.org/wiki/Uncertainty_coefficient
    """
    x=x.astype(np.int64)
    s_xy = conditional_entropy(x,y)
    x_counter = np.bincount(x)
    total_occurrences = x_counter.sum()
    p_x = x_counter/total_occurrences
    s_x = ss.entropy(p_x)
    if s_x == 0:
        return 1
    return (s_x - s_xy) / s_x

def catcorr(data, method='theils'):
    """ Compute categorical correlations using uncertainty coefficients (Theil's U) or Cramer's V
    
    Parameters
    ----------
    data : pandas.DataFrame 
        dataset used to compute correlations 
    method - str, ['theils' | 'cramers']
        method to calculate correlations, cramers is symmetric, theils is asymmetric
    
    Returns
    -------
        computed correlations : pandas.DataFrame
    """
    if method == 'cramers':
        return data.corr(method=cramers_v)
    elif method != 'theils':
        raise NotImplementedError(f"method: '{method}' not implemented, choose either 'cramers' or 'theils'")
    cols = data.columns
    clen = cols.size
    
    pairings = list(itertools.product(data.columns, repeat=2))
    theils_mat = np.reshape([theils_u(data[p[1]],data[p[0]]) for p in pairings],(clen,clen))
    return pd.DataFrame(theils_mat, index=cols, columns=cols)

Since we are working with all categorical values a traditional correlation matrix will not suffice. Instead, we'll be using an uncertainty coefficient matrix to model the asymmetric relationships between value pairs. The functions used are slight modifications of those in the dython package, modified to prefer performance over flexibility.

In [16]:
fig, ax = plt.subplots(1,1, figsize=(14,10))
sns.heatmap(catcorr(agar_enc),0,1,ax=ax,annot=True,fmt="0.2f").set_title("Uncertainty coefficient matrix")
plt.show()

The matrix reveals a few key findings, first and foremost, given we know odor we already have a substantial amount of information pertaining to the response, toxic. This will be very important to keep in mind as a we move forward with the analysis.

Additional there is a near symmetric relationship among veil_color and the stalk_color* features, this relationship is not true across all color attributes, however, as seen by spore_print_color, gill_print_color, and cap_color.

Models

We'll be using two primary models

  • RandomForestClassifier - sklearn
  • XGBoostClassifier - XGBoost

We will run both the categorical encoded and one-hot encoded data on each of these classifiers.

Helper functions

In [17]:
def multi_table(*dfs):
    """Compactly display multiple adjacent DataFrames
    
    Args: *dfs, pd.DataFrames
    See also: https://stackoverflow.com/a/44923103
    """
    html_str=''
    for df in dfs:
        html_str+=df.to_html()
    display_html(html_str.replace('table','table style="display:inline"'),raw=True)
In [18]:
def pplot_cm(y_true, y_pred, labels=None, filename=None, ymap=None, cf_report=False, figsize=(7,5), **kwargs):
    """
    Pretty plot a confusion matrix with annotations.
    Optionally save plot image to disk.
    
    Args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      filename:  filename of figure file to save
      labels:    array of shape [n_class,] name the order of class labels in the confusion matrix.
                 If None, defaults to y_true.unique()
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      cf_report: if True, print a classification report in addition to confusion matrix
      figsize:   the size of the figure plotted.
      **kwargs:  keyword arguments to pass to seaborn's heatmap
      
    See also:
        https://gist.github.com/hitvoice/36cf44689065ca9b927431546381a3f7
    """
    if ymap is not None:
        y_pred = [ymap[yi] for yi in y_pred]
        y_true = [ymap[yi] for yi in y_true]
        labels = [ymap[yi] for yi in labels]
    if cf_report: 
        print(classification_report(y_true,y_pred))
    labels=labels if labels is not None else y_true.unique()
    cm = confusion_matrix(y_true, y_pred, labels=labels)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', ax=ax, **kwargs)
    plt.savefig(filename) if filename is not None else plt.show()
In [19]:
def plot_tree(dtree, featnames, cnames=None, width=600, height=800):
    """Plot a decision tree using graphviz, pydotplus and IPython.Image
    
    Args:
        dtree : fitted DecisionTree
        featnames : array-like[str], feature names to label tree nodes
        cnames : array-like[str], class names to label tree nodes (default: None)
        width : int, width(in px) of output visualization
        height : int, height(in px) of output visualization
    Returns: 
        tree plot : IPython.Image
    See also:
        https://medium.com/@rnbrown/creating-and-visualizing-decision-trees-with-python-f8e8fa394176
    """
    dot_data = StringIO()
    tree.export_graphviz(dtree, out_file=dot_data, 
                    feature_names=featnames, class_names=cnames,
                    filled=True, rounded=True, special_characters=True)

    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    return Image(graph.create_png(), width=width, height=height)

Random Forest

In [20]:
rfc = RandomForestClassifier(100 ,n_jobs=-1, random_state=RS)
rfc.fit(X_train,y_train)
preds = rfc.predict(X_test)
pplot_cm(y_test, preds, rfc.classes_, cf_report=True, figsize=(7,5), cmap='Blues')
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       840
           1       1.00      1.00      1.00       785

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625

Our Random Forest model yielded perfect scores across the board. Now, this may seem like fantastic news, but anytime you have a perfectly performing model it should be met with a great deal of skepticism. The majority of the time, it means there is either data leakage somewhere in the dataset or some mistake made during modeling or data prep. Perhaps testing data was accidentally introduced during training, or one of the input variables is perfectly correlated with the target, or maybe the chosen train/test split just happened to pick an easy subset.

From the uncertainty coefficients matrix, we know that there is no perfect correlation between any variables and the target, and there does not appear to be any overlap between train and test data. So, we'll see if luck was the answer by doing k-fold validation.

In [21]:
skf = StratifiedKFold(5, shuffle=True, random_state=RS)
for train_idx, test_idx in skf.split(X_enc, y_enc):
    X_train, X_test, y_train, y_test = X_enc.loc[train_idx], X_enc.loc[test_idx], y_enc[train_idx], y_enc[test_idx]
    rfc.fit(X_train,y_train)
    y_pred = rfc.predict(X_test)
    print(classification_report(y_test,y_pred))
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       842
           1       1.00      1.00      1.00       784

    accuracy                           1.00      1626
   macro avg       1.00      1.00      1.00      1626
weighted avg       1.00      1.00      1.00      1626

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       842
           1       1.00      1.00      1.00       783

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       842
           1       1.00      1.00      1.00       783

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       841
           1       1.00      1.00      1.00       783

    accuracy                           1.00      1624
   macro avg       1.00      1.00      1.00      1624
weighted avg       1.00      1.00      1.00      1624

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       841
           1       1.00      1.00      1.00       783

    accuracy                           1.00      1624
   macro avg       1.00      1.00      1.00      1624
weighted avg       1.00      1.00      1.00      1624

Perfect across 5 folds, let's see how it does with 10.

In [22]:
metrics = ['precision','recall','f1','roc_auc']
scores = cross_validate(rfc, X_enc, y_enc, scoring=metrics, cv=10, return_train_score=True, n_jobs=-1)

for m in metrics:
    test_score, train_score = [scores[x] for x in scores.keys() if m in x]
    print(m+':\n','{:>4} train scores: {}'.format('',list(train_score)))
    print('{:>5} test scores: {}'.format('',list(test_score)))
    print('{:>5} test mean: {}'.format('',test_score.mean()))
precision:
      train scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9678217821782178, 1.0]
      test mean: 0.9967821782178218
recall:
      train scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test scores: [0.3469387755102041, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test mean: 0.9346938775510203
f1:
      train scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test scores: [0.5151515151515151, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9836477987421384, 1.0]
      test mean: 0.9498799313893654
roc_auc:
      train scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test scores: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
      test mean: 1.0

It does indeed appear as though some subsets of the data were easier to predict than others but on the whole, we still have a very strong classifier, well above 90% on all metrics using 10-fold cross_validation.

This is with no tweaks to the model's hyperparameters at all, just the defaults. With a bit of tweaking near perfect predictions isn't out of the realm of possibility.

In [23]:
rfc_fi = pd.DataFrame({'feature':X.columns,'importance':rfc.feature_importances_}).sort_values(by='importance',ascending=False)
sns.catplot(x='feature',y='importance',data=rfc_fi,kind='bar',aspect=1.5).set_xticklabels(rotation=90);

odor, gill_size, gill_color, and spore_print_color encapsulate a large portion of the overall feature importance for the RandomForest model. Odor makes some intuitive sense even with no fungal knowledge. Generally, smell is a fairly strong indicator of whether or not something is safe to consume. Of course, given how we have chosen to represent the data, we do not have the level of detail required to know exactly what kind of smell means edible or poisonous.

In [24]:
# filter out any non-perfect scoring decision trees, then take tree with fewest leaves 
smallest_dt = min(filter(lambda dt: dt.score(X_test, y_test) == 1.0, rfc.estimators_), key=lambda dt: dt.get_n_leaves())
In [25]:
plot_tree(smallest_dt, ftnames, ['edible','poisonous'],500,600)
Out[25]:

Reducing the random forest to a single decision tree gives some insight into the model's decision making process. However, it is not quite intuitive from a human rationalization standpoint. Next, we'll try using the one-hot encoded data to see if increased granularity provides easier interpretations.

One hot encoded

In [26]:
rfc_oh = RandomForestClassifier(100, n_jobs=-1,random_state=RS)
rfc_oh.fit(Xoh_train,yoh_train)
preds_oh = rfc_oh.predict(Xoh_test)
In [27]:
pplot_cm(yoh_test, preds_oh, rfc_oh.classes_, cf_report=True, figsize=(7,5), cmap='Blues')
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       840
           1       1.00      1.00      1.00       785

    accuracy                           1.00      1625
   macro avg       1.00      1.00      1.00      1625
weighted avg       1.00      1.00      1.00      1625

In [28]:
rfc_oh_fi = pd.DataFrame({'feature':Xoh.columns,'importance':rfc_oh.feature_importances_}).sort_values(by='importance',ascending=False)
sns.catplot(x='feature',y='importance',data=rfc_oh_fi[:21],kind='bar',aspect=1.5).set_xticklabels(rotation=90);

odor=none and odor=foul rank 1st and 3rd for feature importance with gill_size=narrow coming between. One may speculate that having no perceptible odor would indicate it may be safe for consumption.

In [29]:
odorXtox = pd.crosstab(agar.odor, agar.toxic)
gsizXtox = pd.crosstab(agar.gill_size, agar.toxic)
gcolXtox = pd.crosstab(agar.gill_color, agar.toxic)
multi_table(odorXtox,gsizXtox,gcolXtox)
toxic e p
odor
a 400 0
c 0 192
f 0 2160
l 400 0
m 0 36
n 3408 120
p 0 256
s 0 576
y 0 576
toxic e p
gill_size
b 3920 1692
n 288 2224
toxic e p
gill_color
b 0 1728
e 96 0
g 248 504
h 204 528
k 344 64
n 936 112
o 64 0
p 852 640
r 0 24
u 444 48
w 956 246
y 64 22

A quick crosstab reveals that there are 120 cases were mushroom with no odor is still poisonous, and in fact, odor=none is the only odor value without perfect dichotomous separability. Gill size is slightly less absolute in claims, but having a narrow size is a rather strong indication of it being inedible. As for color, the only attribute with substantial support is when gill_color = buff

In [30]:
smallest_dt_oh = min(filter(lambda dt: dt.score(Xoh_test, yoh_test) == 1.0, rfc_oh.estimators_), key=lambda dt: dt.get_n_leaves())
In [31]:
plot_tree(smallest_dt_oh, ftnames_oh, ['edible','posionous'])
Out[31]: