Skip to content

API documentation

AutoeXplainer

AutoExplainer

The main class that evaluates a series of explanation methods and chooses the best one.

Attributes:

Name Type Description
raw_results Dict

Raw values of metrics computed for each observation for each explanation.

first_aggregation_results Dict

Values of metrics aggregated across observations, i.e. each explanation function has value for each metric.

second_aggregation_results Dict

Values of metrics aggregated for each explanation method. Each explanation method has single value, that represents overall quality.

best_explanation_name str

Name of the selected best explanation found.

Source code in autoexplainer/autoexplainer.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
class AutoExplainer:
    """
    The main class that evaluates a series of explanation methods and chooses the best one.
    Attributes:
        raw_results (Dict): Raw values of metrics computed for each observation for each explanation.
        first_aggregation_results (Dict): Values of metrics aggregated across observations, i.e. each explanation
            function has value for each metric.
        second_aggregation_results (Dict): Values of metrics aggregated for each explanation method. Each explanation
            method has single value, that represents overall quality.
        best_explanation_name (str): Name of the selected best explanation found.
    """

    KNOWN_EXPLANATION_HANDLERS: Dict = {
        "kernel_shap": KernelShapHandler,
        "integrated_gradients": IntegratedGradients,
        "grad_cam": GradCamHandler,
        "saliency": SaliencyHandler,
    }
    KNOWN_METRIC_HANDLERS: Dict = {
        "faithfulness_estimate": FaithfulnessEstimateHandler,
        "average_sensitivity": AvgSensitivityHandler,
        "irof": IROFHandler,
        "sparseness": SparsenessHandler,
    }
    KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS: Dict = {
        "mean": first_stage_aggregation_mean,
        "median": first_stage_aggregation_median,
        "max": first_stage_aggregation_max,
        "min": first_stage_aggregation_min,
    }
    KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS: Dict = {
        "rank_based": second_stage_aggregation_rank_based,
        "weighted_mean": second_stage_aggregation_weighted_mean,
    }

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, device: str = "cpu", seed: int = 42
    ):
        """

        Args:
            model (torch.nn.Module): Convolutional neural network to be explained. On this model some explanation and metric
                parameters will be inferred.
            data (torch.Tensor): Data that will be used for explanation method evaluation. shape: (N, C, H, W).
            targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).
        """
        self._check_model(model)
        self._check_data(data, targets)
        self._second_stage_aggregation_function_name = None
        self._first_stage_aggregation_function_name = None

        self.explanation_handlers: Dict = None
        self.metric_handlers: Dict = None
        self.first_aggregation_results: Dict = None
        self.second_aggregation_results: Dict = None
        self.best_explanation_name: str = None
        self.aggregation_parameters: Dict = None

        model = model.to(device)
        model.eval()
        self.model = fix_relus_in_model(model)
        self.data = data.to(device)
        self.targets = targets.to(device)
        self.raw_results: Dict = {}
        self.times_methods: Dict = {}
        self.times_metrics: Dict = {}
        self.times_metrics_aggregated: Dict = {}
        self.times_total = 0.0
        self.device = device
        self._set_seed(seed)

    def evaluate(
        self,
        explanations: List[str] = None,
        metrics: List[str] = None,
        explanation_params: Dict = None,
        metrics_params: Dict = None,
    ) -> None:
        """
        Evaluates explanation methods. Stores results in ``.raw_results`` attribute.
        Args:
            explanations (List[str]): List of names of explanation methods to be evaluated.
                                      By default, uses all available explanation methods.
                                      Accepts lists with subset of: ``{"saliency", "grad_cam", "integrated_gradients", "kernel_shap"}``.
            metrics (List[str]): List of names of evaluation metrics to be used. By default, uses all available metrics.
                                Accepts lists with subset of: ``{"irof", "sparseness", "average_sensitivity", "faithfulness_estimate"}``.
            explanation_params (Dict[str, Dict]): Allows to override default parameters of selected explanation functions.
                                                Accept Dictionary with form ``{"explanation_name": <Dictionary with parameters>}``.
                                                See corresponding ExplanationHandler to see what parameters are accepted.
            metrics_params (Dict[str, Dict]): Allows to override default parameters of selected metrics.
                                              Accept Dictionary with form ``{"metric_name": <Dictionary with parameters>}``.
                                              See corresponding MetricHandler to see what parameters are accepted.

        """
        self._check_method_and_metric_names(explanations, metrics)
        if explanations is None:
            explanations = self.KNOWN_EXPLANATION_HANDLERS
        if metrics is None:
            metrics = self.KNOWN_METRIC_HANDLERS
        self._check_method_and_metric_params_dicts(explanations, metrics, explanation_params, metrics_params)
        if explanation_params is None:
            explanation_params = {}
        if metrics_params is None:
            metrics_params = {}

        print("\nPreparing explanation methods and metric handlers...\n")

        self.explanation_handlers = {
            explanation_name: self.KNOWN_EXPLANATION_HANDLERS[explanation_name](
                self.model, self.data, self.targets, explanation_params.get(explanation_name)
            )
            for explanation_name in explanations
        }
        self.metric_handlers = {
            metric_name: self.KNOWN_METRIC_HANDLERS[metric_name](
                self.model, self.data, self.targets, metrics_params.get(metric_name)
            )
            for metric_name in metrics
        }
        self.times_metrics = {metric_name: {} for metric_name in metrics}

        print("\tNumber of explanation methods to evaluate: ", len(self.explanation_handlers))
        print(
            "\tExplanation methods selected: "
            + f"{', '.join([EXPLANATION_NAME_SHORT_TO_LONG[x] for x in list(self.explanation_handlers.keys())])}"
        )
        print("")
        print("\tNumber of metrics used during evaluation: ", len(self.metric_handlers))
        print(
            "\tMetrics selected: "
            + f"{', '.join([METRIC_NAME_SHORT_TO_LONG[x] for x in list(self.metric_handlers.keys())])}"
        )

        pbar = tqdm.tqdm(self.explanation_handlers.items(), desc="Creating attributions")
        for explanation_name, explanation_handler in pbar:
            start_time = time.time()
            pbar.set_description(f"Creating attributions for {explanation_name}")
            explanation_handler.explain(model=self.model, data=self.data, targets=self.targets)
            self.times_methods[explanation_name] = round(time.time() - start_time, 3)

        for explanation_name in self.explanation_handlers.keys():
            self.raw_results[explanation_name] = {}

        print("Creating attribution finished. Starting evaluation.")
        print("Evaluation may take a very long time, please be patient...")

        pbar = tqdm.tqdm(
            itertools.product(self.metric_handlers.items(), self.explanation_handlers.items()),
            total=len(self.metric_handlers) * len(self.explanation_handlers),
            desc="Evaluating metrics",
        )
        for (metric_name, metric_handler), (explanation_name, explanation_handler) in pbar:
            start_time = time.time()
            pbar.set_description(f"Evaluating: method {explanation_name} and metric {metric_name}")
            self.raw_results[explanation_name][metric_name] = metric_handler.compute_metric_values(
                model=self.model,
                data=self.data,
                targets=self.targets,
                attributions=explanation_handler.attributions.to(next(self.model.parameters()).device),
                explanation_func=explanation_handler.get_explanation_function(),
            )
            self.times_metrics[metric_name][explanation_name] = round(
                time.time() - start_time + self.times_methods[explanation_name], 3
            )

        self.times_metrics_aggregated = {
            metric_name: round(sum(self.times_metrics[metric_name].values()), 3)
            for metric_name in self.times_metrics.keys()
        }
        self.times_total = round(sum(self.times_metrics_aggregated.values()), 3)

        print(f"Evaluating metrics finished after {self.times_total} seconds.")

    def aggregate(
        self,
        first_stage_aggregation_function_name: str = "mean",
        second_stage_aggregation_function_name: str = "rank_based",
        second_stage_aggregation_function_aggregation_parameters: Dict = None,
    ) -> None:
        """
        Aggregates raw result computed in .evaluate() method in two steps. First, aggregates metric scores across
        provided observations, i.e. each explanation method has a  value for each metric. Secondly, aggregates
        scores across available metrics, i.e. each explanation method has a single value that represents overall quality.

        Stores both aggregation steps in the attributes ``first_aggregation_results`` and ``second_aggregation_results``.

        Args:
            first_stage_aggregation_function_name ({"mean", "median", "min","max"}): Name of the function for the first stage aggregation.
            second_stage_aggregation_function_name ({"mean", "median", "min","max"}): Name of the function for second stage aggregaton.
            second_stage_aggregation_function_aggregation_parameters (Dict): Parameters for the second stage aggregation function.

        """

        self._check_is_after_evaluation()
        self._check_aggregation_parameters(
            first_stage_aggregation_function_name,
            second_stage_aggregation_function_name,
            second_stage_aggregation_function_aggregation_parameters,
        )
        self._first_stage_aggregation_function_name = first_stage_aggregation_function_name
        self._second_stage_aggregation_function_name = second_stage_aggregation_function_name

        if second_stage_aggregation_function_aggregation_parameters is None:
            second_stage_aggregation_function_aggregation_parameters = {}
        self.first_aggregation_results = self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS[
            first_stage_aggregation_function_name
        ](self.raw_results)
        self.second_aggregation_results = self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS[
            second_stage_aggregation_function_name
        ](self.first_aggregation_results, second_stage_aggregation_function_aggregation_parameters)
        sorted_results = sorted(self.second_aggregation_results.items(), key=lambda x: x[1], reverse=True)
        if len(sorted_results) > 1:
            best_result, second_best_result = sorted_results[0], sorted_results[1]
            if best_result[1] == second_best_result[1]:
                if self.times_methods[best_result[0]] > self.times_methods[second_best_result[0]]:
                    best_result = second_best_result
            self.best_explanation_name = best_result[0]
        else:
            self.best_explanation_name = sorted_results[0][0]

        self.aggregation_parameters = {
            "first_stage_aggregation_function": self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS[
                first_stage_aggregation_function_name
            ],
            "second_stage_aggregation_function": self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS[
                second_stage_aggregation_function_name
            ],
            "second_stage_aggregation_function_aggregation_parameters": second_stage_aggregation_function_aggregation_parameters,
        }

    def to_html(
        self, file_path: str, model_name: str = None, dataset_name: str = None, labels: Dict[int, str] = None
    ) -> None:
        """
        Generates evaluation report as HTML file.
        Args:
            file_path (str): Target file path.
            model_name (str): Name of model to show inside report.
            dataset_name (str): Name of dataset to show inside report.
            labels (Dict[int,str]): Mapping between class number and class names. e.g. ``{0:"dog", 1:"cat", 2:"fish"}`` for labels
                                    inside report.
        """
        assert self.first_aggregation_results is not None, "Aggregated results are needed for report generation."
        assert self.second_aggregation_results is not None, "Aggregated results are needed for report generation."

        environment = Environment(loader=PackageLoader("autoexplainer"))
        template = environment.get_template("report.html")

        report_info = self._get_info_for_report(labels=labels)

        pic_io_bytes = io.BytesIO()
        fig = report_info["fig_with_examples"]
        fig.savefig(pic_io_bytes, format="png")
        pic_io_bytes.seek(0)
        pic_hash = base64.b64encode(pic_io_bytes.read())

        _, _, *float_columns, _ = report_info["result_dataframe"].columns

        html_table = (
            report_info["result_dataframe"]
            .style.set_properties(
                subset=["Agg. Score", "Explanation Name"], **{"font-weight": "bold", "text-align": "center"}
            )
            .set_properties(border=0)
            .hide_index()
            .format("{:.3f}", subset=float_columns)
            .render()
        )

        rendered = template.render(
            model_name=model_name,
            dataset_name=dataset_name,
            dataframe_html=html_table,
            pic_hash=pic_hash.decode(),
            **report_info,
        )
        with open(file_path, mode="w", encoding="utf-8") as results:
            results.write(rendered)

    def to_pdf(
        self,
        folder_path: str = "",
        model_name: str = "name of the model",
        dataset_name: str = "name of the dataset",
        labels: Dict[int, str] = None,
    ) -> None:

        """
        Creates PDF report from dict stored in the attribute ``first_aggregation_results``.
        Needs Latex packages installed to run - see README.

        Args:
            folder_path (str): Path to directory, where the reports (PDF and tex) should be created.
            model_name (str): Name of model to show inside report.
            dataset_name (str): Name of dataset to show inside report.
            labels (Dict[int,str]): Mapping between class number and class names. e.g. ``{0:"dog", 1:"cat", 2:"fish"}`` for labels
                                    inside report.

        """
        self._check_is_after_aggregation()

        tex_file = os.path.join(folder_path, "report.tex")
        pdf_file = os.path.join(folder_path, "report.pdf")

        if os.path.exists(tex_file):
            os.remove(tex_file)
        if os.path.exists(pdf_file):
            os.remove(pdf_file)

        report_info = self._get_info_for_report(labels=labels)

        left_margin = 2
        max_nr_columns_in_table = 5
        geometry_options = {"tmargin": "2cm", "lmargin": f"{left_margin}cm"}
        doc = Document(geometry_options=geometry_options)
        doc.preamble.append(Command("title", "AutoeXplainer Report"))
        doc.preamble.append(Command("date", ""))
        doc.packages.append(Package("hyperref"))
        doc.packages.append(Package("booktabs"))
        doc.append(NoEscape(r"\maketitle"))

        results = report_info["result_dataframe"]

        metric_name_copy = copy.deepcopy(METRIC_NAME_SHORT_TO_LONG)
        metric_name_copy["explanation_name"] = "explanation name"
        metric_name_copy["Rank"] = "Rank"
        metric_name_copy["Agg. Score"] = "Agg. Score"
        explanation_methods = report_info["methods"]
        metrics = report_info["metrics"]
        metrics_used = copy.deepcopy(metrics)

        metrics = ["explanation name", "Rank"] + metrics + ["Agg. Score"]
        data = copy.deepcopy(results)

        def hyperlink(url: str, text: str) -> NoEscape:  # type: ignore
            return NoEscape(r"\href{" + url + "}{" + escape_latex(text) + "}")

        # create content of  the Document
        with doc.create(Section("General information", numbering=False)):
            doc.append(bold("Model name: "))
            doc.append(italic(f"{model_name} \n"))
            doc.append(bold("Dataset name: "))
            doc.append(italic(f"{dataset_name} \n"))
            doc.append(bold("Execution time: "))
            doc.append(italic(f"{report_info['execution_time']} s \n"))
            doc.append(bold("Package version: "))
            doc.append(italic(f"{report_info['autoexplainer_version']} \n"))
            doc.append(bold("Date: "))
            doc.append(italic(f"{report_info['date']} \n"))
            doc.append(bold("Selected method: "))
            doc.append(italic(f"{report_info['selected_method']} \n"))
            doc.append(bold("Number of images: "))
            doc.append(italic(f"{report_info['n_images']}"))

        with doc.create(Section("Model performance", numbering=False)):
            doc.append(bold("Accuracy: "))
            doc.append(italic(f"{report_info['model_acc']} \n"))
            doc.append(bold("F1 macro: "))
            doc.append(italic(f"{report_info['model_f1_macro']} \n"))
            doc.append(bold("Balanced accuracy: "))
            doc.append(italic(f"{report_info['model_bac']} \n"))

        with doc.create(Section("Table of results", numbering=False)):
            doc.append(NoEscape(r"\begin{footnotesize}"))
            doc.append(NoEscape(r"\begin{flushleft} "))
            doc.append(NoEscape(report_info["result_dataframe"].to_latex(index=False)))
            doc.append(NoEscape(r"\end{flushleft}"))
            doc.append(NoEscape(r"\end{footnotesize}"))
            doc.append(bold("Table description \n"))
            doc.append(
                "Arrow next to the metric names indicates whether larger or smaller values of metric are better. Time elapsed shows time that was required for computation of attribution for given batch of images. When there is a tie in Aggregated Score, the best metric is chosen based on computation time."
            )

        doc.append(NewPage())
        with doc.create(Section("Details", numbering=False)):
            with doc.create(Subsection("Explanations:", numbering=False)):
                with doc.create(Itemize()) as itemize:
                    for i in range(0, len(data.iloc[:, 0])):
                        explanation_name = EXPLANATION_NAME_SHORT_TO_LONG[explanation_methods[i]]
                        itemize.add_item(bold(explanation_name))
                        doc.append(EXPLANATION_DESCRIPTION[str(explanation_name)][0])
                        doc.append(
                            hyperlink(
                                EXPLANATION_DESCRIPTION[str(explanation_name)][1],
                                EXPLANATION_DESCRIPTION[str(explanation_name)][2],
                            )
                        )
                        doc.append("\n")
                        doc.append("Explanation's parameters: \n")
                        doc.append(NoEscape(r"\texttt{"))
                        doc.append(f"{report_info['method_parameters'][explanation_methods[i]]} \n")
                        doc.append(NoEscape(r"}"))
            doc.append(NewPage())
            with doc.create(Subsection("Metrics:", numbering=False)):
                with doc.create(Itemize()) as itemize:
                    minus = 2
                    for i in range(2, len(data.columns) - 1):
                        if data.columns[i] == "Time elapsed [s]":
                            minus += 1
                        else:
                            itemize.add_item(bold(METRIC_NAME_MEDIUM_TO_LONG[data.columns[i]]))
                            doc.append(METRIC_DESCRIPTIONS[data.columns[i]][0])
                            doc.append(
                                hyperlink(
                                    METRIC_DESCRIPTIONS[data.columns[i]][1], METRIC_DESCRIPTIONS[data.columns[i]][2]
                                )
                            )
                            doc.append("\n")
                            doc.append("Metric's parameters: \n")
                            doc.append(NoEscape(r"\texttt{"))
                            doc.append(f"{report_info['metric_parameters'][metrics_used[i-minus]]} \n")
                            doc.append(NoEscape(r"}"))
            with doc.create(Subsection("Aggregation parameters", numbering=False)):
                doc.append(NoEscape(r"\texttt{"))
                doc.append(report_info["aggregation_parameters"])
                doc.append(NoEscape(r"}"))
        doc.append(NewPage())
        with doc.create(Section("Examples of explanations", numbering=False)):
            with doc.create(Figure(position="!h")) as mini_logo:
                fig = report_info["fig_with_examples"]
                mini_logo.add_plot(fig=fig, width=f"{21 - 2 * left_margin}cm")

        doc.generate_pdf(os.path.join(folder_path, "report"), clean_tex=False)

    def get_best_explanation(self) -> BestExplanation:
        """
        Returns an object with the selected best explanation method wrapped with a few additions, see BestExplanation for more details.
        Returns (BestExplanation): BestExplanation object

        """
        self._check_is_after_aggregation()
        best_explanation_handler = self.explanation_handlers[self.best_explanation_name]
        return BestExplanation(
            attributions=best_explanation_handler.attributions,
            explanation_function=best_explanation_handler.get_explanation_function(),
            explanation_name=self.best_explanation_name,
            explanation_function_parameters=best_explanation_handler.explanation_parameters,
            metric_handlers=self.metric_handlers,
            aggregation_parameters=self.aggregation_parameters,
        )

    def _get_info_for_report(self, labels: Union[Dict[int, str], None] = None) -> Dict:
        pp = pprint.PrettyPrinter(indent=4)
        dict_list_for_df = []
        methods = []
        for k, v in self.first_aggregation_results.items():  # noqa: B007
            methods.append(k)
            dict_list_for_df.append(v)

        metrics = list(self.first_aggregation_results[methods[0]].keys())

        methods_full_names = [EXPLANATION_NAME_SHORT_TO_LONG[x] for x in methods]
        metrics_full_names = [METRIC_NAME_LONG_TO_MEDIUM[x] for x in metrics]

        df = pd.DataFrame(dict_list_for_df, index=methods_full_names)
        df.columns = metrics_full_names
        df["Time elapsed [s]"] = pd.Series(
            {EXPLANATION_NAME_SHORT_TO_LONG[k]: v for k, v in self.times_methods.items()}
        )
        agg_score = pd.Series(self.second_aggregation_results)
        agg_score = agg_score.set_axis([EXPLANATION_NAME_SHORT_TO_LONG[x] for x in agg_score.index])

        df["Agg. Score"] = agg_score
        df = df.sort_values(["Agg. Score", "Time elapsed [s]"], ascending=[False, True])

        df["Rank"] = np.arange(len(df)) + 1
        cols = df.columns.tolist()
        df = df[[cols[-1]] + cols[:-1]]
        df = df.round(3)  # type: ignore
        df.reset_index(inplace=True)
        df.rename(columns={"index": "Explanation Name"}, inplace=True)

        method_parameters = {k: pp.pformat(v.explanation_parameters) for k, v in self.explanation_handlers.items()}
        metric_parameters = {k: pp.pformat(v.metric_parameters) for k, v in self.metric_handlers.items()}

        fig = self._generate_plot_for_report(labels=labels)

        metric_parameters = extract_function_names(metric_parameters)
        method_parameters = extract_function_names(method_parameters)

        aggregation_parameters = {
            "first_stage_aggregation_function": self._first_stage_aggregation_function_name,
            "second_stage_aggregation_function": self._second_stage_aggregation_function_name,
            "second_stage_aggregation_function_aggregation_parameters": self.aggregation_parameters[
                "second_stage_aggregation_function_aggregation_parameters"
            ],
        }
        n_images = len(self.targets)
        aggregation_parameters_str = pp.pformat(aggregation_parameters)
        model_performance = self._evaluate_model_performance()
        return {
            "execution_time": self.times_total,
            "selected_method": EXPLANATION_NAME_SHORT_TO_LONG[self.best_explanation_name],
            "result_dataframe": df,
            "methods": methods,
            "metrics": metrics,
            "aggregation_parameters": aggregation_parameters_str,
            "method_parameters": method_parameters,
            "metric_parameters": metric_parameters,
            "autoexplainer_version": _get_package_version(),
            "date": date.today(),
            "fig_with_examples": fig,
            "n_images": n_images,
            **model_performance,
        }

    def _evaluate_model_performance(self) -> Dict:
        predictions = self.model(self.data).detach().cpu()
        predicted_labels = predictions.argmax(axis=1).numpy()
        y_true = self.targets.detach().cpu().numpy()
        return {
            "model_acc": round(accuracy_score(y_true, predicted_labels), 3),
            "model_f1_macro": round(f1_score(y_true, predicted_labels, average="macro"), 3),
            "model_bac": round(balanced_accuracy_score(y_true, predicted_labels), 3),
        }

    def _generate_plot_for_report(
        self, count_of_images: int = 10, labels: Union[Dict[int, str], None] = None
    ) -> plt.Figure:
        number_of_explanations = len(self.explanation_handlers)
        number_of_columns = number_of_explanations + 1
        number_of_images = min(count_of_images, len(self.data))

        if labels is None:
            labels = {}

        ids_of_images_to_show = []
        classes = list(self.targets.unique().cpu().detach().numpy())
        number_of_classes = len(classes)
        images_per_class = int(number_of_images / number_of_classes)
        for class_id in classes:
            ids_of_images_from_this_class = [
                i for i, x in enumerate(self.targets.cpu().detach().tolist()) if x == class_id
            ]
            ids_of_images_to_show += ids_of_images_from_this_class[:images_per_class]

        number_of_images = len(ids_of_images_to_show)

        fig = plt.figure(figsize=(10, 1.6 * number_of_images + 1))
        grid = ImageGrid(
            fig,
            111,
            nrows_ncols=(number_of_images, number_of_columns),
            axes_pad=0,
            share_all=True,
        )
        grid[0].set_xticks([])
        grid[0].set_yticks([])

        cmap = LinearSegmentedColormap.from_list("red-white-green", ["red", "white", "green"])
        vmin, vmax = -1, 1

        images_to_plot: Dict[str, list] = {"Original image": []}

        for original_image in self.data[ids_of_images_to_show]:
            images_to_plot["Original image"].append(normalize_image(torch_image_to_numpy_image(original_image)))

        for explanation_name in self.explanation_handlers:
            attributions_to_plot = self.explanation_handlers[explanation_name].attributions[ids_of_images_to_show]
            full_explanation_name = EXPLANATION_NAME_SHORT_TO_LONG[explanation_name]
            images_to_plot[full_explanation_name] = []
            for attribution in attributions_to_plot:
                attribution = torch_image_to_numpy_image(attribution)
                if explanation_name == "integrated_gradients":
                    attribution[attribution > np.percentile(attribution, 99.5)] = np.percentile(attribution, 99.5)
                    attribution[attribution < np.percentile(attribution, 0.5)] = np.percentile(attribution, 0.5)
                attribution_scaled = attribution / np.max(np.abs(attribution))
                images_to_plot[full_explanation_name].append(attribution_scaled)

        order_of_columns = ["Original image"] + [
            EXPLANATION_NAME_SHORT_TO_LONG[explanation_name] for explanation_name in self.explanation_handlers
        ]

        for column_num, column_name in enumerate(order_of_columns):
            grid[column_num].set_title(f"{column_name}", fontsize=11)
            for row_num, image in enumerate(images_to_plot[column_name]):
                if column_name == "Original image":
                    grid[row_num * number_of_columns + column_num].imshow(image)
                else:
                    grid[row_num * number_of_columns + column_num].imshow(image, cmap=cmap, vmin=vmin, vmax=vmax)
        for row_num in range(number_of_images):
            image_for_prediction = self.data[ids_of_images_to_show[row_num]]
            model_predition = self.model(image_for_prediction[None, :])
            predicted_class = model_predition.max(1)[1].cpu().detach().numpy()[0]
            predicted_class_softmax = (
                torch.max(torch.nn.functional.softmax(model_predition, dim=1)).cpu().detach().numpy()
            )

            class_id = int(self.targets[ids_of_images_to_show[row_num]].cpu().detach().numpy())
            grid[row_num * number_of_columns].set_ylabel(
                r"$\bf{"
                + "Real~class"
                + "}$"
                + f"\n{labels.get(class_id, class_id)}\n"
                + r"$\bf{"
                + "Predicted~class"
                + "}$"
                + f"\n{labels.get(predicted_class, predicted_class)}\n"
                + r"$\bf{"
                + "Predicted~score"
                + "}$"
                + f"\n{predicted_class_softmax:.2f}",
                rotation=0,
                size="large",
            )

            grid[row_num * number_of_columns].yaxis.set_label_coords(-0.5, 0.2)

        fig.suptitle("Examples of computed attributions", fontsize=15, y=0.99)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            if number_of_images > 4:
                fig.tight_layout(rect=[0.05, 0, 1, 1])
            else:
                fig.tight_layout(rect=[0, 0, 1, 1])

        return fig

    def _check_is_after_evaluation(self) -> None:
        if self.raw_results is None or self.raw_results == {}:
            raise ValueError("Methods are not evaluated yet. Please run .evaluate() first.")  # noqa: TC003

    def _check_is_after_aggregation(self) -> None:
        self._check_is_after_evaluation()
        if (
            self.best_explanation_name is None
            or self.first_aggregation_results is None
            or self.second_aggregation_results is None
        ):
            raise ValueError("Results are not aggregated yet. Please run .aggregate() first.")  # noqa: TC003

    def _check_aggregation_parameters(
        self,
        first_stage_aggregation_function_name: str,
        second_stage_aggregation_function_name: str,
        second_stage_aggregation_function_aggregation_parameters: Union[Dict[str, Any], None],
    ) -> None:
        if not isinstance(first_stage_aggregation_function_name, str):
            raise TypeError(
                f"First stage aggregation function name must be a string. Got {type(first_stage_aggregation_function_name)} instead."
            )
        if not isinstance(second_stage_aggregation_function_name, str):
            raise TypeError(
                f"Second stage aggregation function name must be a string. Got {type(second_stage_aggregation_function_name)} instead."
            )
        if first_stage_aggregation_function_name not in self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS:
            raise ValueError(
                f"Unknown first stage aggregation function: {first_stage_aggregation_function_name}. Available functions: {list(self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS.keys())}"
            )
        if second_stage_aggregation_function_name not in self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS:
            raise ValueError(
                f"Unknown second stage aggregation function: {second_stage_aggregation_function_name}. Available functions: {list(self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS.keys())}"
            )
        if second_stage_aggregation_function_aggregation_parameters is not None:
            if not isinstance(second_stage_aggregation_function_aggregation_parameters, dict):
                raise TypeError(
                    f"Second stage aggregation function parameters must be provided as a dictionary. Got {type(second_stage_aggregation_function_aggregation_parameters)} instead."
                )

    def _check_model(self, model: torch.nn.Module) -> None:
        if not isinstance(model, torch.nn.Module):
            raise TypeError("Model must be of type torch.nn.Module.")  # noqa: TC003

    def _check_data(self, data: torch.Tensor, targets: torch.Tensor) -> None:
        if not isinstance(data, torch.Tensor):
            raise TypeError("Data must be of type torch.Tensor.")  # noqa: TC003
        if len(data.shape) != 4:
            raise ValueError("Data must be of shape (N, C, H, W).")  # noqa: TC003
        if not isinstance(targets, torch.Tensor):
            raise TypeError("Targets must be of type torch.Tensor.")  # noqa: TC003
        if len(targets.shape) != 1:
            raise ValueError("Targets must be of shape (N,).")  # noqa: TC003
        if data.shape[0] != targets.shape[0]:
            raise ValueError("Data and targets must have the same number of observations.")  # noqa: TC003
        if torch.any(torch.isnan(data)):
            raise ValueError("Provided data has NaN values.")
        if torch.any(torch.isnan(targets)):
            raise ValueError("Targets have NaN values.")

    def _check_method_and_metric_names(
        self, method_names: Union[List[str], None], metric_names: Union[List[str], None]
    ) -> None:
        if method_names is not None:
            for method_name in method_names:
                if not isinstance(method_name, str):
                    raise ValueError("Method names must be strings.")  # noqa: TC003
                if method_name not in self.KNOWN_EXPLANATION_HANDLERS:
                    raise ValueError(
                        f"Unknown explanation method: {method_name}. Available explanation methods: {list(self.KNOWN_EXPLANATION_HANDLERS.keys())}"
                    )  # noqa: TC003
        if metric_names is not None:
            for metric_name in metric_names:
                if not isinstance(metric_name, str):
                    raise ValueError("Metric names must be strings.")  # noqa: TC003
                if metric_name not in self.KNOWN_METRIC_HANDLERS:
                    raise ValueError(
                        f"Unknown metric: {metric_name}. Available metrics: {list(self.KNOWN_METRIC_HANDLERS.keys())}"
                    )  # noqa: TC003

    def _check_method_and_metric_params_dicts(
        self,
        method_names: List[str],
        metric_names: List[str],
        method_params: Union[Dict, None],
        metric_params: Union[Dict, None],
    ) -> None:
        if method_params is not None:
            if not isinstance(method_params, dict):
                raise TypeError("Explanation parameters must be a dictionary.")  # noqa: TC003
            for method_name, method_param in method_params.items():
                if method_name not in self.KNOWN_EXPLANATION_HANDLERS:
                    raise ValueError(
                        f"Unknown explanation method: {method_name}. Available explanation methods: {list(self.KNOWN_EXPLANATION_HANDLERS.keys())}"
                    )
                if method_name not in method_names:
                    warnings.warn(
                        f"Explanation method {method_name} is not in the list of methods to evaluate but the parameters were set for this method.",
                        UserWarning,
                    )
                if not isinstance(method_param, dict):
                    raise TypeError(
                        f"Explanation method parameters must be provided as a dictionary. Got {type(method_param)} instead."
                    )
        if metric_params is not None:
            if not isinstance(metric_params, dict):
                raise TypeError("Metric parameters must be a dictionary.")  # noqa: TC003
            for metric_name, metric_param in metric_params.items():
                if metric_name not in self.KNOWN_METRIC_HANDLERS:
                    raise ValueError(
                        f"Unknown metric: {metric_name}. Available metrics: {list(self.KNOWN_METRIC_HANDLERS.keys())}"
                    )
                if metric_name not in metric_names:
                    warnings.warn(
                        f"Metric {metric_name} is not in the list of metrics to evaluate but the parameters were set for this metric.",
                        UserWarning,
                    )
                if not isinstance(metric_param, dict):
                    raise TypeError(
                        f"Metric parameters must be provided as a dictionary. Got {type(metric_param)} instead."
                    )

    def _set_seed(self, seed: int) -> None:
        """
        Sets seed for all random number generators.
        Args:
            seed (int): Seed for random number generators.
        """
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

