fastai Explainability with SHAP

I’ve been working my way through Practical Deep Learning for Coders, a fantastic resource from the authors of fastai. But while I enjoy the deliberate way the authors are slowly peeling back the layers to uncover what makes a neural network tick, I wanted to just rip the lid off.

So I built a classifier using the techniques provided by fastai but applied the explainability features of SHAP to understand how the deep learning model arrives at its decision.

I’ll walk you through the steps I took to create a neural network that can classify architectural styles and show you how to apply SHAP to your own fastai model. You’ll learn how to train and explain a highly accurate neural net with just a few lines of code!

Gather the data

I followed this great guide on image scraping from Google to gather images for my training set.

Due to the limited availability of images, I settled on seven architectural styles:

Gothic

Victorian

Craftsman

Classical

Modern

Tudor

Cape Cod

I do have some class imbalance:

    • Cape Cod: 94

    • Craftsman: 94

    • Tudor: 49

    • Victorian: 73

    • Classical: 148

    • Modern: 75

This could be concerning, especially given the large spread between the number of images available for Classical and Tudor architectural styles. However, the main point of this guide is to show how you can apply SHAP to a fastai model so we won’t worry too much about class imbalance here.

Set up your environment

I’ve been using Paperspace to train deep learning models for my personal use. It’s a clean, intuitive platform, and there’s a free GPU option, although those instances are first-come, first-serve so they’re often unavailable.

Instead, I opted to pay $8 a month for their Developer plan to gain access to the upgraded P4000 GPU at $0.51 per hour. Completely worth it, IMO, for the speed and near-guaranteed access.

BEWARE! Paperspace does not autosave your notebooks. I’ve been burned by this too many times. Don’t forget to hit save!

Ensure you have the following packages imported into your workspace:

import fastbook
from fastbook import *
from fastai.vision.all import *
fastbook.setup_book()

import tensorflow
import shap

import matplotlib.pyplot as pl
from shap.plots import colors

Create a DataLoaders object

fastai has a DataLoaders class that reads in your data, assigns it the correct data type, resizes, and performs data augmentation—all in one!

dblock = DataBlock(
    # define X as images and Y as categorical
    blocks=(ImageBlock(), CategoryBlock()), 

    # retrieve images from a given path
    get_items=get_image_files, 

    # set the directory name as the image classification
    get_y=parent_label, 

    # resize the images to squares of 460 pixels
    item_tfms=Resize(460), 

    # see explanation below for batch_tfms
    batch_tfms=[
        *aug_transforms(size=224, min_scale=0.75),
        Normalize
    ]
)

The batch_tfms argument performs the following transformations on each batch:

    • Resizes to squares of 224 pixels

    • Ensures that cropped images are no less than 0.75 of the original image

    • By default, flips horizontally but not vertically (desired behavior for images of buildings)

    • By default, applies a random rotation of 10 degrees

    • By default, adjusts brightness and contrast by 0.2

    • The Normalize method will normalize your pixel values to have a mean of 0 and a standard deviation of 1.

These batch transformations are performed on the GPU after the resizing specified in item_tfms, which takes place on the CPU. This order of operations ensures that our images are standardized first before we hand them off to the GPU to perform the more intensive transformations.

Let’s look at one batch.

dls = dblock.dataloaders('images/', bs=32)
dls.show_batch()

We see that some slight transformations have been performed to augment our dataset but that the integrity of the architecture has been maintained (i.e. straight lines are still straight, buildings are still the correct side up).

Train a neural net

We’ll use a technique described in Chapter 7 of Practical Deep Learning for Coders to train our neural network: progressive resizing.

The first layers of a neural network are only focused on high-level image characteristics, like edges and gradients, and the later layers start to discern finer features, like windows and cornices.

We can save time by training the neural network initially on smaller images so that the model begins to build those early layers on basic features. Then we hone our accuracy by training the model further on larger images that show more of the details.

Let’s first train on images that are 128 pixels square.

def get_dls(bs, size):
    dblock = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_items=get_image_files,
        get_y=parent_label,
        item_tfms=Resize(460),
        batch_tfms=[
            *aug_transforms(size=size, min_scale=0.75),
            Normalize
        ]
    )
    return dblock.dataloaders('images/', bs=bs)

