Feature Importance Stability

In this example, we will probe the the stability of the permutation feature importance metric for random forest relative to data resampling, data preprocessing, and model hyperparameter perturbations.

vflow supports automatic parallelization using ray. We can use ray to compute downstream results by setting is_async=True when constructing a Vset:

Define stability target

Below, we create a Vset which applies three custom data preprocessing functions and another that calculates the permutation importance metric via the function sklearn.inspection.permutation_importance.

Define model hyperparameter perturbations

We can also specify modeling perturbations, both within a single class of models (hyperparameter perturbations) and across different classes. Here we'll use the helper build_vset to create hyperparameter perturbations for random forest.

Define data perturbations

For stability analysis, it is often useful to add data perturbations such as the bootstrap in order to assess stability over resampling variability in the data. We can lazily compute the bootstrap, such that data will not be resampled until needed, by setting lazy=True when constructing a Vset:

Fit all models for all combinations of resampling and preprocessing

Now we can load in our data and fit each of the four random forest models to the 300 combinations of resampled training data and preprocessing functions.

We can examine the pipeline graph to see what happened so far using the utility function build_graph:

Calculate feature importances and perturbation statistics

Finally, we calculate the importance metric and examine its mean and standard deviation across bootstrap perturbations for each combination of data preprocessing and modeling hyperparameters. This allows us to assess the stability of the feature importances conditioned on different pipeline paths:

Multiple outputs can be split with dict_to_df using param_key='out'. We use it below to split feature importances into mean and std:

We can compute statistics on a single iterable item of the output by passing wrt=out-col and split=True:

From here, we can (optionally) filter over the data preprocessing and modeling perturbations via the helper filter_vset_by_metric to select the top combinations in terms of stability (or another metric of interest) and continue our analysis on a held-out test set.