__init__(model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, device: str = 'cpu', seed: int = 42)

Parameters:

Name Type Description Default
model torch.nn.Module

Convolutional neural network to be explained. On this model some explanation and metric parameters will be inferred.

required
data torch.Tensor

Data that will be used for explanation method evaluation. shape: (N, C, H, W).

required
targets torch.Tensor

Labels for provided data. Encoded as integer vector with shape (N,).

required
Source code in autoexplainer/autoexplainer.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def __init__(
    self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, device: str = "cpu", seed: int = 42
):
    """

    Args:
        model (torch.nn.Module): Convolutional neural network to be explained. On this model some explanation and metric
            parameters will be inferred.
        data (torch.Tensor): Data that will be used for explanation method evaluation. shape: (N, C, H, W).
        targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).
    """
    self._check_model(model)
    self._check_data(data, targets)
    self._second_stage_aggregation_function_name = None
    self._first_stage_aggregation_function_name = None

    self.explanation_handlers: Dict = None
    self.metric_handlers: Dict = None
    self.first_aggregation_results: Dict = None
    self.second_aggregation_results: Dict = None
    self.best_explanation_name: str = None
    self.aggregation_parameters: Dict = None

    model = model.to(device)
    model.eval()
    self.model = fix_relus_in_model(model)
    self.data = data.to(device)
    self.targets = targets.to(device)
    self.raw_results: Dict = {}
    self.times_methods: Dict = {}
    self.times_metrics: Dict = {}
    self.times_metrics_aggregated: Dict = {}
    self.times_total = 0.0
    self.device = device
    self._set_seed(seed)