dls = get_dls(128, 128)
learn = Learner(
    dls, 
    xresnet50(n_out=dls.c),
    loss_func=LabelSmoothingCrossEntropy(), 
    metrics=accuracy
)
learn.fit_one_cycle(8, 3e-3)
epochtrain_lossvalid_lossaccuracytime
02.0790401.9926930.16800000:12
12.0071773.2953600.16000000:10
21.9638402.7414980.18400000:09
31.8929743.0447730.19200000:09
41.8202232.2328640.34400000:10
51.7170042.5429910.33600000:09
61.6402532.1232040.34400000:09
71.5817731.8131850.42400000:10

Now we increase our image size to 224 pixels square.

learn.dls = get_dls(32, 224)
learn.fine_tune(9, 1e-3)
epochtrain_lossvalid_lossaccuracytime
01.1898841.1303920.64800000:12
11.1676261.1707920.68000000:12
21.1889281.3499470.55200000:12
31.1586771.1234950.68000000:12
41.1196561.1128120.69600000:11
51.0677101.0947950.72000000:12
61.0115491.0119650.79200000:11
70.9620060.9759190.80000000:12
80.9334950.9638010.80800000:11

Not bad! Roughly 80% accuracy after barely any code and just a few minutes of training time.

Evaluate the model

We do have some significant class imbalance so the accuracy shown above isn’t telling us the full story. Let’s look at class-based accuracy to see how the model performs on each architectural style.

preds = learn.get_preds()
pred_class = preds[0].max(1).indices
tgts = preds[1]

for i, name in enumerate(dls.train.vocab):
    idx = torch.nonzero(tgts==i)
    subset = (tgts == pred_class)[idx]
    acc = subset.squeeze().float().mean()
    print(f'{name}: {acc:.1%}')
cape_cod: 82.4%
classical: 88.9%
craftsman: 75.0%
gothic: 94.4%
modern: 92.3%
tudor: 62.5%
victorian: 63.6%

We have decent accuracy on across the classes. Let’s create a confusion matrix to see which architectural styles the model mistakes for another.

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

We see that Tudor can sometimes be misclassified as Craftsman. Perhaps this is because both styles rely on exposed beams. Similarly, we could surmise that Gothic is most often confused with Victorian because both contain ornate decorations or that Classical can be mistaken for Modern due to a prevalence of clean lines.

How can we test these hypotheses? This is where SHAP comes into play.

Explain with SHAP

SHAP explains feature importances through Shapley values, a concept borrowed from game theory. If you’re interested in learning more, I suggest checking out the SHAP documentation.

Let’s apply SHAP to the model we trained above. First, we determine a background distribution that defines the conditional expectation function. Then we sample against this background distribution to create expected gradients, allowing us to approximate Shapley values.

# pull a sample of our data (128 images)
batch = dls.one_batch()

# specify how many images to use when creating the background distribution
num_samples = 100
explainer = shap.GradientExplainer(
    learn.model, batch[0][:num_samples]
)

# calculate shapely values
shap_values = explainer.shap_values(
    batch[0][num_samples:]
)

Now we can overlay the Shapley values on the images to see which features the model focuses on to make a classification.

In the images below, positive Shapley values in red indicate those areas of the image that contributed to the final prediction whereas negative Shapley values in blue show areas that detracted from that prediction.

import matplotlib.pyplot as pl
from shap.plots import colors

for idx, x in enumerate(batch[0][num_samples:]):
    x = x.cpu() # move image to CPU
    label = dls.train.vocab[batch[1][num_samples:]][idx]
    sv_idx = list(dls.train.vocab).index(label)

    # plot our explanations
    fig, axes = pl.subplots(figsize=(7, 7))

    # make sure we have a 2D array for grayscale
    if len(x.shape) == 3 and x.shape[2] == 1:
        x = x.reshape(x.shape[:2])
    if x.max() > 1:
        x /= 255.

    # get a grayscale version of the image
    x_curr_gray = (
        0.2989 * x[0,:,:] +
        0.5870 * x[1,:,:] +
        0.1140 * x[2,:,:]
    )
    x_curr_disp = x

    abs_vals = np.stack(
        [np.abs(shap_values[sv_idx][idx].sum(0))], 0
    ).flatten()
    max_val = np.nanpercentile(abs_vals, 99.9)

    label_kwargs = {'fontsize': 12}
    axes.set_title(label, **label_kwargs)

    sv = shap_values[sv_idx][idx].sum(0)
    axes.imshow(
        x_curr_gray,
        cmap=pl.get_cmap('gray'),
        alpha=0.3,
        extent=(-1, sv.shape[1], sv.shape[0], -1)
    )
    im = axes.imshow(
        sv,
        cmap=colors.red_transparent_blue, 
        vmin=-max_val, 
        vmax=max_val
    )
    axes.axis('off')

    fig.tight_layout()

    cb = fig.colorbar(
        im, 
        ax=np.ravel(axes).tolist(),
        label="SHAP value",
        orientation="horizontal"
    )
    cb.outline.set_visible(False)
    pl.show()

