SHAP Example

The SHAP call is very similar to the LIME call. For the example, we load an image from the imagenet-sample-images repository. An explanation will be generated for the image later on.

using ExplainableAI
using Metalhead: ResNet
using JML_XAI_Project
using Images
using VisionHeatmaps

img = load("images/dog.jpeg")
Example block output

Image pre-processing

The image is processed in order to use it as input for a model.

imgVec = permutedims(channelview(img),(3,2,1));
imgVec = reshape(imgVec, size(imgVec)..., 1);
input = Float32.(imgVec);

Generation of the explanation

The next step is to initialize a pre-trained ResNet model and apply SHAP to it.

Info

Any classifier or regressor can be used at this point.

model = ResNet(18; pretrain = true);
model = model.layers;

#A different kernel function is used to call SHAP,
#and no feature selection is required, which is why LASSO must be set to false
analyzer = LIME(model, agnostic_kernel, false)
expl = analyze(input, analyzer);
XAIBase.Explanation{Array{Float64, 4}, Matrix{Float32}, CartesianIndex{2}, Nothing}([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0;;;;], Float32[1.0904795; -0.4928558; … ; 1.3466196; 1.1687055;;], CartesianIndex(239, 1), :MyMethod, :attribution, nothing)

Visualize explanaition

The generated explanation can now be displayed as a heat map.

using VisionHeatmaps
heatmap(expl.val)
Info

The following code generates a clearer representation of the explanation.

function generate_heatmap(map; img=nothing, overlay=false, blurring=false, gaussSTD=2)
    map = heatmap(map.val)

    if blurring == true
        gaussKern2 = ImageFiltering.KernelFactors.gaussian((gaussSTD,gaussSTD))
        map = ImageFiltering.imfilter(map, gaussKern2)
    end

    if overlay == true
        map = (0.5.*Gray.(img) + 0.5.*map)
    end

    return map
end
generate_heatmap (generic function with 1 method)
generate_heatmap(expl, img=img, overlay=true, blurring=true)
Example block output

Label

To find out the generated corresponding label of the image, we output the number of the label, which can be looked up in the linked text file.

print("Label: ", argmax(expl.output[:,1]) - 1)
Label: 238