aggregate(first_stage_aggregation_function_name: str = 'mean', second_stage_aggregation_function_name: str = 'rank_based', second_stage_aggregation_function_aggregation_parameters: Dict = None) -> None

Aggregates raw result computed in .evaluate() method in two steps. First, aggregates metric scores across provided observations, i.e. each explanation method has a value for each metric. Secondly, aggregates scores across available metrics, i.e. each explanation method has a single value that represents overall quality.

Stores both aggregation steps in the attributes first_aggregation_results and second_aggregation_results.

Parameters:

Name Type Description Default
first_stage_aggregation_function_name {"mean", "median", "min","max"}

Name of the function for the first stage aggregation.

'mean'
second_stage_aggregation_function_name {"mean", "median", "min","max"}

Name of the function for second stage aggregaton.

'rank_based'
second_stage_aggregation_function_aggregation_parameters Dict

Parameters for the second stage aggregation function.

None
Source code in autoexplainer/autoexplainer.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def aggregate(
    self,
    first_stage_aggregation_function_name: str = "mean",
    second_stage_aggregation_function_name: str = "rank_based",
    second_stage_aggregation_function_aggregation_parameters: Dict = None,
) -> None:
    """
    Aggregates raw result computed in .evaluate() method in two steps. First, aggregates metric scores across
    provided observations, i.e. each explanation method has a  value for each metric. Secondly, aggregates
    scores across available metrics, i.e. each explanation method has a single value that represents overall quality.

    Stores both aggregation steps in the attributes ``first_aggregation_results`` and ``second_aggregation_results``.

    Args:
        first_stage_aggregation_function_name ({"mean", "median", "min","max"}): Name of the function for the first stage aggregation.
        second_stage_aggregation_function_name ({"mean", "median", "min","max"}): Name of the function for second stage aggregaton.
        second_stage_aggregation_function_aggregation_parameters (Dict): Parameters for the second stage aggregation function.

    """

    self._check_is_after_evaluation()
    self._check_aggregation_parameters(
        first_stage_aggregation_function_name,
        second_stage_aggregation_function_name,
        second_stage_aggregation_function_aggregation_parameters,
    )
    self._first_stage_aggregation_function_name = first_stage_aggregation_function_name
    self._second_stage_aggregation_function_name = second_stage_aggregation_function_name

    if second_stage_aggregation_function_aggregation_parameters is None:
        second_stage_aggregation_function_aggregation_parameters = {}
    self.first_aggregation_results = self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS[
        first_stage_aggregation_function_name
    ](self.raw_results)
    self.second_aggregation_results = self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS[
        second_stage_aggregation_function_name
    ](self.first_aggregation_results, second_stage_aggregation_function_aggregation_parameters)
    sorted_results = sorted(self.second_aggregation_results.items(), key=lambda x: x[1], reverse=True)
    if len(sorted_results) > 1:
        best_result, second_best_result = sorted_results[0], sorted_results[1]
        if best_result[1] == second_best_result[1]:
            if self.times_methods[best_result[0]] > self.times_methods[second_best_result[0]]:
                best_result = second_best_result
        self.best_explanation_name = best_result[0]
    else:
        self.best_explanation_name = sorted_results[0][0]

    self.aggregation_parameters = {
        "first_stage_aggregation_function": self.KNOWN_FIRST_STAGE_AGGREGATION_FUNCTIONS[
            first_stage_aggregation_function_name
        ],
        "second_stage_aggregation_function": self.KNOWN_SECOND_STAGE_AGGREGATION_FUNCTIONS[
            second_stage_aggregation_function_name
        ],
        "second_stage_aggregation_function_aggregation_parameters": second_stage_aggregation_function_aggregation_parameters,
    }