Excellent! This aligns with our intuition. We see that the model relies on the beam latticework to predict Tudor, the roofline to predict Craftsman, and flying buttresses to predict Gothic. It seems to consider window and door trim important to Victorian architecture and narrow windows a key feature of Classical design.

Conclusion

It’s easier than ever to apply deep learning techniques to any project. But with great power comes great responsibility! Understanding how a model arrives at its conclusions is essential to building trust with stakeholders and debugging your model.

Our architecture classifier could also be a visual teaching tool to explain what makes each architectural design distinct and could be incorporated into some kind of flashcard system to help others learn the differences. Sometimes, the explanations can also be the goal itself!

Applying DAG’s to Causal Models

I’ve been reading “The Book of Why” by Judea Pearl over the past few weeks, which has really helped formalize my intuition of causation. However, the book would be much better if Pearl left out any sentences written in the first-person as he has an annoying tendency to style himself as a messiah proclaiming the enlightened concepts of Causation to all the lowly statisticians still stuck on Correlation.

If we can look past his self-aggrandizing remarks, “The Book of Why” applies causal models to examples from the surgeon general’s committee on smoking in the 1960’s to the Monty Hall paradox. By reducing these multi-faceted problems down to a causal representation, we can finally put our finger on contributing factors or “causes” and control for them (if possible) to isolate the effect we are attempting to discover.

Perhaps the biggest takeaway for me from this book is the need to understand the data generation process when working with a dataset. This might sound like a no-brainer but too often, data scientists are so eager to jump in to the big shiny ball pit of a new dataset that they don’t stop to think about what this data actually represents.

Data scientists with a new dataset

By including the process by which the data was generated in these causal models, we can augment our own mental model and unlock the true relationships behind the variables of interest.

So what’s a DAG?

Directly acyclic graphs (DAG’s) are a visual representation of a causal model. Here’s a simple one:

You were late for work because you had to change your car’s tire because it was flat. Of course, we could add on much more than this (why was it flat?) but you get the idea.

Junction Types

Let’s explore what we can do with DAG’s through different junction types.

Chain

This is the simplest DAG and is represented in the example above. A generalized representation below shows that A is a cause of B, which is itself a cause of C.

Collider

Now we have two causes for C. Both A and B affect the outcome C.

Conditioning on C will reveal a non-causal, negative correlation between A & B. This correlation is called collider bias.

We can understand this effect in crude mathematical terms. If A + B = C and we hold C constant, then we must increase A by the same amount we decrease B.

Additionally, this phenomenon is sometimes also called the “explain-away effect” because C “explains away” the correlation between A and B.

Note that the collider bias may be positive in cases when contributions from both A and B are necessary to affect C.

An example of a collider relationship would be the age-old nature vs. nurture question. Someone’s personality (C) is a product of both their upbringing (A) and the genes (B).

Fork

In the case of a fork, A affects both B and C.

Without conditioning on A, there exists a spurious (non-causal) correlation between B & C. A classic example of a spurious correlation is the relationship between crime (B) and ice cream sales (C). When you plot these two values over time, they appear to increase and decrease together, suggesting some kind of causality. Does ice cream cause people to commit crime?

Of course, this relationship can be explained by adding in temperature (A). Warmer weather causes people to leave their homes more often, leading to more crime (B). People also crave ice cream cones (C) on hot days.

Node Types

Mediators

A mediator is the node that “mediates” or transmits a causal effect from one node to another.

Again using the example below, B mediates the causal effect of A onto C.

Confounders

Harking back to the crime-and-ice-cream example, temperature is the confounder node as it “confounds” the relationship between ice cream sales and crime.

If we control for the confounder (A), we can isolate the relationship between C and B, if one exists. This is a key concept for experimental design.

Correcting for Confounding

