bex.explainers package

Module contents

class bex.explainers.Dice(num_explanations=10, lr=0.1, num_iters=50, proximity_weight=1, diversity_weight=1)

Bases: ExplainerBase

DiCE explainer as described in https://arxiv.org/abs/1905.07697

Parameters:
  • num_explanations (int, optional) – number of counterfactuals to be generated (default: 10)

  • lr (float, optional) – learning rate (default: 0.1)

  • num_iters (int, optional) – number of gradient descent steps to perform (default: 50)

  • proximity_weight (float, optional) – weight of the reconstruction term \(\lambda_1\) in the loss function (default: 1.0)

  • diversity_weight (float, optional) – weight of the diversity term \(\lambda_2\) in the loss function (default: 1.0)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.Dive(num_explanations=10, lr=0.1, num_iters=50, diversity_weight=0.001, lasso_weight=0.1, reconstruction_weight=0.0001, method='fisher_spectral')

Bases: ExplainerBase

DiVE algorithm as described in https://arxiv.org/abs/2103.10226

Parameters:
  • num_explanations (int, optional) – number of counterfactuals to be generated (default: 10)

  • lr (float, optional) – learning rate (default: 0.1)

  • num_iters (int, optional) – number of gradient descent steps to perform (default: 50)

  • diversity_weight (float, optional) – weight of the diversity term in the loss function (default: 0)

  • lasso_weight (float, optional) – factor \(\gamma\) that controls the sparsity of the latent space (default: 0.1)

  • reconstruction_weight (float, optional) – weight of the reconstruction term in the loss function (default: 0.001)

  • method (string, optional) – method used for gradient masking (default: ‘fisher_spectral’)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.ExplainerBase

Bases: object

Base class for all explainer methods

If you wish to test your own explainer on our benchmark use this as a base class and override the explain_batch method

Example

import random
from bex.explainers import ExplainerBase

class DummyExplainer(ExplainerBase):

    def __init__(self, num_explanations):
        super().__init__()
        self.num_explanations = num_explanations

    def explain_batch(self, latents, logits, images, classifier, generator):

        b = latents.shape[0]
        # we will produce self.num_explanations counterfactuals per sample
        z = latents[:, None, :].repeat(1, self.num_explanations, 1)
        z_perturbed = z + random.random() # create counterfactuals z'

        return z_perturbed.view(b, self.num_explanations, -1)

bn = bex.Benchmark()
bn.run(DummyExplainer, num_explanations=10)
abstract explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.GrowingSpheres(num_explanations=10, n_candidates=50, first_radius=10, decrease_radius=2)

Bases: ExplainerBase

Growing Spheres explainer as described in https://arxiv.org/abs/1712.08443

num_explanations (int, optional): number of counterfactuals to be generated (default: 10)

n_candidates (int, optional): number of observations \(n\) to generate at each step (default: 50) first_radius(float, optional): radius \(\eta\) of the first hyperball generated (default: 10.0) decrease_radius(float, optional): parameter controlling the size of the radius at each step (default: 2.0)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.LCF(num_explanations=10, lr=0.1, num_iters=50, p=0.1, tolerance=0.5)

Bases: ExplainerBase

Latent-CF explainer as described in https://arxiv.org/abs/2012.09301

Parameters:
  • num_explanations (int, optional) – number of counterfactuals to be generated (default: 10)

  • lr (float, optional) – learning rate (default: 0.1)

  • num_iters (int, optional) – max number of gradient descent steps to perform without convergence (default: 50)

  • p (float, optional) – probability \(p\) of target counterfactual class.

  • tolerance (float, optional) – 0.5)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.Stylex(num_explanations=10, t=0.3, shift_size=0.8, strategy='independent')

Bases: ExplainerBase

StylEx explainer as described in https://arxiv.org/abs/2104.13369

num_explanations (int, optional): number of counterfactuals to be generated (default: 10) t (float, optional): perturbation threshold \(t\) to consider a sample explained (default: 0.3) shift_size (float, optional): amount of shift applied to each coordinate (default: 0.8) strategy(string, optional): selection strategy ‘independent’ or ‘subset’ (default: ‘independent’)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)

class bex.explainers.Xgem(num_explanations=10, lr=0.1, num_iters=50, reconstruction_weight=0.001)

Bases: Dive

xGEM explainer as described in https://arxiv.org/abs/1806.08867

Parameters:
  • num_explanations (int, optional) – number of counterfactuals to be generated (default: 10)

  • lr (float, optional) – learning rate (default: 0.1)

  • num_iters (int, optional) – number of gradient descent steps to perform (default: 50)

  • reconstruction_weight (float, optional) – weight of the reconstruction term in the loss function (default: 0.01)

explain_batch(latents, logits, images, classifier, generator)

Method to generate a set of counterfactuals for a given batch

Parameters:
  • latents (torch.Tensor) – standardized latent \(\textbf{z}\) representation of samples to be perturbed

  • logits (torch.Tensor) – classifier logits given \(\textbf{z}\)

  • images (torch.Tensor) – images \(x\) produced by the generator given \(\textbf{z}\)

  • classifier (torch.nn.Module) – classifier to explain \(\hat{f}(x)\)

  • generator (callable) – function that takes a batch of latents \(\textbf{z'}\) and returns a batch of images

Returns:

the obtained counterfactuals \(\textbf{z'}\) for each batch element

Return type:

(torch.Tensor)

Shape:

latents \((B, Z)\)

logits \((B, 2)\)

images \((B, C, H, W)\)

obtained counterfactuals: \((B, n\_explanations, Z)\)