evaluate(explanations: List[str] = None, metrics: List[str] = None, explanation_params: Dict = None, metrics_params: Dict = None) -> None

Evaluates explanation methods. Stores results in .raw_results attribute.

Parameters:

Name Type Description Default
explanations List[str]

List of names of explanation methods to be evaluated. By default, uses all available explanation methods. Accepts lists with subset of: {"saliency", "grad_cam", "integrated_gradients", "kernel_shap"}.

None
metrics List[str]

List of names of evaluation metrics to be used. By default, uses all available metrics. Accepts lists with subset of: {"irof", "sparseness", "average_sensitivity", "faithfulness_estimate"}.

None
explanation_params Dict[str, Dict]

Allows to override default parameters of selected explanation functions. Accept Dictionary with form {"explanation_name": <Dictionary with parameters>}. See corresponding ExplanationHandler to see what parameters are accepted.

None
metrics_params Dict[str, Dict]

Allows to override default parameters of selected metrics. Accept Dictionary with form {"metric_name": <Dictionary with parameters>}. See corresponding MetricHandler to see what parameters are accepted.

None
Source code in autoexplainer/autoexplainer.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def evaluate(
    self,
    explanations: List[str] = None,
    metrics: List[str] = None,
    explanation_params: Dict = None,
    metrics_params: Dict = None,
) -> None:
    """
    Evaluates explanation methods. Stores results in ``.raw_results`` attribute.
    Args:
        explanations (List[str]): List of names of explanation methods to be evaluated.
                                  By default, uses all available explanation methods.
                                  Accepts lists with subset of: ``{"saliency", "grad_cam", "integrated_gradients", "kernel_shap"}``.
        metrics (List[str]): List of names of evaluation metrics to be used. By default, uses all available metrics.
                            Accepts lists with subset of: ``{"irof", "sparseness", "average_sensitivity", "faithfulness_estimate"}``.
        explanation_params (Dict[str, Dict]): Allows to override default parameters of selected explanation functions.
                                            Accept Dictionary with form ``{"explanation_name": <Dictionary with parameters>}``.
                                            See corresponding ExplanationHandler to see what parameters are accepted.
        metrics_params (Dict[str, Dict]): Allows to override default parameters of selected metrics.
                                          Accept Dictionary with form ``{"metric_name": <Dictionary with parameters>}``.
                                          See corresponding MetricHandler to see what parameters are accepted.

    """
    self._check_method_and_metric_names(explanations, metrics)
    if explanations is None:
        explanations = self.KNOWN_EXPLANATION_HANDLERS
    if metrics is None:
        metrics = self.KNOWN_METRIC_HANDLERS
    self._check_method_and_metric_params_dicts(explanations, metrics, explanation_params, metrics_params)
    if explanation_params is None:
        explanation_params = {}
    if metrics_params is None:
        metrics_params = {}

    print("\nPreparing explanation methods and metric handlers...\n")

    self.explanation_handlers = {
        explanation_name: self.KNOWN_EXPLANATION_HANDLERS[explanation_name](
            self.model, self.data, self.targets, explanation_params.get(explanation_name)
        )
        for explanation_name in explanations
    }
    self.metric_handlers = {
        metric_name: self.KNOWN_METRIC_HANDLERS[metric_name](
            self.model, self.data, self.targets, metrics_params.get(metric_name)
        )
        for metric_name in metrics
    }
    self.times_metrics = {metric_name: {} for metric_name in metrics}

    print("\tNumber of explanation methods to evaluate: ", len(self.explanation_handlers))
    print(
        "\tExplanation methods selected: "
        + f"{', '.join([EXPLANATION_NAME_SHORT_TO_LONG[x] for x in list(self.explanation_handlers.keys())])}"
    )
    print("")
    print("\tNumber of metrics used during evaluation: ", len(self.metric_handlers))
    print(
        "\tMetrics selected: "
        + f"{', '.join([METRIC_NAME_SHORT_TO_LONG[x] for x in list(self.metric_handlers.keys())])}"
    )

    pbar = tqdm.tqdm(self.explanation_handlers.items(), desc="Creating attributions")
    for explanation_name, explanation_handler in pbar:
        start_time = time.time()
        pbar.set_description(f"Creating attributions for {explanation_name}")
        explanation_handler.explain(model=self.model, data=self.data, targets=self.targets)
        self.times_methods[explanation_name] = round(time.time() - start_time, 3)

    for explanation_name in self.explanation_handlers.keys():
        self.raw_results[explanation_name] = {}

    print("Creating attribution finished. Starting evaluation.")
    print("Evaluation may take a very long time, please be patient...")

    pbar = tqdm.tqdm(
        itertools.product(self.metric_handlers.items(), self.explanation_handlers.items()),
        total=len(self.metric_handlers) * len(self.explanation_handlers),
        desc="Evaluating metrics",
    )
    for (metric_name, metric_handler), (explanation_name, explanation_handler) in pbar:
        start_time = time.time()
        pbar.set_description(f"Evaluating: method {explanation_name} and metric {metric_name}")
        self.raw_results[explanation_name][metric_name] = metric_handler.compute_metric_values(
            model=self.model,
            data=self.data,
            targets=self.targets,
            attributions=explanation_handler.attributions.to(next(self.model.parameters()).device),
            explanation_func=explanation_handler.get_explanation_function(),
        )
        self.times_metrics[metric_name][explanation_name] = round(
            time.time() - start_time + self.times_methods[explanation_name], 3
        )

    self.times_metrics_aggregated = {
        metric_name: round(sum(self.times_metrics[metric_name].values()), 3)
        for metric_name in self.times_metrics.keys()
    }
    self.times_total = round(sum(self.times_metrics_aggregated.values()), 3)

    print(f"Evaluating metrics finished after {self.times_total} seconds.")

get_best_explanation() -> BestExplanation

Returns an object with the selected best explanation method wrapped with a few additions, see BestExplanation for more details. Returns (BestExplanation): BestExplanation object

Source code in autoexplainer/autoexplainer.py
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def get_best_explanation(self) -> BestExplanation:
    """
    Returns an object with the selected best explanation method wrapped with a few additions, see BestExplanation for more details.
    Returns (BestExplanation): BestExplanation object

    """
    self._check_is_after_aggregation()
    best_explanation_handler = self.explanation_handlers[self.best_explanation_name]
    return BestExplanation(
        attributions=best_explanation_handler.attributions,
        explanation_function=best_explanation_handler.get_explanation_function(),
        explanation_name=self.best_explanation_name,
        explanation_function_parameters=best_explanation_handler.explanation_parameters,
        metric_handlers=self.metric_handlers,
        aggregation_parameters=self.aggregation_parameters,
    )

to_html(file_path: str, model_name: str = None, dataset_name: str = None, labels: Dict[int, str] = None) -> None

Generates evaluation report as HTML file.

Parameters:

Name Type Description Default
file_path str

Target file path.

required
model_name str

Name of model to show inside report.

None
dataset_name str

Name of dataset to show inside report.

None
labels Dict[int, str]

Mapping between class number and class names. e.g. {0:"dog", 1:"cat", 2:"fish"} for labels inside report.