Let’s spend some more time on this subject. Pearl’s assertion is that if we control for all confounders, we should be able to isolate the relationship between the variables of interest and therefore prove causation, instead of mere correlation.

Pearl defines confounding more broadly as any relationship that leads to P(Y|do(X)) \neq P(Y|X), where the do operator implies an action. In other words, if there is a difference between the probability of an outcome Y given X and the probability of Y given X in a perfect world in which we were able to change X and only X, then confounding is afoot.

Four Rules of Information Flow

Pearl has 4 rules for controlling the flow of information through a DAG.

    1. In a chain (A → B → C), B carries information from A to C. Therefore, controlling for B prevents information about A from reaching C and vice versa.

    2. In a fork (A ← B → C), B is the only known common source of information between both A and C. Therefore, controlling for B prevents information about A from reaching C and vice versa.

    3. In a collider (A → B ← C), controlling for B “opens up” the pipe between A and C due to the explain-away effect.

    4. Controlling for descendants of a variable will partially control for the variable itself. Therefore, controlling the descendant of a mediator partially closes the pipe, and controlling for the descendant of a collider partially opens the pipe.

Back-door criterion

We can use these causal models as represented by DAG’s to determine how exactly we should remove this confounding from our study.

If we are interested in understanding the relationship between only X and Y, we must identify and dispatch any confounding back-door paths, where a back-door path is any path from X to Y that starts with an arrow into X.

Pearl’s Games

Pearl devises a series of games that involve increasingly complicated DAG’s where the objective is to “deconfound” the path from X to Y. This is achieved by blocking every non-causal path while leaving all causal paths intact.

In other words, we need to identify and block all back-door paths while ensuring that any variable Z on a back-door path is not a descendant of X via a causal path to Y.

Let’s go through some examples, using the numbered games from the book.

Game 2

We need to determine which variables (if any) of A, B, C, D, or E need to be controlled in order to deconfound the path from X to Y.

There is one back-door path: X ← A → B ← D → E → Y. This path is blocked by the collider at B from the third rule of information flow.

Therefore, there is no need to control any of these variables!

Game 5

This one’s a bit more interesting. We have two back-door paths:

    1. X ← A → B ← C → Y

    2. X ← B ← C → Y

The first back-door path is blocked by a collider at B so there is no need to control any variables due to this relationship.

The second path, however, represents a non-causal path between X and Y. We need to control for either B or C.

But watch out! If we control for B, we fall into the condition outlined by Pearl’s third rule above, where we’ve controlled for a collider and thus opened up the first back-door path in this diagram.

Therefore, if we control for B, we will then have to control for A or C as well. However, we can also control for only C initially and avoid the collider bias altogether.

Conclusion

DAG’s can be an informative way to organize our mental models around causal relationships. Keeping in mind Pearl’s Four Rules of Information Flow, we can identify confounding variables that cloud the true relationship between the variables under study.

Bringing this home for data scientists, when we include the data generation process as a variable in a DAG, we remove much of the mystery surrounding such pitfalls as Simpson’s Paradox. We’re able to think more like informed humans and less like data-crunching machines—an ability we should all be striving for in our increasingly AI-driven world.

How to Scrape LinkedIn Sales Navigator with Python

This guide will show you how to use Python to:

    1. Log into LinkedIn Sales Navigator

    2. Search by company name

    3. Filter search results

    4. Scrape returned data

My code can be found on Github but I’ll explain how each section works, if you’d like to customize it for your own project.

What is LinkedIn Sales Navigator?

LinkedIn Sales Navigator is LinkedIn’s paid sales toolset. It mines all that data you and I have freely handed to LI over the years and gives sales organizations the power to create leads and manage their pipeline. It can integrate with your CRM to personalize your results and show additional information.

LI Sales Navigator markets itself to sellers but its data aggregations are also a gold mine for creating insights. For example, I wanted data on employees’ tenure with their company. This would be very difficult using vanilla LinkedIn–I’d have to click into each employee’s profile individually.

LinkedIn Sales Navigator brings that data directly to you and allows you to filter results by geography or years of experience. The screenshot below gives you an idea of the kind of data returned for employees (known as leads in LinkedIn parlance) but similar aggregations are performed on companies (known as accounts).

Example of LinkedIn Sales Navigator search results

Is scraping legal?

Scraping publicly accessible online data is legal, as ruled by a U.S. appeals court.

