• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Image Classification Runner."""
16import os
17import re
18import json
19from time import time
20
21import numpy as np
22from scipy.stats import beta
23from PIL import Image
24
25import mindspore as ms
26from mindspore import context
27from mindspore import log
28import mindspore.dataset as ds
29from mindspore.dataset import Dataset
30from mindspore.nn import Cell, SequentialCell
31from mindspore.ops.operations import ExpandDims
32from mindspore.train._utils import check_value_type
33from mindspore.train.summary._summary_adapter import _convert_image_format
34from mindspore.train.summary.summary_record import SummaryRecord
35from mindspore.train.summary_pb2 import Explain
36from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation
37from mindspore.explainer.benchmark import Localization
38from mindspore.explainer.benchmark._attribution.metric import AttributionMetric
39from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric
40from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric
41from mindspore.explainer.explanation import RISE
42from mindspore.explainer.explanation._attribution.attribution import Attribution
43from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc
44from mindspore.explainer._utils import deprecated_error
45
46
47_EXPAND_DIMS = ExpandDims()
48
49
50def _normalize(img_np):
51    """Normalize the numpy image to the range of [0, 1]. """
52    max_ = img_np.max()
53    min_ = img_np.min()
54    normed = (img_np - min_) / (max_ - min_).clip(min=1e-10)
55    return normed
56
57
58def _np_to_image(img_np, mode):
59    """Convert numpy array to PIL image."""
60    return Image.fromarray(np.uint8(img_np * 255), mode=mode)
61
62
63class _Verifier:
64    """Verification of dataset and settings of ImageClassificationRunner."""
65    ALL = 0xFFFFFFFF
66    REGISTRATION = 1
67    DATA_N_NETWORK = 1 << 1
68    SALIENCY = 1 << 2
69    HOC = 1 << 3
70    ENVIRONMENT = 1 << 4
71
72    def _verify(self, flags):
73        """
74        Verify datasets and settings.
75
76        Args:
77            flags (int): Verification flags, use bitwise or '|' to combine multiple flags.
78                Possible bitwise flags are shown as follow.
79
80                - ALL: Verify everything.
81                - REGISTRATION: Verify explainer module registration.
82                - DATA_N_NETWORK: Verify dataset and network.
83                - SALIENCY: Verify saliency related settings.
84                - HOC: Verify HOC related settings.
85                - ENVIRONMENT: Verify the runtime environment.
86
87        Raises:
88                ValueError: Be raised for any data or settings' value problem.
89                TypeError: Be raised for any data or settings' type problem.
90                RuntimeError: Be raised for any runtime problem.
91        """
92        if flags & self.ENVIRONMENT:
93            device_target = context.get_context('device_target')
94            if device_target not in ("Ascend", "GPU"):
95                raise RuntimeError(f"Unsupported device_target: '{device_target}', "
96                                   f"only 'Ascend' or 'GPU' is supported. "
97                                   f"Please call context.set_context(device_target='Ascend') or "
98                                   f"context.set_context(device_target='GPU').")
99        if flags & (self.ENVIRONMENT | self.SALIENCY):
100            if self._is_saliency_registered:
101                mode = context.get_context('mode')
102                if mode != context.PYNATIVE_MODE:
103                    raise RuntimeError("Context mode: GRAPH_MODE is not supported, "
104                                       "please call context.set_context(mode=context.PYNATIVE_MODE).")
105
106        if flags & self.REGISTRATION:
107            if self._is_uncertainty_registered and not self._is_saliency_registered:
108                raise ValueError("Function register_uncertainty() is called but register_saliency() is not.")
109            if not self._is_saliency_registered and not self._is_hoc_registered:
110                raise ValueError(
111                    "No explanation module was registered, user should at least call register_saliency() "
112                    "or register_hierarchical_occlusion() once with proper arguments.")
113
114        if flags & (self.DATA_N_NETWORK | self.SALIENCY | self.HOC):
115            self._verify_data()
116
117        if flags & self.DATA_N_NETWORK:
118            self._verify_network()
119
120        if flags & self.SALIENCY:
121            self._verify_saliency()
122
123    def _verify_labels(self):
124        """Verify labels."""
125        label_set = set()
126        if not self._labels:
127            raise ValueError(f"The label list provided is empty.")
128        for i, label in enumerate(self._labels):
129            if label.strip() == "":
130                raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is "
131                                 f"no empty label.")
132            if label in label_set:
133                raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.")
134            label_set.add(label)
135
136    def _verify_ds_inputs_shape(self, sample, inputs, labels):
137        """Verify a dataset sample's input shape."""
138
139        if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]:
140            raise ValueError(
141                "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format(
142                    inputs.shape))
143        if len(inputs.shape) == 3:
144            log.warning(
145                "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th"
146                " dimension as batch data.".format(inputs.shape))
147        if len(sample) > 1:
148            if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1:
149                raise ValueError(
150                    "Labels shape {} is unrecognizable: outputs should not have more than two dimensions"
151                    " with length greater than 1.".format(labels.shape))
152
153        if self._is_hoc_registered:
154            if inputs.shape[-3] != 3:
155                raise ValueError(
156                    "Hierarchical occlusion is registered, images must be in 3 channels format, but "
157                    "{} channel(s) is(are) encountered.".format(inputs.shape[-3]))
158            short_side = min(inputs.shape[-2:])
159            if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN:
160                raise ValueError(
161                    "Hierarchical occlusion is registered, images' short side must be equals to or greater then "
162                    "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side))
163
164    def _verify_ds_sample(self, sample):
165        """Verify a dataset sample."""
166        if len(sample) not in [1, 2, 3]:
167            raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]"
168                             " as columns.")
169
170        if len(sample) == 3:
171            inputs, labels, bboxes = sample
172            if bboxes.shape[-1] != 4:
173                raise ValueError("The third element of dataset should be bounding boxes with shape of "
174                                 "[batch_size, num_ground_truth, 4].")
175        else:
176            if self._benchmarkers is not None:
177                if any([isinstance(bench, Localization) for bench in self._benchmarkers]):
178                    raise ValueError("The dataset must provide bboxes if Localization is to be computed.")
179
180            if len(sample) == 2:
181                inputs, labels = sample
182            if len(sample) == 1:
183                inputs = sample[0]
184
185        self._verify_ds_inputs_shape(sample, inputs, labels)
186
187    def _verify_data(self):
188        """Verify dataset and labels."""
189        self._verify_labels()
190
191        try:
192            sample = next(self._dataset.create_tuple_iterator())
193        except StopIteration:
194            raise ValueError("The dataset provided is empty.")
195
196        self._verify_ds_sample(sample)
197
198    def _verify_network(self):
199        """Verify the network."""
200        next_element = next(self._dataset.create_tuple_iterator())
201        inputs, _, _ = self._unpack_next_element(next_element)
202        prop_test = self._full_network(inputs)
203        check_value_type("output of network in explainer", prop_test, ms.Tensor)
204        if prop_test.shape[1] != len(self._labels):
205            raise ValueError("The dimension of network output does not match the no. of classes. Please "
206                             "check labels or the network in the explainer again.")
207
208    def _verify_saliency(self):
209        """Verify the saliency settings."""
210        if self._explainers:
211            explainer_classes = []
212            for explainer in self._explainers:
213                if explainer.__class__ in explainer_classes:
214                    raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! "
215                                     "Please make sure all explainers' class is distinct.")
216                if explainer.network is not self._network:
217                    raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different "
218                                     "instance from network of runner. Please make sure they are the same "
219                                     "instance.")
220                explainer_classes.append(explainer.__class__)
221        if self._benchmarkers:
222            benchmarker_classes = []
223            for benchmarker in self._benchmarkers:
224                if benchmarker.__class__ in benchmarker_classes:
225                    raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! "
226                                     "Please make sure all benchmarkers' class is distinct.")
227                if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels):
228                    raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different "
229                                     "from no. of labels of runner. Please make them are the same.")
230                benchmarker_classes.append(benchmarker.__class__)
231
232
233@deprecated_error
234class ImageClassificationRunner(_Verifier):
235    """
236    A high-level API for users to generate and store results of the explanation methods and the evaluation methods.
237
238    Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version
239    will be deprecated and will not be supported in MindInsight of current version.
240
241    Args:
242        summary_dir (str): The directory path to save the summary files which store the generated results.
243        data (tuple[Dataset, list[str]]): Tuple of dataset and the corresponding class label list. The dataset
244            should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must
245            share the exact same length and order of the network outputs.
246        network (Cell): The network(with logit outputs) to be explained.
247        activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For
248            single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification
249            tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long
250            as when combining this function with network, the final output is the probability of the input.
251
252    Raises:
253        TypeError: Be raised for any argument type problem.
254
255    Supported Platforms:
256        ``Ascend`` ``GPU``
257
258    Examples:
259        >>> from mindspore.explainer import ImageClassificationRunner
260        >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient
261        >>> from mindspore.explainer.benchmark import Faithfulness
262        >>> from mindspore.nn import Softmax
263        >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net
264        >>> from mindspore import context
265        >>>
266        >>> context.set_context(mode=context.PYNATIVE_MODE)
267        >>> # The detail of AlexNet is shown in model_zoo.official.cv.alexnet.src.alexnet.py
268        >>> net = AlexNet(10)
269        >>> # Load the checkpoint
270        >>> param_dict = load_checkpoint("/path/to/checkpoint")
271        >>> load_param_into_net(net, param_dict)
272        []
273        >>>
274        >>> # Prepare the dataset for explaining and evaluation.
275        >>> # The detail of create_dataset_cifar10 method is shown in model_zoo.official.cv.alexnet.src.dataset.py
276        >>>
277        >>> dataset = create_dataset_cifar10("/path/to/cifar/dataset", 1)
278        >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
279        >>>
280        >>> activation_fn = Softmax()
281        >>> gbp = GuidedBackprop(net)
282        >>> gradient = Gradient(net)
283        >>> explainers = [gbp, gradient]
284        >>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness")
285        >>> benchmarkers = [faithfulness]
286        >>>
287        >>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn)
288        >>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers)
289        >>> runner.run()
290    """
291
292    # datafile directory names
293    _DATAFILE_DIRNAME_PREFIX = "_explain_"
294    _ORIGINAL_IMAGE_DIRNAME = "origin_images"
295    _HEATMAP_DIRNAME = "heatmap"
296    # specfial filenames
297    _MANIFEST_FILENAME = "manifest.json"
298    # max. no. of sample per directory
299    _SAMPLE_PER_DIR = 1000
300    # seed for fixing the iterating order of the dataset
301    _DATASET_SEED = 58
302    # printing spacer
303    _SPACER = "{:120}\r"
304    # datafile directory's permission
305    _DIR_MODE = 0o700
306    # datafile's permission
307    _FILE_MODE = 0o400
308
309    def __init__(self,
310                 summary_dir,
311                 data,
312                 network,
313                 activation_fn):
314
315        check_value_type("data", data, tuple)
316        if len(data) != 2:
317            raise ValueError("Argument data is not a tuple with 2 elements")
318        check_value_type("data[0]", data[0], Dataset)
319        check_value_type("data[1]", data[1], list)
320        if not all(isinstance(ele, str) for ele in data[1]):
321            raise ValueError("Argument data[1] is not list of str.")
322
323        check_value_type("summary_dir", summary_dir, str)
324        check_value_type("network", network, Cell)
325        check_value_type("activation_fn", activation_fn, Cell)
326
327        self._summary_dir = summary_dir
328        self._dataset = data[0]
329        self._labels = data[1]
330        self._network = network
331        self._explainers = None
332        self._benchmarkers = None
333        self._uncertainty = None
334        self._hoc_searcher = None
335        self._summary_timestamp = None
336        self._sample_index = -1
337        self._manifest = None
338
339        self._full_network = SequentialCell([self._network, activation_fn])
340        self._full_network.set_train(False)
341
342        self._verify(_Verifier.DATA_N_NETWORK | _Verifier.ENVIRONMENT)
343
344    def register_saliency(self,
345                          explainers,
346                          benchmarkers=None):
347        """
348        Register saliency explanation instances.
349
350        .. warning::
351            This function can not be invoked more than once on each runner.
352
353        Args:
354            explainers (list[Attribution]): The explainers to be evaluated,
355                see `mindspore.explainer.explanation`. All explainers' class must be distinct and their network
356                must be the exact same instance of the runner's network.
357            benchmarkers (list[AttributionMetric], optional): The benchmarkers for scoring the explainers,
358                see `mindspore.explainer.benchmark`. All benchmarkers' class must be distinct.
359
360        Raises:
361            ValueError: Be raised for any data or settings' value problem.
362            TypeError: Be raised for any data or settings' type problem.
363            RuntimeError: Be raised if this function was invoked before.
364        """
365        check_value_type("explainers", explainers, list)
366        if not all(isinstance(ele, Attribution) for ele in explainers):
367            raise TypeError("Argument explainers is not list of mindspore.explainer.explanation .")
368
369        if not explainers:
370            raise ValueError("Argument explainers is empty.")
371
372        if benchmarkers is not None:
373            check_value_type("benchmarkers", benchmarkers, list)
374            if not all(isinstance(ele, AttributionMetric) for ele in benchmarkers):
375                raise TypeError("Argument benchmarkers is not list of mindspore.explainer.benchmark .")
376
377        if self._explainers is not None:
378            raise RuntimeError("Function register_saliency() was invoked already.")
379
380        self._explainers = explainers
381        self._benchmarkers = benchmarkers
382
383        try:
384            self._verify(_Verifier.SALIENCY | _Verifier.ENVIRONMENT)
385        except (ValueError, TypeError):
386            self._explainers = None
387            self._benchmarkers = None
388            raise
389
390    def register_hierarchical_occlusion(self):
391        """
392        Register hierarchical occlusion instances.
393
394        .. warning::
395            This function can not be invoked more than once on each runner.
396
397        Note:
398            Input images are required to be in 3 channels formats and the length of side short must be equals to or
399            greater than 56 pixels.
400
401        Raises:
402            ValueError: Be raised for any data or settings' value problem.
403            RuntimeError: Be raised if the function was called already.
404        """
405        if self._hoc_searcher is not None:
406            raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.")
407
408        self._hoc_searcher = hoc.Searcher(self._full_network)
409
410        try:
411            self._verify(_Verifier.HOC | _Verifier.ENVIRONMENT)
412        except ValueError:
413            self._hoc_searcher = None
414            raise
415
416    def register_uncertainty(self):
417        """
418        Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem.
419
420        .. warning::
421            This function can not be invoked more than once on each runner.
422
423        Note:
424            Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the
425            details. The actual output is standard deviation of the classification predictions and the corresponding
426            95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are
427            going to be shown on the saliency map page in MindInsight.
428
429        Raises:
430            RuntimeError: Be raised if the function was called already.
431        """
432        if self._uncertainty is not None:
433            raise RuntimeError("Function register_uncertainty() was invoked already.")
434
435        self._uncertainty = UncertaintyEvaluation(model=self._full_network,
436                                                  train_dataset=None,
437                                                  task_type='classification',
438                                                  num_classes=len(self._labels))
439
440    def run(self):
441        """
442        Run the explain job and save the result as a summary in summary_dir.
443
444        Note:
445            User should call register_saliency() once before running this function.
446
447        Raises:
448            ValueError: Be raised for any data or settings' value problem.
449            TypeError: Be raised for any data or settings' type problem.
450            RuntimeError: Be raised for any runtime problem.
451        """
452        self._verify(_Verifier.ALL)
453        self._manifest = {"saliency_map": False,
454                          "benchmark": False,
455                          "uncertainty": False,
456                          "hierarchical_occlusion": False}
457        with SummaryRecord(self._summary_dir, raise_exception=True) as summary:
458            print("Start running and writing......")
459            begin = time()
460
461            self._summary_timestamp = self._extract_timestamp(summary.file_info['file_name'])
462            if self._summary_timestamp is None:
463                raise RuntimeError("Cannot extract timestamp from summary filename!"
464                                   " It should contains a timestamp after 'summary.' .")
465
466            self._save_metadata(summary)
467
468            imageid_labels = self._run_inference(summary)
469            sample_count = self._sample_index
470            if self._is_saliency_registered:
471                self._run_saliency(summary, imageid_labels)
472                if not self._manifest["saliency_map"]:
473                    raise RuntimeError(
474                        f"No saliency map was generated in {sample_count} samples. "
475                        f"Please make sure the dataset, labels, activation function and network are properly trained "
476                        f"and configured.")
477
478            if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]:
479                raise RuntimeError(
480                    f"No Hierarchical Occlusion result was found in {sample_count} samples. "
481                    f"Please make sure the dataset, labels, activation function and network are properly trained "
482                    f"and configured.")
483
484            self._save_manifest()
485
486            print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin))
487
488    @property
489    def _is_hoc_registered(self):
490        """Check if HOC module is registered."""
491        return self._hoc_searcher is not None
492
493    @property
494    def _is_saliency_registered(self):
495        """Check if saliency module is registered."""
496        return bool(self._explainers)
497
498    @property
499    def _is_uncertainty_registered(self):
500        """Check if uncertainty module is registered."""
501        return self._uncertainty is not None
502
503    def _save_metadata(self, summary):
504        """Save metadata of the explain job to summary."""
505        print("Start writing metadata......")
506
507        explain = Explain()
508        explain.metadata.label.extend(self._labels)
509
510        if self._is_saliency_registered:
511            exp_names = [exp.__class__.__name__ for exp in self._explainers]
512            explain.metadata.explain_method.extend(exp_names)
513            if self._benchmarkers is not None:
514                bench_names = [bench.__class__.__name__ for bench in self._benchmarkers]
515                explain.metadata.benchmark_method.extend(bench_names)
516
517        summary.add_value("explainer", "metadata", explain)
518        summary.record(1)
519
520        print("Finish writing metadata.")
521
522    def _run_inference(self, summary, threshold=0.5):
523        """
524        Run inference for the dataset and write the inference related data into summary.
525
526        Args:
527            summary (SummaryRecord): The summary object to store the data.
528            threshold (float): The threshold for prediction.
529
530        Returns:
531            dict, The map of sample d to the union of its ground truth and predicted labels.
532        """
533        sample_id_labels = {}
534        self._sample_index = 0
535        ds.config.set_seed(self._DATASET_SEED)
536        for j, batch in enumerate(self._dataset):
537            now = time()
538            self._infer_batch(summary, batch, sample_id_labels, threshold)
539            self._spaced_print("Finish running and writing {}-th batch inference data."
540                               " Time elapsed: {:.3f} s".format(j, time() - now))
541        return sample_id_labels
542
543    def _infer_batch(self, summary, batch, sample_id_labels, threshold):
544        """
545        Infer a batch.
546
547        Args:
548            summary (SummaryRecord): The summary object to store the data.
549            batch (tuple): The next dataset sample.
550            sample_id_labels (dict): The sample id to labels dictionary.
551            threshold (float): The threshold for prediction.
552        """
553        inputs, labels, _ = self._unpack_next_element(batch)
554        prob = self._full_network(inputs).asnumpy()
555
556        if self._uncertainty is not None:
557            prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs)
558        else:
559            prob_var = None
560
561        for idx, inp in enumerate(inputs):
562            gt_labels = labels[idx]
563            gt_probs = [float(prob[idx][i]) for i in gt_labels]
564
565            if prob_var is not None:
566                gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels]
567                gt_itl_lows, gt_itl_his, gt_prob_sds = \
568                    self._calc_beta_intervals(gt_probs, gt_prob_vars)
569
570            data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW')
571            original_image = _np_to_image(_normalize(data_np), mode='RGB')
572            original_image_path = self._save_original_image(self._sample_index, original_image)
573
574            predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]]
575            predicted_probs = [float(prob[idx][i]) for i in predicted_labels]
576
577            if prob_var is not None:
578                predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels]
579                predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \
580                    self._calc_beta_intervals(predicted_probs, predicted_prob_vars)
581
582            union_labs = list(set(gt_labels + predicted_labels))
583            sample_id_labels[str(self._sample_index)] = union_labs
584
585            explain = Explain()
586            explain.sample_id = self._sample_index
587            explain.image_path = original_image_path
588            summary.add_value("explainer", "sample", explain)
589
590            explain = Explain()
591            explain.sample_id = self._sample_index
592            explain.ground_truth_label.extend(gt_labels)
593            explain.inference.ground_truth_prob.extend(gt_probs)
594            explain.inference.predicted_label.extend(predicted_labels)
595            explain.inference.predicted_prob.extend(predicted_probs)
596
597            if prob_var is not None:
598                explain.inference.ground_truth_prob_sd.extend(gt_prob_sds)
599                explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows)
600                explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his)
601                explain.inference.predicted_prob_sd.extend(predicted_prob_sds)
602                explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows)
603                explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his)
604
605                self._manifest["uncertainty"] = True
606
607            summary.add_value("explainer", "inference", explain)
608            summary.record(1)
609
610            if self._is_hoc_registered:
611                self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx])
612
613            self._sample_index += 1
614
615    def _run_explainer(self, summary, sample_id_labels, explainer):
616        """
617        Run the explainer.
618
619        Args:
620            summary (SummaryRecord): The summary object to store the data.
621            sample_id_labels (dict): A dict that maps the sample id and its union labels.
622            explainer (_Attribution): An Attribution object to generate saliency maps.
623        """
624        for idx, next_element in enumerate(self._dataset):
625            now = time()
626            self._spaced_print("Start running {}-th explanation data for {}......".format(
627                idx, explainer.__class__.__name__))
628            saliency_dict_lst = self._run_exp_step(next_element, explainer, sample_id_labels, summary)
629            self._spaced_print(
630                "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format(
631                    idx, explainer.__class__.__name__, time() - now))
632
633            if not self._benchmarkers:
634                continue
635
636            for bench in self._benchmarkers:
637                now = time()
638                self._spaced_print(
639                    "Start running {}-th batch {} data for {}......".format(
640                        idx, bench.__class__.__name__, explainer.__class__.__name__))
641                self._run_exp_benchmark_step(next_element, explainer, bench, saliency_dict_lst)
642                self._spaced_print(
643                    "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format(
644                        idx, bench.__class__.__name__, explainer.__class__.__name__, time() - now))
645
646    def _run_saliency(self, summary, sample_id_labels):
647        """Run the saliency explanations."""
648
649        for explainer in self._explainers:
650            explain = Explain()
651            if self._benchmarkers:
652                for bench in self._benchmarkers:
653                    bench.reset()
654            print(f"Start running and writing explanation for {explainer.__class__.__name__}......")
655            self._sample_index = 0
656            start = time()
657            ds.config.set_seed(self._DATASET_SEED)
658            self._run_explainer(summary, sample_id_labels, explainer)
659
660            if not self._benchmarkers:
661                continue
662
663            for bench in self._benchmarkers:
664                benchmark = explain.benchmark.add()
665                benchmark.explain_method = explainer.__class__.__name__
666                benchmark.benchmark_method = bench.__class__.__name__
667
668                benchmark.total_score = bench.performance
669                if isinstance(bench, LabelSensitiveMetric):
670                    benchmark.label_score.extend(bench.class_performances)
671
672            self._spaced_print("Finish running and writing explanation and benchmark data for {}. "
673                               "Time elapsed: {:.3f} s".format(explainer.__class__.__name__, time() - start))
674            summary.add_value('explainer', 'benchmark', explain)
675            summary.record(1)
676
677    def _run_hoc(self, summary, sample_id, sample_input, prob):
678        """
679        Run HOC search for a sample image, and then save the result to summary.
680
681        Args:
682            summary (SummaryRecord): The summary object to store the data.
683            sample_id (int): The sample ID.
684            sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1).
685            prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for
686                labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default).
687        """
688        if isinstance(sample_input, ms.Tensor):
689            sample_input = sample_input.asnumpy()
690        if len(sample_input.shape) == 3:
691            sample_input = np.expand_dims(sample_input, axis=0)
692
693        explain = None
694        str_mask = hoc.auto_str_mask(sample_input)
695        compiled_mask = None
696
697        for label_idx, label_prob in enumerate(prob):
698            if label_prob <= self._hoc_searcher.threshold:
699                continue
700            if compiled_mask is None:
701                compiled_mask = hoc.compile_mask(str_mask, sample_input)
702            try:
703                edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask)
704            except hoc.NoValidResultError:
705                log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} "
706                            f"label:{self._labels[label_idx]}, skipped.")
707                continue
708
709            if explain is None:
710                explain = Explain()
711                explain.sample_id = sample_id
712
713            self._add_hoc_result_to_explain(label_idx, str_mask, edit_tree, layer_outputs, explain)
714
715        if explain is not None:
716            summary.add_value("explainer", "hoc", explain)
717            summary.record(1)
718            self._manifest['hierarchical_occlusion'] = True
719
720    @staticmethod
721    def _add_hoc_result_to_explain(label_idx, str_mask, edit_tree, layer_outputs, explain):
722        """
723        Add HOC result to Explain record.
724
725        Args:
726            label_idx (int): The label index.
727            str_mask (str): The mask string.
728            edit_tree (EditStep): The result HOC edit tree.
729            layer_outputs (list[float]): The network output confident of each layer.
730            explain (Explain): The Explain record.
731        """
732        hoc_rec = explain.hoc.add()
733        hoc_rec.label = label_idx
734        hoc_rec.mask = str_mask
735        layer_count = edit_tree.max_layer + 1
736        for layer in range(layer_count):
737            steps = edit_tree.get_layer_or_leaf_steps(layer)
738            layer_output = layer_outputs[layer]
739            hoc_layer = hoc_rec.layer.add()
740            hoc_layer.prob = layer_output
741            for step in steps:
742                hoc_layer.box.extend(list(step.box))
743
744    def _add_exp_step_samples(self, explainer, sample_label_sets, batch_saliency_full, summary):
745        """
746        Add explanation results of samples to summary record.
747
748        Args:
749            explainer (Attribution): The explainer to be run.
750            sample_label_sets (list[list[int]]): The label sets of samples.
751            batch_saliency_full (Tensor): The saliency output from explainer.
752            summary (SummaryRecord): The summary record.
753        """
754        saliency_dict_lst = []
755        has_saliency_rec = False
756        for idx, label_set in enumerate(sample_label_sets):
757            saliency_dict = {}
758            explain = Explain()
759            explain.sample_id = self._sample_index
760            for k, lab in enumerate(label_set):
761                saliency = batch_saliency_full[idx:idx + 1, k:k + 1]
762                saliency_dict[lab] = saliency
763
764                saliency_np = _normalize(saliency.asnumpy().squeeze())
765                saliency_image = _np_to_image(saliency_np, mode='L')
766                heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab,
767                                                  self._sample_index, saliency_image)
768
769                explanation = explain.explanation.add()
770                explanation.explain_method = explainer.__class__.__name__
771                explanation.heatmap_path = heatmap_path
772                explanation.label = lab
773
774                has_saliency_rec = True
775
776            summary.add_value("explainer", "explanation", explain)
777            summary.record(1)
778
779            self._sample_index += 1
780            saliency_dict_lst.append(saliency_dict)
781
782        return saliency_dict_lst, has_saliency_rec
783
784    def _run_exp_step(self, next_element, explainer, sample_id_labels, summary):
785        """
786        Run the explanation for each step and write explanation results into summary.
787
788        Args:
789            next_element (Tuple): Data of one step
790            explainer (_Attribution): An Attribution object to generate saliency maps.
791            sample_id_labels (dict): A dict that maps the sample id and its union labels.
792            summary (SummaryRecord): The summary object to store the data.
793
794        Returns:
795            list, List of dict that maps label to its corresponding saliency map.
796        """
797        inputs, labels, _ = self._unpack_next_element(next_element)
798        sample_index = self._sample_index
799        sample_label_sets = []
800        for _ in range(len(labels)):
801            sample_label_sets.append(sample_id_labels[str(sample_index)])
802            sample_index += 1
803
804        batch_label_sets = self._make_label_batch(sample_label_sets)
805
806        if isinstance(explainer, RISE):
807            batch_saliency_full = explainer(inputs, batch_label_sets)
808        else:
809            batch_saliency_full = []
810            for i in range(len(batch_label_sets[0])):
811                batch_saliency = explainer(inputs, batch_label_sets[:, i])
812                batch_saliency_full.append(batch_saliency)
813            concat = ms.ops.operations.Concat(1)
814            batch_saliency_full = concat(tuple(batch_saliency_full))
815
816        saliency_dict_lst, has_saliency_rec = \
817            self._add_exp_step_samples(explainer, sample_label_sets, batch_saliency_full, summary)
818
819        if has_saliency_rec:
820            self._manifest['saliency_map'] = True
821
822        return saliency_dict_lst
823
824    def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst):
825        """Run the explanation and evaluation for each step and write explanation results into summary."""
826        inputs, labels, _ = self._unpack_next_element(next_element)
827        for idx, inp in enumerate(inputs):
828            inp = _EXPAND_DIMS(inp, 0)
829            self._manifest['benchmark'] = True
830            if isinstance(benchmarker, LabelAgnosticMetric):
831                res = benchmarker.evaluate(explainer, inp)
832                benchmarker.aggregate(res)
833                continue
834            saliency_dict = saliency_dict_lst[idx]
835            for label, saliency in saliency_dict.items():
836                if isinstance(benchmarker, Localization):
837                    _, _, bboxes = self._unpack_next_element(next_element, True)
838                    if label in labels[idx]:
839                        res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label],
840                                                   saliency=saliency)
841                        benchmarker.aggregate(res, label)
842                elif isinstance(benchmarker, LabelSensitiveMetric):
843                    res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency)
844                    benchmarker.aggregate(res, label)
845                else:
846                    raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but'
847                                    'receive {}'.format(type(benchmarker)))
848
849    @staticmethod
850    def _calc_beta_intervals(means, variances, prob=0.95):
851        """Calculate confidence interval of beta distributions."""
852        if not isinstance(means, np.ndarray):
853            means = np.array(means)
854        if not isinstance(variances, np.ndarray):
855            variances = np.array(variances)
856        with np.errstate(divide='ignore'):
857            coef_a = ((means ** 2) * (1 - means) / variances) - means
858            coef_b = (coef_a * (1 - means)) / means
859            itl_lows, itl_his = beta.interval(prob, coef_a, coef_b)
860            sds = np.sqrt(variances)
861        for i in range(itl_lows.shape[0]):
862            if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]):
863                itl_lows[i] = means[i]
864                itl_his[i] = means[i]
865                sds[i] = 0
866        return itl_lows, itl_his, sds
867
868    def _transform_bboxes(self, inputs, labels, bboxes, ifbbox):
869        """
870        Transform the bounding boxes.
871        Args:
872            inputs (Tensor): the image data
873            labels (Tensor): the labels
874            bboxes (Tensor): the boudnding boxes data
875            ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t
876                label id will be returned. If False, the returned bboxes is the the parsed bboxes.
877
878         Returns:
879            bboxes (Union[list[dict], None, Tensor]): the bounding boxes
880        """
881        input_len = len(inputs)
882        if bboxes is None or not ifbbox:
883            return bboxes
884        bboxes = ms.Tensor(bboxes, ms.int32)
885        masks_lst = []
886        labels = labels.asnumpy().reshape([input_len, -1])
887        bboxes = bboxes.asnumpy().reshape([input_len, -1, 4])
888        for idx, label in enumerate(labels):
889            height, width = inputs[idx].shape[-2], inputs[idx].shape[-1]
890            masks = {}
891            for j, label_item in enumerate(label):
892                target = int(label_item)
893                if not -1 < target < len(self._labels):
894                    continue
895                if target not in masks:
896                    mask = np.zeros((1, 1, height, width))
897                else:
898                    mask = masks[target]
899                x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int)
900                mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1
901                masks[target] = mask
902            masks_lst.append(masks)
903        bboxes = masks_lst
904        return bboxes
905
906    def _transform_data(self, inputs, labels, bboxes, ifbbox):
907        """
908        Transform the data from one iteration of dataset to a unifying form for the follow-up operations.
909
910        Args:
911            inputs (Tensor): the image data
912            labels (Tensor): the labels
913            bboxes (Tensor): the boudnding boxes data
914            ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t
915                label id will be returned. If False, the returned bboxes is the the parsed bboxes.
916
917        Returns:
918            inputs (Tensor): the image data, unified to a 4D Tensor.
919            labels (list[list[int]]): the ground truth labels.
920            bboxes (Union[list[dict], None, Tensor]): the bounding boxes
921        """
922        inputs = ms.Tensor(inputs, ms.float32)
923        if len(inputs.shape) == 3:
924            inputs = _EXPAND_DIMS(inputs, 0)
925            if isinstance(labels, ms.Tensor):
926                labels = ms.Tensor(labels, ms.int32)
927                labels = _EXPAND_DIMS(labels, 0)
928            if isinstance(bboxes, ms.Tensor):
929                bboxes = ms.Tensor(bboxes, ms.int32)
930                bboxes = _EXPAND_DIMS(bboxes, 0)
931
932        bboxes = self._transform_bboxes(inputs, labels, bboxes, ifbbox)
933
934        labels = ms.Tensor(labels, ms.int32)
935        if len(labels.shape) == 1:
936            labels_lst = [[int(i)] for i in labels.asnumpy()]
937        else:
938            labels = labels.asnumpy().reshape([len(inputs), -1])
939            labels_lst = []
940            for item in labels:
941                labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._labels))))
942        labels = labels_lst
943        return inputs, labels, bboxes
944
945    def _unpack_next_element(self, next_element, ifbbox=False):
946        """
947        Unpack a single iteration of dataset.
948
949        Args:
950            next_element (Tuple): a single element iterated from dataset object.
951            ifbbox (bool): whether to preprocess bboxes in self._transform_data.
952
953        Returns:
954            tuple, a unified Tuple contains image_data, labels, and bounding boxes.
955        """
956        if len(next_element) == 3:
957            inputs, labels, bboxes = next_element
958        elif len(next_element) == 2:
959            inputs, labels = next_element
960            bboxes = None
961        else:
962            inputs = next_element[0]
963            labels = [[] for _ in inputs]
964            bboxes = None
965        inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox)
966        return inputs, labels, bboxes
967
968    @staticmethod
969    def _make_label_batch(labels):
970        """
971        Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max
972        length of all the rows in labels.
973
974        Args:
975            labels (List[List]): the union labels of a data batch.
976
977        Returns:
978            2D Tensor.
979        """
980        max_len = max([len(label) for label in labels])
981        batch_labels = np.zeros((len(labels), max_len))
982
983        for idx, _ in enumerate(batch_labels):
984            length = len(labels[idx])
985            batch_labels[idx, :length] = np.array(labels[idx])
986
987        return ms.Tensor(batch_labels, ms.int32)
988
989    def _save_manifest(self):
990        """Save manifest.json underneath datafile directory."""
991        if self._manifest is None:
992            raise RuntimeError("Manifest not yet be initialized.")
993        path_tokens = [self._summary_dir,
994                       self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)]
995        abs_dir_path = self._create_subdir(*path_tokens)
996        save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME)
997        fd = os.open(save_path, os.O_WRONLY | os.O_CREAT, mode=self._FILE_MODE)
998        file = os.fdopen(fd, "w")
999        try:
1000            json.dump(self._manifest, file, indent=4)
1001        except IOError:
1002            log.error(f"Failed to save manifest as {save_path}!")
1003            raise
1004        finally:
1005            file.flush()
1006            os.close(fd)
1007        os.chmod(save_path, self._FILE_MODE)
1008
1009    def _save_original_image(self, sample_id, image):
1010        """Save an image to summary directory."""
1011        id_dirname = self._get_sample_dirname(sample_id)
1012        path_tokens = [self._summary_dir,
1013                       self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
1014                       self._ORIGINAL_IMAGE_DIRNAME,
1015                       id_dirname]
1016
1017        abs_dir_path = self._create_subdir(*path_tokens)
1018        filename = f"{sample_id}.jpg"
1019        save_path = os.path.join(abs_dir_path, filename)
1020        image.save(save_path)
1021        os.chmod(save_path, self._FILE_MODE)
1022        return os.path.join(*path_tokens[1:], filename)
1023
1024    def _save_heatmap(self, explain_method, class_id, sample_id, image):
1025        """Save heatmap image to summary directory."""
1026        id_dirname = self._get_sample_dirname(sample_id)
1027        path_tokens = [self._summary_dir,
1028                       self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp),
1029                       self._HEATMAP_DIRNAME,
1030                       explain_method,
1031                       id_dirname]
1032
1033        abs_dir_path = self._create_subdir(*path_tokens)
1034        filename = f"{sample_id}_{class_id}.jpg"
1035        save_path = os.path.join(abs_dir_path, filename)
1036        image.save(save_path, optimize=True)
1037        os.chmod(save_path, self._FILE_MODE)
1038        return os.path.join(*path_tokens[1:], filename)
1039
1040    def _create_subdir(self, *args):
1041        """Recursively create subdirectories."""
1042        abs_path = None
1043        for token in args:
1044            if abs_path is None:
1045                abs_path = os.path.realpath(token)
1046            else:
1047                abs_path = os.path.join(abs_path, token)
1048            # os.makedirs() don't set intermediate dir permission properly, we mkdir() one by one
1049            try:
1050                os.mkdir(abs_path, mode=self._DIR_MODE)
1051                # In some platform, mode may be ignored in os.mkdir(), we have to chmod() again to make sure
1052                os.chmod(abs_path, mode=self._DIR_MODE)
1053            except FileExistsError:
1054                pass
1055        return abs_path
1056
1057    @classmethod
1058    def _get_sample_dirname(cls, sample_id):
1059        """Get the name of parent directory of the image id."""
1060        return str(int(sample_id / cls._SAMPLE_PER_DIR) * cls._SAMPLE_PER_DIR)
1061
1062    @staticmethod
1063    def _extract_timestamp(filename):
1064        """Extract timestamp from summary filename."""
1065        matched = re.search(r"summary\.(\d+)", filename)
1066        if matched:
1067            return int(matched.group(1))
1068        return None
1069
1070    @classmethod
1071    def _spaced_print(cls, message):
1072        """Spaced message printing."""
1073        # workaround to print logs starting new line in case line width mismatch.
1074        print(cls._SPACER.format(message))
1075