None
Source code in autoexplainer/autoexplainer.py
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def to_html(
    self, file_path: str, model_name: str = None, dataset_name: str = None, labels: Dict[int, str] = None
) -> None:
    """
    Generates evaluation report as HTML file.
    Args:
        file_path (str): Target file path.
        model_name (str): Name of model to show inside report.
        dataset_name (str): Name of dataset to show inside report.
        labels (Dict[int,str]): Mapping between class number and class names. e.g. ``{0:"dog", 1:"cat", 2:"fish"}`` for labels
                                inside report.
    """
    assert self.first_aggregation_results is not None, "Aggregated results are needed for report generation."
    assert self.second_aggregation_results is not None, "Aggregated results are needed for report generation."

    environment = Environment(loader=PackageLoader("autoexplainer"))
    template = environment.get_template("report.html")

    report_info = self._get_info_for_report(labels=labels)

    pic_io_bytes = io.BytesIO()
    fig = report_info["fig_with_examples"]
    fig.savefig(pic_io_bytes, format="png")
    pic_io_bytes.seek(0)
    pic_hash = base64.b64encode(pic_io_bytes.read())

    _, _, *float_columns, _ = report_info["result_dataframe"].columns

    html_table = (
        report_info["result_dataframe"]
        .style.set_properties(
            subset=["Agg. Score", "Explanation Name"], **{"font-weight": "bold", "text-align": "center"}
        )
        .set_properties(border=0)
        .hide_index()
        .format("{:.3f}", subset=float_columns)
        .render()
    )

    rendered = template.render(
        model_name=model_name,
        dataset_name=dataset_name,
        dataframe_html=html_table,
        pic_hash=pic_hash.decode(),
        **report_info,
    )
    with open(file_path, mode="w", encoding="utf-8") as results:
        results.write(rendered)

to_pdf(folder_path: str = '', model_name: str = 'name of the model', dataset_name: str = 'name of the dataset', labels: Dict[int, str] = None) -> None

Creates PDF report from dict stored in the attribute first_aggregation_results. Needs Latex packages installed to run - see README.

Parameters:

Name Type Description Default
folder_path str

Path to directory, where the reports (PDF and tex) should be created.

''
model_name str

Name of model to show inside report.

'name of the model'
dataset_name str

Name of dataset to show inside report.

'name of the dataset'
labels Dict[int, str]

Mapping between class number and class names. e.g. {0:"dog", 1:"cat", 2:"fish"} for labels inside report.

None
Source code in autoexplainer/autoexplainer.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def to_pdf(
    self,
    folder_path: str = "",
    model_name: str = "name of the model",
    dataset_name: str = "name of the dataset",
    labels: Dict[int, str] = None,
) -> None:

    """
    Creates PDF report from dict stored in the attribute ``first_aggregation_results``.
    Needs Latex packages installed to run - see README.

    Args:
        folder_path (str): Path to directory, where the reports (PDF and tex) should be created.
        model_name (str): Name of model to show inside report.
        dataset_name (str): Name of dataset to show inside report.
        labels (Dict[int,str]): Mapping between class number and class names. e.g. ``{0:"dog", 1:"cat", 2:"fish"}`` for labels
                                inside report.

    """
    self._check_is_after_aggregation()

    tex_file = os.path.join(folder_path, "report.tex")
    pdf_file = os.path.join(folder_path, "report.pdf")

    if os.path.exists(tex_file):
        os.remove(tex_file)
    if os.path.exists(pdf_file):
        os.remove(pdf_file)

    report_info = self._get_info_for_report(labels=labels)

    left_margin = 2
    max_nr_columns_in_table = 5
    geometry_options = {"tmargin": "2cm", "lmargin": f"{left_margin}cm"}
    doc = Document(geometry_options=geometry_options)
    doc.preamble.append(Command("title", "AutoeXplainer Report"))
    doc.preamble.append(Command("date", ""))
    doc.packages.append(Package("hyperref"))
    doc.packages.append(Package("booktabs"))
    doc.append(NoEscape(r"\maketitle"))

    results = report_info["result_dataframe"]

    metric_name_copy = copy.deepcopy(METRIC_NAME_SHORT_TO_LONG)
    metric_name_copy["explanation_name"] = "explanation name"
    metric_name_copy["Rank"] = "Rank"
    metric_name_copy["Agg. Score"] = "Agg. Score"
    explanation_methods = report_info["methods"]
    metrics = report_info["metrics"]
    metrics_used = copy.deepcopy(metrics)

    metrics = ["explanation name", "Rank"] + metrics + ["Agg. Score"]
    data = copy.deepcopy(results)

    def hyperlink(url: str, text: str) -> NoEscape:  # type: ignore
        return NoEscape(r"\href{" + url + "}{" + escape_latex(text) + "}")

    # create content of  the Document
    with doc.create(Section("General information", numbering=False)):
        doc.append(bold("Model name: "))
        doc.append(italic(f"{model_name} \n"))
        doc.append(bold("Dataset name: "))
        doc.append(italic(f"{dataset_name} \n"))
        doc.append(bold("Execution time: "))
        doc.append(italic(f"{report_info['execution_time']} s \n"))
        doc.append(bold("Package version: "))
        doc.append(italic(f"{report_info['autoexplainer_version']} \n"))
        doc.append(bold("Date: "))
        doc.append(italic(f"{report_info['date']} \n"))
        doc.append(bold("Selected method: "))
        doc.append(italic(f"{report_info['selected_method']} \n"))
        doc.append(bold("Number of images: "))
        doc.append(italic(f"{report_info['n_images']}"))

    with doc.create(Section("Model performance", numbering=False)):
        doc.append(bold("Accuracy: "))
        doc.append(italic(f"{report_info['model_acc']} \n"))
        doc.append(bold("F1 macro: "))
        doc.append(italic(f"{report_info['model_f1_macro']} \n"))
        doc.append(bold("Balanced accuracy: "))
        doc.append(italic(f"{report_info['model_bac']} \n"))

    with doc.create(Section("Table of results", numbering=False)):
        doc.append(NoEscape(r"\begin{footnotesize}"))
        doc.append(NoEscape(r"\begin{flushleft} "))
        doc.append(NoEscape(report_info["result_dataframe"].to_latex(index=False)))
        doc.append(NoEscape(r"\end{flushleft}"))
        doc.append(NoEscape(r"\end{footnotesize}"))
        doc.append(bold("Table description \n"))
        doc.append(
            "Arrow next to the metric names indicates whether larger or smaller values of metric are better. Time elapsed shows time that was required for computation of attribution for given batch of images. When there is a tie in Aggregated Score, the best metric is chosen based on computation time."
        )

    doc.append(NewPage())
    with doc.create(Section("Details", numbering=False)):
        with doc.create(Subsection("Explanations:", numbering=False)):
            with doc.create(Itemize()) as itemize:
                for i in range(0, len(data.iloc[:, 0])):
                    explanation_name = EXPLANATION_NAME_SHORT_TO_LONG[explanation_methods[i]]
                    itemize.add_item(bold(explanation_name))
                    doc.append(EXPLANATION_DESCRIPTION[str(explanation_name)][0])
                    doc.append(
                        hyperlink(
                            EXPLANATION_DESCRIPTION[str(explanation_name)][1],
                            EXPLANATION_DESCRIPTION[str(explanation_name)][2],
                        )
                    )
                    doc.append("\n")
                    doc.append("Explanation's parameters: \n")
                    doc.append(NoEscape(r"\texttt{"))
                    doc.append(f"{report_info['method_parameters'][explanation_methods[i]]} \n")
                    doc.append(NoEscape(r"}"))
        doc.append(NewPage())
        with doc.create(Subsection("Metrics:", numbering=False)):
            with doc.create(Itemize()) as itemize:
                minus = 2
                for i in range(2, len(data.columns) - 1):
                    if data.columns[i] == "Time elapsed [s]":
                        minus += 1
                    else:
                        itemize.add_item(bold(METRIC_NAME_MEDIUM_TO_LONG[data.columns[i]]))
                        doc.append(METRIC_DESCRIPTIONS[data.columns[i]][0])
                        doc.append(
                            hyperlink(
                                METRIC_DESCRIPTIONS[data.columns[i]][1], METRIC_DESCRIPTIONS[data.columns[i]][2]
                            )
                        )
                        doc.append("\n")
                        doc.append("Metric's parameters: \n")
                        doc.append(NoEscape(r"\texttt{"))
                        doc.append(f"{report_info['metric_parameters'][metrics_used[i-minus]]} \n")
                        doc.append(NoEscape(r"}"))
        with doc.create(Subsection("Aggregation parameters", numbering=False)):
            doc.append(NoEscape(r"\texttt{"))
            doc.append(report_info["aggregation_parameters"])
            doc.append(NoEscape(r"}"))
    doc.append(NewPage())
    with doc.create(Section("Examples of explanations", numbering=False)):
        with doc.create(Figure(position="!h")) as mini_logo:
            fig = report_info["fig_with_examples"]
            mini_logo.add_plot(fig=fig, width=f"{21 - 2 * left_margin}cm")

    doc.generate_pdf(os.path.join(folder_path, "report"), clean_tex=False)

Metric handlers

AvgSensitivityHandler

Bases: MetricHandler

Metric Handler for Average Sensivity metric (Yeh et al., 2019).

Measures the average sensitivity of an explanation using a Monte Carlo sampling-based approximation.

Dictionary with parameters to override must be in the form:

metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                   "call": <dictionary with parameters used in metric's __call__>}

Parameters accepted in metric_parameters:

"init":

  • abs: a bool stating if absolute operation should be taken on the attributions
  • normalise: a bool stating if the attributions should be normalised
  • normalise_func: a Callable that make a normalising transformation of the attributions
  • lower_bound (float): lower Bound of Perturbation, default=0.2
  • upper_bound (None, float): upper Bound of Perturbation, default=None
  • nr_samples (integer): the number of samples iterated, default=200.
  • norm_numerator (callable): function for norm calculations on the numerator, default=fro_norm.
  • norm_denominator (callable): function for norm calculations on the denominator, default=fro_norm.
  • perturb_func (callable): input perturbation function, default=uniform_noise.
  • similarity_func (callable): similarity function applied to compare input and perturbed input.

"call": No parameters are used.