But just because scraping is legal doesn’t mean LinkedIn is going to make this easy for us. The code I’ll walk through will show some of the challenges you might run into when attempting to scrape LI but beware! Scraping is software’s equivalent of the Castaway raft—hacked-together and good for one ride only. LI changes its website frequently and what worked one day may not work the next.

Image credit to Spirituality and Practice

How to scrape

Setting up

I used the Python library selenium to scrape data from LinkedIn Sales Navigator within the Chrome browser. You’ll need to run the following from a terminal window to install the required libraries:

pip install selenium
pip install webdriver_manager

It’s best practice to set up a dedicated virtual environment for this task. I recommend conda, though other options include virtualenv or poetry. Check out this helpful resource to access your new conda environment from a Jupyter Notebook.

Logging in

I chose to wrap my scraping functions within a class named LIScraper. An object-oriented approach allowed me to create new Chrome sessions with each instance of LIScraper, simplifying the debugging process.

The only input when instantiating the class is path_to_li_creds. You’ll need to store your LinkedIn username and password within a text file at that destination. We can then instantiate our scraper as shown below.

scraper = LIScraper(path_to_li_creds='li_creds.txt')
scraper.log_in_to_li_sales_nav()

This code will open up a new Chrome window, navigate to the LinkedIn Sales Navigator home page, and log in using the provided credentials.

Start with the goal

Before we go any further, let’s take a quick look at the master function gather_all_data_for_company that accomplishes my specific scraping task.

I wanted to search for a given company, find all current employees with the keyword “data” in their job title, and then scrape their job title and company tenure from the website.

Let’s break this down sequentially.

1. Search by company

I needed to scrape results for 300 companies. I didn’t have the time or patience to manually review LI’s search results for each company name. So I programmatically entered in each company name from my list and assumed that the first search result would be the correct one.

This was a faulty assumption.

I then tried to guide the search algorithm by restricting results to companies within my CRM (crm_only=True) but this still did not guarantee that the first search result was the right one.

As a safeguard, I logged the name of the company whose data I was collecting and then manually reviewed all 300 matches after my scraping job finished to find those that did not match my expectations. For any mismatches, I manually triggered a scraping job after selecting the correct company from LI’s search results.

2. Search for employees

I then wanted to find all job titles containing a specific keyword.

You might notice several layers of nested try-except clauses in this function. I could not understand why the code would run successfully one minute but would then fail when I tried to execute it again immediately after. Alas, the problem was not in how I selected the element on the page but in when I attempted to select it.

I just needed to add more time (ex. time.sleep(4)) before executing my next step. Webpages can take a long time to load all their elements, and this loading time can vary wildly between sessions.

Helpful hint: If your scraping code does not execute successfully in a deterministic manner, add more time between steps.

3. Gather the data

We’re now ready to scrape some data!

First, we scroll to the bottom of the page to allow all the results to load. Then we assess how many results LI returned.

CAUTION: The number of results actually returned by LinkedIn does not always match the number LinkedIn claims to have returned.

I had initially tried to scrape 25 results per page or the remainder if the number of results returned was not an even multiple of 25. For example, if the number of results LI claimed to have returned was 84, I’d scrape three pages of 25 results each and then scrape the remaining 9 results on the last page.

But my job would throw an error when this last page contained just 8 results. Why would LI claim to have found 84 results when in reality, it only had 83? That remains one of the great mysteries of the internet.

To get around this issue, I counted the number of headshots on the page to indicate how many results I’d need to scrape.

Scraping the data itself is relatively trivial once you understand the structure of the search results page. My strategy to find the path to the element I wanted was to right click and select “Inspect”. Then right-click on the highlighted HTML on the far-right (check out the orange arrow below) and go to Copy → Copy full XPath.

I stored the page’s search results in a pandas dataframe and concatenated the data from each new page onto the previous ones.

One last warning

Remember LI isn’t a big fan of scrapers. Your Chrome window will flash the dreaded 429 error if you hit their webpage too frequently. This error occurs when you exceed LI’s rate-limiting threshold. I am not sure what that threshold is or how exactly long you must wait before they reset your allowance.

I needed to scrape data from 300 companies whose returned search results ranged from 10 to 1000. My final dataset contained nearly 32,000 job titles. I only ran into the 429 error twice. Each time I simply paused my work for a couple hours before I restarting.