Source code in autoexplainer/metrics.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class AvgSensitivityHandler(MetricHandler):
    """
    Metric Handler for Average Sensivity metric [(Yeh et al., 2019)](https://arxiv.org/abs/1901.09392).

    Measures the average sensitivity of an explanation using a Monte Carlo sampling-based approximation.

    Dictionary with parameters to override must be in the form:
    ```
    metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                       "call": <dictionary with parameters used in metric's __call__>}
    ```

    Parameters accepted in `metric_parameters`:

    **"init"**:

    - `abs`: a bool stating if absolute operation should be taken on the attributions
    - `normalise`: a bool stating if the attributions should be normalised
    - `normalise_func`: a Callable that make a normalising transformation of the attributions
    - `lower_bound` (float): lower Bound of Perturbation, default=0.2
    - `upper_bound` (None, float): upper Bound of Perturbation, default=None
    - `nr_samples` (integer): the number of samples iterated, default=200.
    - `norm_numerator` (callable): function for norm calculations on the numerator, default=fro_norm.
    - `norm_denominator` (callable): function for norm calculations on the denominator, default=fro_norm.
    - `perturb_func` (callable): input perturbation function, default=uniform_noise.
    - `similarity_func` (callable): similarity function applied to compare input and perturbed input.

    **"call"**: No parameters are used.

    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, metric_parameters: Dict = None
    ) -> None:
        self.metric_parameters = {"init": self._infer_metric_parameters(model, data, targets), "call": {}}
        self.metric_parameters["init"] = self._add_hide_output_parameters_to_dict(self.metric_parameters["init"])
        if str(next(model.parameters()).device) != "cpu":
            self.metric_parameters["call"] = {"device": "cuda"}
        if metric_parameters is not None:
            update_dictionary(self.metric_parameters, metric_parameters)
        self.metric = quantus.AvgSensitivity(**self.metric_parameters["init"])

    def _infer_metric_parameters(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> Dict:
        parameters = {
            "normalise": True,
            "nr_samples": 20,
            "lower_bound": 0.2,
            "norm_numerator": quantus.fro_norm,
            "norm_denominator": quantus.fro_norm,
            "perturb_func": quantus.uniform_noise,
            "similarity_func": quantus.difference,
            "perturb_radius": 0.2,
        }
        return parameters

FaithfulnessEstimateHandler

Bases: MetricHandler

Metric handler for Faithfulness Estimate metric (Alvarez-Melis et al., 2018).

Computes the correlation between probability drops and attribution scores on various points.

Dictionary with parameters to override must be in the form:

metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                   "call": <dictionary with parameters used in metric's __call__>}

Parameters accepted in metric_parameters:

"init":

  • abs: a bool stating if absolute operation should be taken on the attributions
  • normalise: a bool stating if the attributions should be normalised
  • normalise_func: a Callable that make a normalising transformation of the attributions
  • nr_runs (integer): the number of runs (for each input and explanation pair), default=100.
  • subset_size (integer): the size of subset, default=224.
  • perturb_baseline (string): indicates the type of baseline: "mean", "random", "uniform", "black" or "white", default="mean".
  • similarity_func (callable): Similarity function applied to compare input and perturbed input, default=correlation_spearman.
  • perturb_func (callable): input perturbation function, default=baseline_replacement_by_indices.
  • features_in_step (integer): the size of the step, default=256.
  • softmax (boolean): indicates wheter to use softmax probabilities or logits in model prediction, default=True.

"call": No parameters are used.

Source code in autoexplainer/metrics.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class FaithfulnessEstimateHandler(MetricHandler):
    """
    Metric handler for Faithfulness Estimate metric [(Alvarez-Melis et al., 2018)](https://arxiv.org/abs/1806.07538).

    Computes the correlation between probability drops and attribution scores on various points.

    Dictionary with parameters to override must be in the form:
    ```
    metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                       "call": <dictionary with parameters used in metric's __call__>}
    ```

    Parameters accepted in `metric_parameters`:

    **"init"**:

    * `abs`: a bool stating if absolute operation should be taken on the attributions
    * `normalise`: a bool stating if the attributions should be normalised
    * `normalise_func`: a Callable that make a normalising transformation of the attributions
    * `nr_runs` (integer): the number of runs (for each input and explanation pair), default=100.
    * `subset_size` (integer): the size of subset, default=224.
    * `perturb_baseline` (string): indicates the type of baseline: "mean", "random", "uniform", "black" or "white", default="mean".
    * `similarity_func` (callable): Similarity function applied to compare input and perturbed input, default=correlation_spearman.
    * `perturb_func` (callable): input perturbation function, default=baseline_replacement_by_indices.
    * `features_in_step` (integer): the size of the step, default=256.
    * `softmax` (boolean): indicates wheter to use softmax probabilities or logits in model prediction, default=True.

    "call": No parameters are used.

    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, metric_parameters: Dict = None
    ) -> None:
        self.metric_parameters = {}
        self.metric_parameters["init"] = self._infer_metric_parameters(model, data, targets)
        self.metric_parameters["init"] = self._add_hide_output_parameters_to_dict(self.metric_parameters["init"])
        if str(next(model.parameters()).device) == "cpu":
            self.metric_parameters["call"] = {}
        else:
            self.metric_parameters["call"] = {"device": "cuda"}
        if metric_parameters is not None:
            update_dictionary(self.metric_parameters, metric_parameters)
        self.metric = quantus.FaithfulnessEstimate(**self.metric_parameters["init"])

    def _infer_metric_parameters(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> Dict:
        parameters = {
            "normalise": True,
            "features_in_step": 256,
            "perturb_baseline": "black",
            "softmax": True,
        }
        return parameters

IROFHandler

Bases: MetricHandler

Metric handler for Iterative Removal Of Features metric (Rieger at el., 2020).

Computes the area over the curve per class for sorted mean importances of feature segments (superpixels) as they are iteratively removed (and prediction scores are collected), averaged over several test samples.

Dictionary with parameters to override must be in the form:

metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                   "call": <dictionary with parameters used in metric's __call__>}

Parameters accepted in metric_parameters: "init": - abs: a bool stating if absolute operation should be taken on the attributions - normalise: a bool stating if the attributions should be normalised - normalise_func: a Callable that make a normalising transformation of the attributions - segmentation_method (string): Image segmentation method: 'slic' or 'felzenszwalb', default="slic" - perturb_baseline (string): indicates the type of baseline: "mean", "random", "uniform", "black" or "white", default="mean" - perturb_func (callable): input perturbation function, default=baseline_replacement_by_indices - softmax (boolean): indicates wheter to use softmax probabilities or logits in model prediction

"call": No parameters are used.

Source code in autoexplainer/metrics.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class IROFHandler(MetricHandler):
    """
    Metric handler for Iterative Removal Of Features metric [(Rieger at el., 2020)](https://arxiv.org/abs/2003.08747).

    Computes the area over the curve per class for sorted mean importances of feature segments (superpixels)
    as they are iteratively removed (and prediction scores are collected), averaged over several test samples.

    Dictionary with parameters to override must be in the form:
    ```
    metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                       "call": <dictionary with parameters used in metric's __call__>}
    ```
    Parameters accepted in `metric_parameters`:
    "init":
    - abs: a bool stating if absolute operation should be taken on the attributions
    - normalise: a bool stating if the attributions should be normalised
    - normalise_func: a Callable that make a normalising transformation of the attributions
    - segmentation_method (string): Image segmentation method: 'slic' or 'felzenszwalb', default="slic"
    - perturb_baseline (string): indicates the type of baseline: "mean", "random", "uniform", "black" or "white", default="mean"
    - perturb_func (callable): input perturbation function, default=baseline_replacement_by_indices
    - softmax (boolean): indicates wheter to use softmax probabilities or logits in model prediction

    "call": No parameters are used.

    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, metric_parameters: Dict = None
    ) -> None:
        self.metric_parameters = {"init": self._infer_irof_parameters(model, data, targets), "call": {}}
        self.metric_parameters["init"] = self._add_hide_output_parameters_to_dict(self.metric_parameters["init"])
        if str(next(model.parameters()).device) != "cpu":
            self.metric_parameters["call"] = {"device": "cuda"}
        if metric_parameters is not None:
            update_dictionary(self.metric_parameters, metric_parameters)
        self.metric = quantus.IterativeRemovalOfFeatures(**self.metric_parameters["init"])

    def _infer_irof_parameters(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> Dict:
        parameters = {
            "segmentation_method": "slic",
            "perturb_baseline": "mean",
            "softmax": True,
            "return_aggregate": False,
        }
        return parameters

MetricHandler

Bases: ABC

Abstract class for metrics handlers.

Parameters:

Name Type Description Default
model torch.nn.Module

Model used for metrics' parameters inference.

required
data torch.Tensor

Data used for metrics' parameters inference.

required
targets torch.Tensor

Target used for metrics' parameters inference.

required
metric_parameters Dict

Metric parameters to overwrite inferred parameters. Dictionary must be in the form:

                  ```
                  metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                                       "call": <dictionary with parameters used in metric's __call__>}
                      ```
None

Attributes:

Name Type Description
metric quantus.Metric

Attribute that stores created metric object after determinig its parameters.

metric_parameters Dict

Dictionary with parameters used for this metric.

Source code in autoexplainer/metrics.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class MetricHandler(ABC):
    """
    Abstract class for metrics handlers.



    Args:
        model (torch.nn.Module): Model used for metrics' parameters inference.
        data (torch.Tensor): Data used for metrics' parameters inference.
        targets (torch.Tensor): Target used for metrics' parameters inference.
        metric_parameters (Dict): Metric parameters to overwrite inferred parameters. Dictionary must be in the form:

                                  ```
                                  metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                                                       "call": <dictionary with parameters used in metric's __call__>}
                                      ```


    Attributes:
        metric (quantus.Metric): Attribute that stores created metric object after determinig its parameters.
        metric_parameters (Dict): Dictionary with parameters used for this metric.

    """

    metric: quantus.Metric = NotImplemented  # type: ignore[no-any-unimported]
    metric_parameters: Dict = None

    @abstractmethod
    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, metric_parameters: Dict = None
    ) -> None:
        pass

    def compute_metric_values(
        self,
        model: torch.nn.Module,
        data: torch.Tensor,
        targets: torch.Tensor,
        attributions: torch.Tensor = None,
        explanation_func: Callable = None,
    ) -> np.ndarray:
        """
        Computes metric values for given model, dataset, and explanation function.
        Args:
            model:
            data:
            targets:
            attributions:
            explanation_func:

        Returns:
            NumPy array with metric values for each given image.
        """
        x_batch = deepcopy(data.cpu().detach().numpy())
        y_batch = deepcopy(targets.cpu().detach().numpy())
        if attributions is not None:
            a_batch = deepcopy(attributions.cpu().detach().numpy())
        else:
            a_batch = None
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            result_list: List[float] = self.metric(
                model=model,
                x_batch=x_batch,
                y_batch=y_batch,
                a_batch=a_batch,
                explain_func=explanation_func,
                **self.metric_parameters["call"],
            )
        result: np.ndarray = np.array(result_list)
        return result

    def get_parameters(self) -> Dict:
        return self.metric_parameters

    def _add_hide_output_parameters_to_dict(self, metric_parameters: Dict) -> Dict:
        metric_parameters["disable_warnings"] = True
        metric_parameters["display_progressbar"] = False
        return metric_parameters

compute_metric_values(model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, attributions: torch.Tensor = None, explanation_func: Callable = None) -> np.ndarray

Computes metric values for given model, dataset, and explanation function.

Parameters:

Name Type Description Default
model torch.nn.Module required
data torch.Tensor required
targets torch.Tensor required
attributions torch.Tensor None
explanation_func Callable None

Returns:

Type Description
np.ndarray

NumPy array with metric values for each given image.

Source code in autoexplainer/metrics.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def compute_metric_values(
    self,
    model: torch.nn.Module,
    data: torch.Tensor,
    targets: torch.Tensor,
    attributions: torch.Tensor = None,
    explanation_func: Callable = None,
) -> np.ndarray:
    """
    Computes metric values for given model, dataset, and explanation function.
    Args:
        model:
        data:
        targets:
        attributions:
        explanation_func:

    Returns:
        NumPy array with metric values for each given image.
    """
    x_batch = deepcopy(data.cpu().detach().numpy())
    y_batch = deepcopy(targets.cpu().detach().numpy())
    if attributions is not None:
        a_batch = deepcopy(attributions.cpu().detach().numpy())
    else:
        a_batch = None
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        result_list: List[float] = self.metric(
            model=model,
            x_batch=x_batch,
            y_batch=y_batch,
            a_batch=a_batch,
            explain_func=explanation_func,
            **self.metric_parameters["call"],
        )
    result: np.ndarray = np.array(result_list)
    return result

SparsenessHandler

Bases: MetricHandler

Metric Handler for Sparseness metric (Chalasani et al., 2020).

Uses the Gini Index for measuring, if only highly attributed features are truly predictive of the model output.

Dictionary with parameters to override must be in the form:

metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                   "call": <dictionary with parameters used in metric's __call__>}

Parameters accepted in metric_parameters:

"init":

  • abs: a bool stating if absolute operation should be taken on the attributions
  • normalise: a bool stating if the attributions should be normalised
  • normalise_func: a Callable that make a normalising transformation of the attributions

"call": No parameters are used.

Source code in autoexplainer/metrics.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class SparsenessHandler(MetricHandler):
    """
    Metric Handler for Sparseness metric [(Chalasani et al., 2020)](https://arxiv.org/abs/1810.06583).

    Uses the Gini Index for measuring, if only highly attributed features are truly predictive of the model output.


    Dictionary with parameters to override must be in the form:
    ```
    metric_parameters = {"init": <dictionary with parameters used in metric's __init__>,
                       "call": <dictionary with parameters used in metric's __call__>}
    ```

    Parameters accepted in `metric_parameters`:

    "init":

    - abs: a bool stating if absolute operation should be taken on the attributions
    - normalise: a bool stating if the attributions should be normalised
    - normalise_func: a Callable that make a normalising transformation of the attributions

    "call": No parameters are used.


    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, metric_parameters: Dict = None
    ) -> None:
        self.metric_parameters = {"init": self._infer_sparsness_parameters(model, data, targets), "call": {}}
        self.metric_parameters["init"] = self._add_hide_output_parameters_to_dict(self.metric_parameters["init"])
        if str(next(model.parameters()).device) != "cpu":
            self.metric_parameters["call"] = {"device": "cuda"}
        if metric_parameters is not None:
            update_dictionary(self.metric_parameters, metric_parameters)
        self.metric = quantus.Sparseness(**self.metric_parameters["init"])

    def _infer_sparsness_parameters(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> Dict:
        return {}

Explanation methods handlers

BestExplanation

Class for an object that wraps the best explanation method selected during the evaluation process.

Attributes:

Name Type Description
attributions torch.Tensor

Attributions computed during evaluation using this explanation method only.

explanation_function torch.Tensor

Function that computes attributions for the provided model and data.

name str

Name of this explanation method.

parameters Dict

Parameters used in this explanation method.

metric_handlers Dict

Dictionary with metric handlers that this explanation method was evaluated with.

aggregation_parameters Dict

Parameter that were used during aggregation of metric values.

Source code in autoexplainer/explanations/explanation_handlers.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
class BestExplanation:
    """
    Class for an object that wraps the best explanation method selected during the evaluation process.

    Attributes:
         attributions (torch.Tensor): Attributions computed during evaluation using this explanation method only.
         explanation_function (torch.Tensor): Function that computes attributions for the provided model and data.
         name (str): Name of this explanation method.
         parameters (Dict): Parameters used in this explanation method.
         metric_handlers (Dict): Dictionary with metric handlers that this explanation method was evaluated with.
         aggregation_parameters (Dict): Parameter that were used during aggregation of metric values.
    """

    attributions: torch.Tensor = None
    explanation_function: Callable = None
    name: str = None
    parameters: Dict = None
    metric_handlers: Dict = None
    aggregation_parameters: Dict = None

    def __init__(
        self,
        attributions: torch.Tensor,
        explanation_function: Callable,
        explanation_name: str,
        explanation_function_parameters: Dict,
        metric_handlers: Dict,
        aggregation_parameters: Dict,
    ) -> None:
        self.attributions = attributions
        self.explanation_function = explanation_function
        self.name = explanation_name
        self.parameters = explanation_function_parameters
        self.metric_handlers = metric_handlers
        self.aggregation_parameters = aggregation_parameters

    def explain(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute new attributions.
        Args:
            model (torch.nn.Module): CNN neural network to be explained.
            data (torch.Tensor): Data for which attributions will be computed. shape: (N, C, H, W).
            targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).

        Returns:
            attributions (torch.Tensor)

        """
        self._check_model_and_data(model, data, targets)
        print(f"Computing attributions using {EXPLANATION_NAME_SHORT_TO_LONG[self.name]} method.")
        print("This may take a while, depending on the number of samples to be explained.")
        model = fix_relus_in_model(model)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            attributions_list = []
            for x, y in tqdm.tqdm(zip(data, targets), total=len(data), desc="Calculating attributions"):
                attr = self.explanation_function(
                    model=model,
                    inputs=x.reshape(1, *x.shape).to(next(model.parameters()).device),
                    targets=y.to(next(model.parameters()).device),
                )
                attributions_list.append(torch.tensor(attr)[0])
        all_attributions = torch.stack(attributions_list, dim=0)
        print("Finished.")
        return all_attributions

    def _check_model_and_data(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> None:
        if not isinstance(model, torch.nn.Module):
            raise TypeError("Model must be of type torch.nn.Module.")  # noqa: TC003
        if not isinstance(data, torch.Tensor):
            raise TypeError("Data must be of type torch.Tensor.")  # noqa: TC003
        if len(data.shape) != 4:
            raise ValueError("Data must be of shape (N, C, H, W).")  # noqa: TC003
        if not isinstance(targets, torch.Tensor):
            raise TypeError("Targets must be of type torch.Tensor.")  # noqa: TC003
        if len(targets.shape) != 1:
            raise ValueError("Targets must be of shape (N,).")  # noqa: TC003
        if data.shape[0] != targets.shape[0]:
            raise ValueError("Data and targets must have the same number of observations.")  # noqa: TC003

    def evaluate(
        self,
        model: torch.nn.Module,
        data: torch.Tensor,
        targets: torch.Tensor,
        attributions: torch.Tensor = None,
        aggregate: bool = False,
    ) -> Dict:
        """
        Evaluate the selected best explanation method again on new data.
        Args:
            model (torch.nn.Module): Convolutional neural network to be explained.
            data (torch.Tensor): Data that will be used for the evaluation of the explanation method. shape: (N, C, H, W).
            targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).
            attributions (torch.Tensor, optional): Attributions for this data that were previously computed, to skip computing them once more.
            aggregate (bool, optional): Indicates whether results should be aggregated (in the same manner as in AutoExplainer).

        Returns:
            results (Dict): Results of evaluation.
        """
        self._check_model_and_data(model, data, targets)
        if attributions is not None:
            self._check_attributions(data, attributions)
        self._check_is_bool(aggregate, "aggregate")
        print(
            f"Evaluating explanation method {EXPLANATION_NAME_SHORT_TO_LONG[self.name]} using {len(self.metric_handlers)} metrics."
        )
        print(f"\tMetrics: {', '.join([METRIC_NAME_SHORT_TO_LONG[x] for x in list(self.metric_handlers.keys())])}")
        print("This may take a long time, depending on the number of samples and metrics.")
        model = fix_relus_in_model(model)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            raw_results = {}
            pbar = tqdm.tqdm(self.metric_handlers.items(), total=len(self.metric_handlers), desc="Evaluating")
            for metric_name, metric_handler in pbar:
                raw_results[metric_name] = metric_handler.compute_metric_values(
                    model=model,
                    data=data,
                    targets=targets,
                    attributions=attributions,
                    explanation_func=self.explanation_function,
                )
        if not aggregate:
            return raw_results
        else:
            raw_results = {self.name: raw_results}
            first_aggregation_results = self.aggregation_parameters["first_stage_aggregation_function"](raw_results)
            return first_aggregation_results[self.name]

    def _check_attributions(self, data: torch.Tensor, attributions: torch.Tensor) -> None:
        if not isinstance(attributions, torch.Tensor):
            raise TypeError("Attributions must be of type torch.Tensor.")  # noqa: TC003
        if len(attributions.shape) != 4:
            raise ValueError("Attributions must be of shape (N, C, H, W).")  # noqa: TC003
        if data.shape[0] != attributions.shape[0]:
            raise ValueError("Data and targets must have the same number of observations.")  # noqa: TC003
        if data.shape[2:] != attributions.shape[2:]:
            raise ValueError("Data and attributions must have the same image shape.")  # noqa: TC003

    def _check_is_bool(self, value: bool, value_name: str) -> None:
        if not isinstance(value, bool):
            raise TypeError(f"Value  of {value_name} must be of type bool.")  # noqa: TC003

evaluate(model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, attributions: torch.Tensor = None, aggregate: bool = False) -> Dict

Evaluate the selected best explanation method again on new data.

Parameters:

Name Type Description Default
model torch.nn.Module

Convolutional neural network to be explained.

required
data torch.Tensor

Data that will be used for the evaluation of the explanation method. shape: (N, C, H, W).

required
targets torch.Tensor

Labels for provided data. Encoded as integer vector with shape (N,).

required
attributions torch.Tensor

Attributions for this data that were previously computed, to skip computing them once more.

None
aggregate bool

Indicates whether results should be aggregated (in the same manner as in AutoExplainer).

False

Returns:

Name Type Description
results Dict

Results of evaluation.

Source code in autoexplainer/explanations/explanation_handlers.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def evaluate(
    self,
    model: torch.nn.Module,
    data: torch.Tensor,
    targets: torch.Tensor,
    attributions: torch.Tensor = None,
    aggregate: bool = False,
) -> Dict:
    """
    Evaluate the selected best explanation method again on new data.
    Args:
        model (torch.nn.Module): Convolutional neural network to be explained.
        data (torch.Tensor): Data that will be used for the evaluation of the explanation method. shape: (N, C, H, W).
        targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).
        attributions (torch.Tensor, optional): Attributions for this data that were previously computed, to skip computing them once more.
        aggregate (bool, optional): Indicates whether results should be aggregated (in the same manner as in AutoExplainer).

    Returns:
        results (Dict): Results of evaluation.
    """
    self._check_model_and_data(model, data, targets)
    if attributions is not None:
        self._check_attributions(data, attributions)
    self._check_is_bool(aggregate, "aggregate")
    print(
        f"Evaluating explanation method {EXPLANATION_NAME_SHORT_TO_LONG[self.name]} using {len(self.metric_handlers)} metrics."
    )
    print(f"\tMetrics: {', '.join([METRIC_NAME_SHORT_TO_LONG[x] for x in list(self.metric_handlers.keys())])}")
    print("This may take a long time, depending on the number of samples and metrics.")
    model = fix_relus_in_model(model)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        raw_results = {}
        pbar = tqdm.tqdm(self.metric_handlers.items(), total=len(self.metric_handlers), desc="Evaluating")
        for metric_name, metric_handler in pbar:
            raw_results[metric_name] = metric_handler.compute_metric_values(
                model=model,
                data=data,
                targets=targets,
                attributions=attributions,
                explanation_func=self.explanation_function,
            )
    if not aggregate:
        return raw_results
    else:
        raw_results = {self.name: raw_results}
        first_aggregation_results = self.aggregation_parameters["first_stage_aggregation_function"](raw_results)
        return first_aggregation_results[self.name]

explain(model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> torch.Tensor

Compute new attributions.

Parameters:

Name Type Description Default
model torch.nn.Module

CNN neural network to be explained.

required
data torch.Tensor

Data for which attributions will be computed. shape: (N, C, H, W).

required
targets torch.Tensor

Labels for provided data. Encoded as integer vector with shape (N,).

required

Returns:

Type Description
torch.Tensor

attributions (torch.Tensor)

Source code in autoexplainer/explanations/explanation_handlers.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def explain(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Compute new attributions.
    Args:
        model (torch.nn.Module): CNN neural network to be explained.
        data (torch.Tensor): Data for which attributions will be computed. shape: (N, C, H, W).
        targets (torch.Tensor): Labels for provided data. Encoded as integer vector with shape (N,).

    Returns:
        attributions (torch.Tensor)

    """
    self._check_model_and_data(model, data, targets)
    print(f"Computing attributions using {EXPLANATION_NAME_SHORT_TO_LONG[self.name]} method.")
    print("This may take a while, depending on the number of samples to be explained.")
    model = fix_relus_in_model(model)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        attributions_list = []
        for x, y in tqdm.tqdm(zip(data, targets), total=len(data), desc="Calculating attributions"):
            attr = self.explanation_function(
                model=model,
                inputs=x.reshape(1, *x.shape).to(next(model.parameters()).device),
                targets=y.to(next(model.parameters()).device),
            )
            attributions_list.append(torch.tensor(attr)[0])
    all_attributions = torch.stack(attributions_list, dim=0)
    print("Finished.")
    return all_attributions

ExplanationHandler

Bases: ABC

Abstract class for explanation methods handlers. Handlers manage explanation methods: they read and adapt parameters for given model and data. They also create explanation functions that may be used by the user or can be passed to metric handlers.

Parameters:

Name Type Description Default
model torch.nn.Module

Model used for methods' parameter adaptation.

required
data torch.Tensor

Data used for method's parameters adaptation. Tensor with shape (N, C, H, W)

required
targets torch.Tensor

Target used for method's parameters inference - integer vector with shape (N,)

required
explanation_parameters Dict

Explanation method parameters to be overwritten.

None

Attributes:

Name Type Description
explanation_function Callable

Explanation method as function ready to be used with already set parameters.

explanation_parameters Dict

Parameters chosen for given explanation method.

attributions torch.Tensor

Computed attributions, only most recent.

Source code in autoexplainer/explanations/explanation_handlers.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class ExplanationHandler(ABC):
    """
    Abstract class for explanation methods handlers. Handlers manage explanation methods: they read and adapt parameters for
    given model and data. They also create explanation functions that may be used by the user or can be passed to metric handlers.

    Parameters:
        model (torch.nn.Module): Model used for methods' parameter adaptation.
        data (torch.Tensor): Data used for method's parameters adaptation. Tensor with shape (N, C, H, W)
        targets (torch.Tensor): Target used for method's parameters inference - integer vector with shape (N,)
        explanation_parameters (Dict): Explanation method parameters to be overwritten.

    Attributes:
        explanation_function (Callable): Explanation method as function ready to be used with already set parameters.
        explanation_parameters (Dict): Parameters chosen for given explanation method.
        attributions (torch.Tensor): Computed attributions, only most recent.

    """

    explanation_function: Callable = NotImplemented
    explanation_parameters: Dict = None
    attributions: torch.Tensor = None

    @abstractmethod
    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, explanation_parameters: Dict = None
    ) -> None:
        pass

    def explain(self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            self.attributions = torch.tensor(self.explanation_function(model, data, targets)).to(
                next(model.parameters()).device
            )
        return self.attributions

    def get_explanation_function(self) -> Callable:
        """Return function that can be run by Quantus metrics."""
        return self.explanation_function

    def get_parameters(self) -> Dict:
        return self.explanation_parameters

    def _create_mask_function(self, mask_parameters: Dict) -> Callable:
        mask_function_name: str = mask_parameters.get("mask_function_name")
        del mask_parameters["mask_function_name"]
        return partial(batch_segmentation, mask_function_name=mask_function_name, **mask_parameters)

get_explanation_function() -> Callable

Return function that can be run by Quantus metrics.

Source code in autoexplainer/explanations/explanation_handlers.py
66
67
68
def get_explanation_function(self) -> Callable:
    """Return function that can be run by Quantus metrics."""
    return self.explanation_function

GradCamHandler

Bases: ExplanationHandler

Handler for GradCam explanation method. Uses captum implementation of GradCam.

By default, the last convolutional layer is chosen as a parameter for GradCam.

To overwrite default parameters, passed dictionary must be in the form:

explanation_parameters = {"explanation_parameters":{ <parameters accepted by GradCam in Captum> }}
Source code in autoexplainer/explanations/explanation_handlers.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class GradCamHandler(ExplanationHandler):
    """
    Handler for GradCam explanation method. Uses captum implementation of GradCam.

    By default, the last convolutional layer is chosen as a parameter for GradCam.

    To overwrite default parameters, passed dictionary must be in the form:
    ```
    explanation_parameters = {"explanation_parameters":{ <parameters accepted by GradCam in Captum> }}
    ```
    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, explanation_parameters: Dict = None
    ) -> None:
        self.explanation_parameters = {}
        self.explanation_parameters["explanation_parameters"] = self._infer_grad_cam_parameters(model, data)
        if explanation_parameters is not None:
            update_dictionary(self.explanation_parameters, explanation_parameters)
        if self.explanation_parameters["explanation_parameters"].get("selected_layer") is None:
            raise ValueError("Unrecognized model, you need to pass selected layer name for GradCam.")  # noqa: TC003
        self.explanation_function = partial(grad_cam, **self.explanation_parameters["explanation_parameters"])

    def _infer_grad_cam_parameters(self, model: torch.nn.Module, data: torch.Tensor) -> Dict:
        parameters = {}
        layer_name = self._get_last_conv_layer_name(model)
        parameters["selected_layer"] = layer_name
        parameters["relu_attributions"] = True
        return parameters

    def _get_last_conv_layer_name(self, model: torch.nn.Module) -> Union[str, None]:
        last_conv_layer_name = None
        for name, layer in model.named_modules():
            if isinstance(layer, torch.nn.Conv2d):
                last_conv_layer_name = name
        if last_conv_layer_name:
            return last_conv_layer_name
        return None

IntegratedGradients

Bases: ExplanationHandler

Handler for Integrated Gradients explanation method. Uses implementation in Quantus library.

To overwrite default parameters, passed dictionary must be in the form:

explanation_parameters = {"explanation_parameters":{ <parameters> }}

Integrated Gradients method accepts paramteters:

  • normalise (Bool) - Normalize attribution values. default=False
  • abs (bool) - Return absolute values of attribtuion. default=False
  • 'pos_only` (bool) - Clip negative values of attribution to 0.0. default=False
  • neg_only (bool) - Clip positive values of attribution to 0.0. default=False
Source code in autoexplainer/explanations/explanation_handlers.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
class IntegratedGradients(ExplanationHandler):
    """
    Handler for Integrated Gradients explanation method. Uses implementation in Quantus library.

    To overwrite default parameters, passed dictionary must be in the form:
    ```
    explanation_parameters = {"explanation_parameters":{ <parameters> }}
    ```

    Integrated Gradients method accepts paramteters:

    - `normalise` (Bool) - Normalize attribution values. default=False
    - `abs` (bool) - Return absolute values of attribtuion. default=False
    - 'pos_only` (bool) - Clip negative values of attribution to 0.0. default=False
    - `neg_only` (bool) - Clip positive values of attribution to 0.0. default=False

    """

    # TODO: use original captum implementation instead Quantus'. Quantus implementation have hardcoded parameters
    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, explanation_parameters: Dict = None
    ) -> None:
        self.explanation_parameters = {}

        self.explanation_parameters["explanation_parameters"] = self._infer_ig_parameters(model, data)
        if explanation_parameters is not None:
            update_dictionary(self.explanation_parameters, explanation_parameters)
        self._map_baseline_function_name_to_function()
        self.explanation_function = partial(
            integrated_gradients_explanation,
            **self.explanation_parameters["explanation_parameters"],
        )

    def _infer_ig_parameters(self, model: torch.nn.Module, data: torch.Tensor) -> Dict:
        parameters = {"baseline_function_name": "black", "n_steps": 20}
        return parameters

    def _map_baseline_function_name_to_function(self) -> None:
        self.explanation_parameters["explanation_parameters"]["baseline_function"] = BASELINE_FUNCTIONS[
            self.explanation_parameters["explanation_parameters"]["baseline_function_name"]
        ]

KernelShapHandler

Bases: ExplanationHandler

Handler for Kernel Shap explanation. Uses captum implementation of Kernel Shap. Accepts parameters with form:

To overwrite default parameters, passed dictionary must be in the form:

explanation_parameters = {"mask_parameters": { "mask_function_name":<str>, <other parameters for chosen mask function> },
                   "explanation_parameters":{ "baseline_function_name":<str>, <parameters accepted by KernelShap in Captum> }}
Source code in autoexplainer/explanations/explanation_handlers.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class KernelShapHandler(ExplanationHandler):
    """
    Handler for Kernel Shap explanation. Uses captum implementation of Kernel Shap. Accepts parameters with form:

    To overwrite default parameters, passed dictionary must be in the form:
    ```
    explanation_parameters = {"mask_parameters": { "mask_function_name":<str>, <other parameters for chosen mask function> },
                       "explanation_parameters":{ "baseline_function_name":<str>, <parameters accepted by KernelShap in Captum> }}
    ```

    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, explanation_parameters: Dict = None
    ) -> None:
        self.explanation_parameters = {}

        self.explanation_parameters["mask_parameters"] = self._infer_mask_parameters(data)
        self.explanation_parameters["explanation_parameters"] = self._infer_kernel_shap_parameters(data)
        if explanation_parameters is not None:
            update_dictionary(self.explanation_parameters, explanation_parameters)
        self._set_baseline_function()
        mask_function = self._create_mask_function(self.explanation_parameters["mask_parameters"])
        self.explanation_function = partial(
            shap_explanation, mask_function=mask_function, **self.explanation_parameters["explanation_parameters"]
        )

    # this method probably will be moved to the super class
    def _infer_mask_parameters(self, data: torch.Tensor) -> Dict:
        parameters = {}
        parameters["mask_function_name"] = "slic"
        parameters["n_segments"] = 50
        return parameters

    def _infer_kernel_shap_parameters(self, data: torch.Tensor) -> Dict:
        parameters = {}
        parameters["n_samples"] = 50
        parameters["baseline_function_name"] = "black"
        return parameters

    def _set_baseline_function(self) -> None:
        self.explanation_parameters["explanation_parameters"]["baseline_function"] = BASELINE_FUNCTIONS[
            self.explanation_parameters["explanation_parameters"]["baseline_function_name"]
        ]

SaliencyHandler

Bases: ExplanationHandler

Handler for Saliency explanation method. Uses implementation in Quantus library.

To overwrite default parameters, passed dictionary must be in the form:

explanation_parameters = {"explanation_parameters":{ <parameters> }}

Saliency method accepts paramteters:

  • normalise (Bool) - Normalize attribution values. default=False
  • abs (bool) - Return absolute values of attribtuion. default=False
  • 'pos_only` (bool) - Clip negative values of attribution to 0.0. default=False
  • neg_only (bool) - Clip positive values of attribution to 0.0. default=False
Source code in autoexplainer/explanations/explanation_handlers.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
class SaliencyHandler(ExplanationHandler):
    """
    Handler for Saliency explanation method. Uses implementation in Quantus library.

    To overwrite default parameters, passed dictionary must be in the form:
    ```
    explanation_parameters = {"explanation_parameters":{ <parameters> }}
    ```

    Saliency method accepts paramteters:

    - `normalise` (Bool) - Normalize attribution values. default=False
    - `abs` (bool) - Return absolute values of attribtuion. default=False
    - 'pos_only` (bool) - Clip negative values of attribution to 0.0. default=False
    - `neg_only` (bool) - Clip positive values of attribution to 0.0. default=False

    """

    def __init__(
        self, model: torch.nn.Module, data: torch.Tensor, targets: torch.Tensor, explanation_parameters: Dict = None
    ) -> None:
        self.explanation_parameters = {}

        self.explanation_parameters["explanation_parameters"] = self._infer_saliency_parameters(model, data)
        if explanation_parameters is not None:
            update_dictionary(self.explanation_parameters, explanation_parameters)
        self.explanation_function = partial(
            saliency_explanation,
            **self.explanation_parameters["explanation_parameters"],
        )

    def _infer_saliency_parameters(self, model: torch.nn.Module, data: torch.Tensor) -> Dict:
        parameters = {"abs": True}
        return parameters