1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Image Classification Runner.""" 16import os 17import re 18import json 19from time import time 20 21import numpy as np 22from scipy.stats import beta 23from PIL import Image 24 25import mindspore as ms 26from mindspore import context 27from mindspore import log 28import mindspore.dataset as ds 29from mindspore.dataset import Dataset 30from mindspore.nn import Cell, SequentialCell 31from mindspore.ops.operations import ExpandDims 32from mindspore.train._utils import check_value_type 33from mindspore.train.summary._summary_adapter import _convert_image_format 34from mindspore.train.summary.summary_record import SummaryRecord 35from mindspore.train.summary_pb2 import Explain 36from mindspore.nn.probability.toolbox.uncertainty_evaluation import UncertaintyEvaluation 37from mindspore.explainer.benchmark import Localization 38from mindspore.explainer.benchmark._attribution.metric import AttributionMetric 39from mindspore.explainer.benchmark._attribution.metric import LabelSensitiveMetric 40from mindspore.explainer.benchmark._attribution.metric import LabelAgnosticMetric 41from mindspore.explainer.explanation import RISE 42from mindspore.explainer.explanation._attribution.attribution import Attribution 43from mindspore.explainer.explanation._counterfactual import hierarchical_occlusion as hoc 44from mindspore.explainer._utils import deprecated_error 45 46 47_EXPAND_DIMS = ExpandDims() 48 49 50def _normalize(img_np): 51 """Normalize the numpy image to the range of [0, 1]. """ 52 max_ = img_np.max() 53 min_ = img_np.min() 54 normed = (img_np - min_) / (max_ - min_).clip(min=1e-10) 55 return normed 56 57 58def _np_to_image(img_np, mode): 59 """Convert numpy array to PIL image.""" 60 return Image.fromarray(np.uint8(img_np * 255), mode=mode) 61 62 63class _Verifier: 64 """Verification of dataset and settings of ImageClassificationRunner.""" 65 ALL = 0xFFFFFFFF 66 REGISTRATION = 1 67 DATA_N_NETWORK = 1 << 1 68 SALIENCY = 1 << 2 69 HOC = 1 << 3 70 ENVIRONMENT = 1 << 4 71 72 def _verify(self, flags): 73 """ 74 Verify datasets and settings. 75 76 Args: 77 flags (int): Verification flags, use bitwise or '|' to combine multiple flags. 78 Possible bitwise flags are shown as follow. 79 80 - ALL: Verify everything. 81 - REGISTRATION: Verify explainer module registration. 82 - DATA_N_NETWORK: Verify dataset and network. 83 - SALIENCY: Verify saliency related settings. 84 - HOC: Verify HOC related settings. 85 - ENVIRONMENT: Verify the runtime environment. 86 87 Raises: 88 ValueError: Be raised for any data or settings' value problem. 89 TypeError: Be raised for any data or settings' type problem. 90 RuntimeError: Be raised for any runtime problem. 91 """ 92 if flags & self.ENVIRONMENT: 93 device_target = context.get_context('device_target') 94 if device_target not in ("Ascend", "GPU"): 95 raise RuntimeError(f"Unsupported device_target: '{device_target}', " 96 f"only 'Ascend' or 'GPU' is supported. " 97 f"Please call context.set_context(device_target='Ascend') or " 98 f"context.set_context(device_target='GPU').") 99 if flags & (self.ENVIRONMENT | self.SALIENCY): 100 if self._is_saliency_registered: 101 mode = context.get_context('mode') 102 if mode != context.PYNATIVE_MODE: 103 raise RuntimeError("Context mode: GRAPH_MODE is not supported, " 104 "please call context.set_context(mode=context.PYNATIVE_MODE).") 105 106 if flags & self.REGISTRATION: 107 if self._is_uncertainty_registered and not self._is_saliency_registered: 108 raise ValueError("Function register_uncertainty() is called but register_saliency() is not.") 109 if not self._is_saliency_registered and not self._is_hoc_registered: 110 raise ValueError( 111 "No explanation module was registered, user should at least call register_saliency() " 112 "or register_hierarchical_occlusion() once with proper arguments.") 113 114 if flags & (self.DATA_N_NETWORK | self.SALIENCY | self.HOC): 115 self._verify_data() 116 117 if flags & self.DATA_N_NETWORK: 118 self._verify_network() 119 120 if flags & self.SALIENCY: 121 self._verify_saliency() 122 123 def _verify_labels(self): 124 """Verify labels.""" 125 label_set = set() 126 if not self._labels: 127 raise ValueError(f"The label list provided is empty.") 128 for i, label in enumerate(self._labels): 129 if label.strip() == "": 130 raise ValueError(f"Label [{i}] is all whitespaces or empty. Please make sure there is " 131 f"no empty label.") 132 if label in label_set: 133 raise ValueError(f"Duplicated label:{label}! Please make sure all labels are unique.") 134 label_set.add(label) 135 136 def _verify_ds_inputs_shape(self, sample, inputs, labels): 137 """Verify a dataset sample's input shape.""" 138 139 if len(inputs.shape) > 4 or len(inputs.shape) < 3 or inputs.shape[-3] not in [1, 3, 4]: 140 raise ValueError( 141 "Image shape {} is unrecognizable: the dimension of image can only be CHW or NCHW.".format( 142 inputs.shape)) 143 if len(inputs.shape) == 3: 144 log.warning( 145 "Image shape {} is 3-dimensional. All the data will be automatically unsqueezed at the 0-th" 146 " dimension as batch data.".format(inputs.shape)) 147 if len(sample) > 1: 148 if len(labels.shape) > 2 and (np.array(labels.shape[1:]) > 1).sum() > 1: 149 raise ValueError( 150 "Labels shape {} is unrecognizable: outputs should not have more than two dimensions" 151 " with length greater than 1.".format(labels.shape)) 152 153 if self._is_hoc_registered: 154 if inputs.shape[-3] != 3: 155 raise ValueError( 156 "Hierarchical occlusion is registered, images must be in 3 channels format, but " 157 "{} channel(s) is(are) encountered.".format(inputs.shape[-3])) 158 short_side = min(inputs.shape[-2:]) 159 if short_side < hoc.AUTO_IMAGE_SHORT_SIDE_MIN: 160 raise ValueError( 161 "Hierarchical occlusion is registered, images' short side must be equals to or greater then " 162 "{}, but {} is encountered.".format(hoc.AUTO_IMAGE_SHORT_SIDE_MIN, short_side)) 163 164 def _verify_ds_sample(self, sample): 165 """Verify a dataset sample.""" 166 if len(sample) not in [1, 2, 3]: 167 raise ValueError("The dataset should provide [images] or [images, labels], [images, labels, bboxes]" 168 " as columns.") 169 170 if len(sample) == 3: 171 inputs, labels, bboxes = sample 172 if bboxes.shape[-1] != 4: 173 raise ValueError("The third element of dataset should be bounding boxes with shape of " 174 "[batch_size, num_ground_truth, 4].") 175 else: 176 if self._benchmarkers is not None: 177 if any([isinstance(bench, Localization) for bench in self._benchmarkers]): 178 raise ValueError("The dataset must provide bboxes if Localization is to be computed.") 179 180 if len(sample) == 2: 181 inputs, labels = sample 182 if len(sample) == 1: 183 inputs = sample[0] 184 185 self._verify_ds_inputs_shape(sample, inputs, labels) 186 187 def _verify_data(self): 188 """Verify dataset and labels.""" 189 self._verify_labels() 190 191 try: 192 sample = next(self._dataset.create_tuple_iterator()) 193 except StopIteration: 194 raise ValueError("The dataset provided is empty.") 195 196 self._verify_ds_sample(sample) 197 198 def _verify_network(self): 199 """Verify the network.""" 200 next_element = next(self._dataset.create_tuple_iterator()) 201 inputs, _, _ = self._unpack_next_element(next_element) 202 prop_test = self._full_network(inputs) 203 check_value_type("output of network in explainer", prop_test, ms.Tensor) 204 if prop_test.shape[1] != len(self._labels): 205 raise ValueError("The dimension of network output does not match the no. of classes. Please " 206 "check labels or the network in the explainer again.") 207 208 def _verify_saliency(self): 209 """Verify the saliency settings.""" 210 if self._explainers: 211 explainer_classes = [] 212 for explainer in self._explainers: 213 if explainer.__class__ in explainer_classes: 214 raise ValueError(f"Repeated {explainer.__class__.__name__} explainer! " 215 "Please make sure all explainers' class is distinct.") 216 if explainer.network is not self._network: 217 raise ValueError(f"The network of {explainer.__class__.__name__} explainer is different " 218 "instance from network of runner. Please make sure they are the same " 219 "instance.") 220 explainer_classes.append(explainer.__class__) 221 if self._benchmarkers: 222 benchmarker_classes = [] 223 for benchmarker in self._benchmarkers: 224 if benchmarker.__class__ in benchmarker_classes: 225 raise ValueError(f"Repeated {benchmarker.__class__.__name__} benchmarker! " 226 "Please make sure all benchmarkers' class is distinct.") 227 if isinstance(benchmarker, LabelSensitiveMetric) and benchmarker.num_labels != len(self._labels): 228 raise ValueError(f"The num_labels of {benchmarker.__class__.__name__} benchmarker is different " 229 "from no. of labels of runner. Please make them are the same.") 230 benchmarker_classes.append(benchmarker.__class__) 231 232 233@deprecated_error 234class ImageClassificationRunner(_Verifier): 235 """ 236 A high-level API for users to generate and store results of the explanation methods and the evaluation methods. 237 238 Update in 2020.11: Adjust the storage structure and format of the data. Summary files generated by previous version 239 will be deprecated and will not be supported in MindInsight of current version. 240 241 Args: 242 summary_dir (str): The directory path to save the summary files which store the generated results. 243 data (tuple[Dataset, list[str]]): Tuple of dataset and the corresponding class label list. The dataset 244 should provides [images], [images, labels] or [images, labels, bboxes] as columns. The label list must 245 share the exact same length and order of the network outputs. 246 network (Cell): The network(with logit outputs) to be explained. 247 activation_fn (Cell): The activation layer that transforms logits to prediction probabilities. For 248 single label classification tasks, `nn.Softmax` is usually applied. As for multi-label classification 249 tasks, `nn.Sigmoid` is usually be applied. Users can also pass their own customized `activation_fn` as long 250 as when combining this function with network, the final output is the probability of the input. 251 252 Raises: 253 TypeError: Be raised for any argument type problem. 254 255 Supported Platforms: 256 ``Ascend`` ``GPU`` 257 258 Examples: 259 >>> from mindspore.explainer import ImageClassificationRunner 260 >>> from mindspore.explainer.explanation import GuidedBackprop, Gradient 261 >>> from mindspore.explainer.benchmark import Faithfulness 262 >>> from mindspore.nn import Softmax 263 >>> from mindspore.train.serialization import load_checkpoint, load_param_into_net 264 >>> from mindspore import context 265 >>> 266 >>> context.set_context(mode=context.PYNATIVE_MODE) 267 >>> # The detail of AlexNet is shown in model_zoo.official.cv.alexnet.src.alexnet.py 268 >>> net = AlexNet(10) 269 >>> # Load the checkpoint 270 >>> param_dict = load_checkpoint("/path/to/checkpoint") 271 >>> load_param_into_net(net, param_dict) 272 [] 273 >>> 274 >>> # Prepare the dataset for explaining and evaluation. 275 >>> # The detail of create_dataset_cifar10 method is shown in model_zoo.official.cv.alexnet.src.dataset.py 276 >>> 277 >>> dataset = create_dataset_cifar10("/path/to/cifar/dataset", 1) 278 >>> labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 279 >>> 280 >>> activation_fn = Softmax() 281 >>> gbp = GuidedBackprop(net) 282 >>> gradient = Gradient(net) 283 >>> explainers = [gbp, gradient] 284 >>> faithfulness = Faithfulness(len(labels), activation_fn, "NaiveFaithfulness") 285 >>> benchmarkers = [faithfulness] 286 >>> 287 >>> runner = ImageClassificationRunner("./summary_dir", (dataset, labels), net, activation_fn) 288 >>> runner.register_saliency(explainers=explainers, benchmarkers=benchmarkers) 289 >>> runner.run() 290 """ 291 292 # datafile directory names 293 _DATAFILE_DIRNAME_PREFIX = "_explain_" 294 _ORIGINAL_IMAGE_DIRNAME = "origin_images" 295 _HEATMAP_DIRNAME = "heatmap" 296 # specfial filenames 297 _MANIFEST_FILENAME = "manifest.json" 298 # max. no. of sample per directory 299 _SAMPLE_PER_DIR = 1000 300 # seed for fixing the iterating order of the dataset 301 _DATASET_SEED = 58 302 # printing spacer 303 _SPACER = "{:120}\r" 304 # datafile directory's permission 305 _DIR_MODE = 0o700 306 # datafile's permission 307 _FILE_MODE = 0o400 308 309 def __init__(self, 310 summary_dir, 311 data, 312 network, 313 activation_fn): 314 315 check_value_type("data", data, tuple) 316 if len(data) != 2: 317 raise ValueError("Argument data is not a tuple with 2 elements") 318 check_value_type("data[0]", data[0], Dataset) 319 check_value_type("data[1]", data[1], list) 320 if not all(isinstance(ele, str) for ele in data[1]): 321 raise ValueError("Argument data[1] is not list of str.") 322 323 check_value_type("summary_dir", summary_dir, str) 324 check_value_type("network", network, Cell) 325 check_value_type("activation_fn", activation_fn, Cell) 326 327 self._summary_dir = summary_dir 328 self._dataset = data[0] 329 self._labels = data[1] 330 self._network = network 331 self._explainers = None 332 self._benchmarkers = None 333 self._uncertainty = None 334 self._hoc_searcher = None 335 self._summary_timestamp = None 336 self._sample_index = -1 337 self._manifest = None 338 339 self._full_network = SequentialCell([self._network, activation_fn]) 340 self._full_network.set_train(False) 341 342 self._verify(_Verifier.DATA_N_NETWORK | _Verifier.ENVIRONMENT) 343 344 def register_saliency(self, 345 explainers, 346 benchmarkers=None): 347 """ 348 Register saliency explanation instances. 349 350 .. warning:: 351 This function can not be invoked more than once on each runner. 352 353 Args: 354 explainers (list[Attribution]): The explainers to be evaluated, 355 see `mindspore.explainer.explanation`. All explainers' class must be distinct and their network 356 must be the exact same instance of the runner's network. 357 benchmarkers (list[AttributionMetric], optional): The benchmarkers for scoring the explainers, 358 see `mindspore.explainer.benchmark`. All benchmarkers' class must be distinct. 359 360 Raises: 361 ValueError: Be raised for any data or settings' value problem. 362 TypeError: Be raised for any data or settings' type problem. 363 RuntimeError: Be raised if this function was invoked before. 364 """ 365 check_value_type("explainers", explainers, list) 366 if not all(isinstance(ele, Attribution) for ele in explainers): 367 raise TypeError("Argument explainers is not list of mindspore.explainer.explanation .") 368 369 if not explainers: 370 raise ValueError("Argument explainers is empty.") 371 372 if benchmarkers is not None: 373 check_value_type("benchmarkers", benchmarkers, list) 374 if not all(isinstance(ele, AttributionMetric) for ele in benchmarkers): 375 raise TypeError("Argument benchmarkers is not list of mindspore.explainer.benchmark .") 376 377 if self._explainers is not None: 378 raise RuntimeError("Function register_saliency() was invoked already.") 379 380 self._explainers = explainers 381 self._benchmarkers = benchmarkers 382 383 try: 384 self._verify(_Verifier.SALIENCY | _Verifier.ENVIRONMENT) 385 except (ValueError, TypeError): 386 self._explainers = None 387 self._benchmarkers = None 388 raise 389 390 def register_hierarchical_occlusion(self): 391 """ 392 Register hierarchical occlusion instances. 393 394 .. warning:: 395 This function can not be invoked more than once on each runner. 396 397 Note: 398 Input images are required to be in 3 channels formats and the length of side short must be equals to or 399 greater than 56 pixels. 400 401 Raises: 402 ValueError: Be raised for any data or settings' value problem. 403 RuntimeError: Be raised if the function was called already. 404 """ 405 if self._hoc_searcher is not None: 406 raise RuntimeError("Function register_hierarchical_occlusion() was invoked already.") 407 408 self._hoc_searcher = hoc.Searcher(self._full_network) 409 410 try: 411 self._verify(_Verifier.HOC | _Verifier.ENVIRONMENT) 412 except ValueError: 413 self._hoc_searcher = None 414 raise 415 416 def register_uncertainty(self): 417 """ 418 Register uncertainty instance to compute the epistemic uncertainty base on the Bayes' theorem. 419 420 .. warning:: 421 This function can not be invoked more than once on each runner. 422 423 Note: 424 Please refer to the documentation of mindspore.nn.probability.toolbox.uncertainty_evaluation for the 425 details. The actual output is standard deviation of the classification predictions and the corresponding 426 95% confidence intervals. Users have to invoke register_saliency() as well for the uncertainty results are 427 going to be shown on the saliency map page in MindInsight. 428 429 Raises: 430 RuntimeError: Be raised if the function was called already. 431 """ 432 if self._uncertainty is not None: 433 raise RuntimeError("Function register_uncertainty() was invoked already.") 434 435 self._uncertainty = UncertaintyEvaluation(model=self._full_network, 436 train_dataset=None, 437 task_type='classification', 438 num_classes=len(self._labels)) 439 440 def run(self): 441 """ 442 Run the explain job and save the result as a summary in summary_dir. 443 444 Note: 445 User should call register_saliency() once before running this function. 446 447 Raises: 448 ValueError: Be raised for any data or settings' value problem. 449 TypeError: Be raised for any data or settings' type problem. 450 RuntimeError: Be raised for any runtime problem. 451 """ 452 self._verify(_Verifier.ALL) 453 self._manifest = {"saliency_map": False, 454 "benchmark": False, 455 "uncertainty": False, 456 "hierarchical_occlusion": False} 457 with SummaryRecord(self._summary_dir, raise_exception=True) as summary: 458 print("Start running and writing......") 459 begin = time() 460 461 self._summary_timestamp = self._extract_timestamp(summary.file_info['file_name']) 462 if self._summary_timestamp is None: 463 raise RuntimeError("Cannot extract timestamp from summary filename!" 464 " It should contains a timestamp after 'summary.' .") 465 466 self._save_metadata(summary) 467 468 imageid_labels = self._run_inference(summary) 469 sample_count = self._sample_index 470 if self._is_saliency_registered: 471 self._run_saliency(summary, imageid_labels) 472 if not self._manifest["saliency_map"]: 473 raise RuntimeError( 474 f"No saliency map was generated in {sample_count} samples. " 475 f"Please make sure the dataset, labels, activation function and network are properly trained " 476 f"and configured.") 477 478 if self._is_hoc_registered and not self._manifest["hierarchical_occlusion"]: 479 raise RuntimeError( 480 f"No Hierarchical Occlusion result was found in {sample_count} samples. " 481 f"Please make sure the dataset, labels, activation function and network are properly trained " 482 f"and configured.") 483 484 self._save_manifest() 485 486 print("Finish running and writing. Total time elapsed: {:.3f} s".format(time() - begin)) 487 488 @property 489 def _is_hoc_registered(self): 490 """Check if HOC module is registered.""" 491 return self._hoc_searcher is not None 492 493 @property 494 def _is_saliency_registered(self): 495 """Check if saliency module is registered.""" 496 return bool(self._explainers) 497 498 @property 499 def _is_uncertainty_registered(self): 500 """Check if uncertainty module is registered.""" 501 return self._uncertainty is not None 502 503 def _save_metadata(self, summary): 504 """Save metadata of the explain job to summary.""" 505 print("Start writing metadata......") 506 507 explain = Explain() 508 explain.metadata.label.extend(self._labels) 509 510 if self._is_saliency_registered: 511 exp_names = [exp.__class__.__name__ for exp in self._explainers] 512 explain.metadata.explain_method.extend(exp_names) 513 if self._benchmarkers is not None: 514 bench_names = [bench.__class__.__name__ for bench in self._benchmarkers] 515 explain.metadata.benchmark_method.extend(bench_names) 516 517 summary.add_value("explainer", "metadata", explain) 518 summary.record(1) 519 520 print("Finish writing metadata.") 521 522 def _run_inference(self, summary, threshold=0.5): 523 """ 524 Run inference for the dataset and write the inference related data into summary. 525 526 Args: 527 summary (SummaryRecord): The summary object to store the data. 528 threshold (float): The threshold for prediction. 529 530 Returns: 531 dict, The map of sample d to the union of its ground truth and predicted labels. 532 """ 533 sample_id_labels = {} 534 self._sample_index = 0 535 ds.config.set_seed(self._DATASET_SEED) 536 for j, batch in enumerate(self._dataset): 537 now = time() 538 self._infer_batch(summary, batch, sample_id_labels, threshold) 539 self._spaced_print("Finish running and writing {}-th batch inference data." 540 " Time elapsed: {:.3f} s".format(j, time() - now)) 541 return sample_id_labels 542 543 def _infer_batch(self, summary, batch, sample_id_labels, threshold): 544 """ 545 Infer a batch. 546 547 Args: 548 summary (SummaryRecord): The summary object to store the data. 549 batch (tuple): The next dataset sample. 550 sample_id_labels (dict): The sample id to labels dictionary. 551 threshold (float): The threshold for prediction. 552 """ 553 inputs, labels, _ = self._unpack_next_element(batch) 554 prob = self._full_network(inputs).asnumpy() 555 556 if self._uncertainty is not None: 557 prob_var = self._uncertainty.eval_epistemic_uncertainty(inputs) 558 else: 559 prob_var = None 560 561 for idx, inp in enumerate(inputs): 562 gt_labels = labels[idx] 563 gt_probs = [float(prob[idx][i]) for i in gt_labels] 564 565 if prob_var is not None: 566 gt_prob_vars = [float(prob_var[idx][i]) for i in gt_labels] 567 gt_itl_lows, gt_itl_his, gt_prob_sds = \ 568 self._calc_beta_intervals(gt_probs, gt_prob_vars) 569 570 data_np = _convert_image_format(np.expand_dims(inp.asnumpy(), 0), 'NCHW') 571 original_image = _np_to_image(_normalize(data_np), mode='RGB') 572 original_image_path = self._save_original_image(self._sample_index, original_image) 573 574 predicted_labels = [int(i) for i in (prob[idx] > threshold).nonzero()[0]] 575 predicted_probs = [float(prob[idx][i]) for i in predicted_labels] 576 577 if prob_var is not None: 578 predicted_prob_vars = [float(prob_var[idx][i]) for i in predicted_labels] 579 predicted_itl_lows, predicted_itl_his, predicted_prob_sds = \ 580 self._calc_beta_intervals(predicted_probs, predicted_prob_vars) 581 582 union_labs = list(set(gt_labels + predicted_labels)) 583 sample_id_labels[str(self._sample_index)] = union_labs 584 585 explain = Explain() 586 explain.sample_id = self._sample_index 587 explain.image_path = original_image_path 588 summary.add_value("explainer", "sample", explain) 589 590 explain = Explain() 591 explain.sample_id = self._sample_index 592 explain.ground_truth_label.extend(gt_labels) 593 explain.inference.ground_truth_prob.extend(gt_probs) 594 explain.inference.predicted_label.extend(predicted_labels) 595 explain.inference.predicted_prob.extend(predicted_probs) 596 597 if prob_var is not None: 598 explain.inference.ground_truth_prob_sd.extend(gt_prob_sds) 599 explain.inference.ground_truth_prob_itl95_low.extend(gt_itl_lows) 600 explain.inference.ground_truth_prob_itl95_hi.extend(gt_itl_his) 601 explain.inference.predicted_prob_sd.extend(predicted_prob_sds) 602 explain.inference.predicted_prob_itl95_low.extend(predicted_itl_lows) 603 explain.inference.predicted_prob_itl95_hi.extend(predicted_itl_his) 604 605 self._manifest["uncertainty"] = True 606 607 summary.add_value("explainer", "inference", explain) 608 summary.record(1) 609 610 if self._is_hoc_registered: 611 self._run_hoc(summary, self._sample_index, inputs[idx], prob[idx]) 612 613 self._sample_index += 1 614 615 def _run_explainer(self, summary, sample_id_labels, explainer): 616 """ 617 Run the explainer. 618 619 Args: 620 summary (SummaryRecord): The summary object to store the data. 621 sample_id_labels (dict): A dict that maps the sample id and its union labels. 622 explainer (_Attribution): An Attribution object to generate saliency maps. 623 """ 624 for idx, next_element in enumerate(self._dataset): 625 now = time() 626 self._spaced_print("Start running {}-th explanation data for {}......".format( 627 idx, explainer.__class__.__name__)) 628 saliency_dict_lst = self._run_exp_step(next_element, explainer, sample_id_labels, summary) 629 self._spaced_print( 630 "Finish writing {}-th batch explanation data for {}. Time elapsed: {:.3f} s".format( 631 idx, explainer.__class__.__name__, time() - now)) 632 633 if not self._benchmarkers: 634 continue 635 636 for bench in self._benchmarkers: 637 now = time() 638 self._spaced_print( 639 "Start running {}-th batch {} data for {}......".format( 640 idx, bench.__class__.__name__, explainer.__class__.__name__)) 641 self._run_exp_benchmark_step(next_element, explainer, bench, saliency_dict_lst) 642 self._spaced_print( 643 "Finish running {}-th batch {} data for {}. Time elapsed: {:.3f} s".format( 644 idx, bench.__class__.__name__, explainer.__class__.__name__, time() - now)) 645 646 def _run_saliency(self, summary, sample_id_labels): 647 """Run the saliency explanations.""" 648 649 for explainer in self._explainers: 650 explain = Explain() 651 if self._benchmarkers: 652 for bench in self._benchmarkers: 653 bench.reset() 654 print(f"Start running and writing explanation for {explainer.__class__.__name__}......") 655 self._sample_index = 0 656 start = time() 657 ds.config.set_seed(self._DATASET_SEED) 658 self._run_explainer(summary, sample_id_labels, explainer) 659 660 if not self._benchmarkers: 661 continue 662 663 for bench in self._benchmarkers: 664 benchmark = explain.benchmark.add() 665 benchmark.explain_method = explainer.__class__.__name__ 666 benchmark.benchmark_method = bench.__class__.__name__ 667 668 benchmark.total_score = bench.performance 669 if isinstance(bench, LabelSensitiveMetric): 670 benchmark.label_score.extend(bench.class_performances) 671 672 self._spaced_print("Finish running and writing explanation and benchmark data for {}. " 673 "Time elapsed: {:.3f} s".format(explainer.__class__.__name__, time() - start)) 674 summary.add_value('explainer', 'benchmark', explain) 675 summary.record(1) 676 677 def _run_hoc(self, summary, sample_id, sample_input, prob): 678 """ 679 Run HOC search for a sample image, and then save the result to summary. 680 681 Args: 682 summary (SummaryRecord): The summary object to store the data. 683 sample_id (int): The sample ID. 684 sample_input (Union[Tensor, np.ndarray]): Sample image tensor in CHW or NCWH(N=1). 685 prob (Union[Tensor, np.ndarray]): List of sample's classification prediction output, HOC will run for 686 labels with prediction output strictly larger then HOC searcher's threshold(0.5 by default). 687 """ 688 if isinstance(sample_input, ms.Tensor): 689 sample_input = sample_input.asnumpy() 690 if len(sample_input.shape) == 3: 691 sample_input = np.expand_dims(sample_input, axis=0) 692 693 explain = None 694 str_mask = hoc.auto_str_mask(sample_input) 695 compiled_mask = None 696 697 for label_idx, label_prob in enumerate(prob): 698 if label_prob <= self._hoc_searcher.threshold: 699 continue 700 if compiled_mask is None: 701 compiled_mask = hoc.compile_mask(str_mask, sample_input) 702 try: 703 edit_tree, layer_outputs = self._hoc_searcher.search(sample_input, label_idx, compiled_mask) 704 except hoc.NoValidResultError: 705 log.warning(f"No Hierarchical Occlusion result was found in sample#{sample_id} " 706 f"label:{self._labels[label_idx]}, skipped.") 707 continue 708 709 if explain is None: 710 explain = Explain() 711 explain.sample_id = sample_id 712 713 self._add_hoc_result_to_explain(label_idx, str_mask, edit_tree, layer_outputs, explain) 714 715 if explain is not None: 716 summary.add_value("explainer", "hoc", explain) 717 summary.record(1) 718 self._manifest['hierarchical_occlusion'] = True 719 720 @staticmethod 721 def _add_hoc_result_to_explain(label_idx, str_mask, edit_tree, layer_outputs, explain): 722 """ 723 Add HOC result to Explain record. 724 725 Args: 726 label_idx (int): The label index. 727 str_mask (str): The mask string. 728 edit_tree (EditStep): The result HOC edit tree. 729 layer_outputs (list[float]): The network output confident of each layer. 730 explain (Explain): The Explain record. 731 """ 732 hoc_rec = explain.hoc.add() 733 hoc_rec.label = label_idx 734 hoc_rec.mask = str_mask 735 layer_count = edit_tree.max_layer + 1 736 for layer in range(layer_count): 737 steps = edit_tree.get_layer_or_leaf_steps(layer) 738 layer_output = layer_outputs[layer] 739 hoc_layer = hoc_rec.layer.add() 740 hoc_layer.prob = layer_output 741 for step in steps: 742 hoc_layer.box.extend(list(step.box)) 743 744 def _add_exp_step_samples(self, explainer, sample_label_sets, batch_saliency_full, summary): 745 """ 746 Add explanation results of samples to summary record. 747 748 Args: 749 explainer (Attribution): The explainer to be run. 750 sample_label_sets (list[list[int]]): The label sets of samples. 751 batch_saliency_full (Tensor): The saliency output from explainer. 752 summary (SummaryRecord): The summary record. 753 """ 754 saliency_dict_lst = [] 755 has_saliency_rec = False 756 for idx, label_set in enumerate(sample_label_sets): 757 saliency_dict = {} 758 explain = Explain() 759 explain.sample_id = self._sample_index 760 for k, lab in enumerate(label_set): 761 saliency = batch_saliency_full[idx:idx + 1, k:k + 1] 762 saliency_dict[lab] = saliency 763 764 saliency_np = _normalize(saliency.asnumpy().squeeze()) 765 saliency_image = _np_to_image(saliency_np, mode='L') 766 heatmap_path = self._save_heatmap(explainer.__class__.__name__, lab, 767 self._sample_index, saliency_image) 768 769 explanation = explain.explanation.add() 770 explanation.explain_method = explainer.__class__.__name__ 771 explanation.heatmap_path = heatmap_path 772 explanation.label = lab 773 774 has_saliency_rec = True 775 776 summary.add_value("explainer", "explanation", explain) 777 summary.record(1) 778 779 self._sample_index += 1 780 saliency_dict_lst.append(saliency_dict) 781 782 return saliency_dict_lst, has_saliency_rec 783 784 def _run_exp_step(self, next_element, explainer, sample_id_labels, summary): 785 """ 786 Run the explanation for each step and write explanation results into summary. 787 788 Args: 789 next_element (Tuple): Data of one step 790 explainer (_Attribution): An Attribution object to generate saliency maps. 791 sample_id_labels (dict): A dict that maps the sample id and its union labels. 792 summary (SummaryRecord): The summary object to store the data. 793 794 Returns: 795 list, List of dict that maps label to its corresponding saliency map. 796 """ 797 inputs, labels, _ = self._unpack_next_element(next_element) 798 sample_index = self._sample_index 799 sample_label_sets = [] 800 for _ in range(len(labels)): 801 sample_label_sets.append(sample_id_labels[str(sample_index)]) 802 sample_index += 1 803 804 batch_label_sets = self._make_label_batch(sample_label_sets) 805 806 if isinstance(explainer, RISE): 807 batch_saliency_full = explainer(inputs, batch_label_sets) 808 else: 809 batch_saliency_full = [] 810 for i in range(len(batch_label_sets[0])): 811 batch_saliency = explainer(inputs, batch_label_sets[:, i]) 812 batch_saliency_full.append(batch_saliency) 813 concat = ms.ops.operations.Concat(1) 814 batch_saliency_full = concat(tuple(batch_saliency_full)) 815 816 saliency_dict_lst, has_saliency_rec = \ 817 self._add_exp_step_samples(explainer, sample_label_sets, batch_saliency_full, summary) 818 819 if has_saliency_rec: 820 self._manifest['saliency_map'] = True 821 822 return saliency_dict_lst 823 824 def _run_exp_benchmark_step(self, next_element, explainer, benchmarker, saliency_dict_lst): 825 """Run the explanation and evaluation for each step and write explanation results into summary.""" 826 inputs, labels, _ = self._unpack_next_element(next_element) 827 for idx, inp in enumerate(inputs): 828 inp = _EXPAND_DIMS(inp, 0) 829 self._manifest['benchmark'] = True 830 if isinstance(benchmarker, LabelAgnosticMetric): 831 res = benchmarker.evaluate(explainer, inp) 832 benchmarker.aggregate(res) 833 continue 834 saliency_dict = saliency_dict_lst[idx] 835 for label, saliency in saliency_dict.items(): 836 if isinstance(benchmarker, Localization): 837 _, _, bboxes = self._unpack_next_element(next_element, True) 838 if label in labels[idx]: 839 res = benchmarker.evaluate(explainer, inp, targets=label, mask=bboxes[idx][label], 840 saliency=saliency) 841 benchmarker.aggregate(res, label) 842 elif isinstance(benchmarker, LabelSensitiveMetric): 843 res = benchmarker.evaluate(explainer, inp, targets=label, saliency=saliency) 844 benchmarker.aggregate(res, label) 845 else: 846 raise TypeError('Benchmarker must be one of LabelSensitiveMetric or LabelAgnosticMetric, but' 847 'receive {}'.format(type(benchmarker))) 848 849 @staticmethod 850 def _calc_beta_intervals(means, variances, prob=0.95): 851 """Calculate confidence interval of beta distributions.""" 852 if not isinstance(means, np.ndarray): 853 means = np.array(means) 854 if not isinstance(variances, np.ndarray): 855 variances = np.array(variances) 856 with np.errstate(divide='ignore'): 857 coef_a = ((means ** 2) * (1 - means) / variances) - means 858 coef_b = (coef_a * (1 - means)) / means 859 itl_lows, itl_his = beta.interval(prob, coef_a, coef_b) 860 sds = np.sqrt(variances) 861 for i in range(itl_lows.shape[0]): 862 if not np.isfinite(sds[i]) or not np.isfinite(itl_lows[i]) or not np.isfinite(itl_his[i]): 863 itl_lows[i] = means[i] 864 itl_his[i] = means[i] 865 sds[i] = 0 866 return itl_lows, itl_his, sds 867 868 def _transform_bboxes(self, inputs, labels, bboxes, ifbbox): 869 """ 870 Transform the bounding boxes. 871 Args: 872 inputs (Tensor): the image data 873 labels (Tensor): the labels 874 bboxes (Tensor): the boudnding boxes data 875 ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t 876 label id will be returned. If False, the returned bboxes is the the parsed bboxes. 877 878 Returns: 879 bboxes (Union[list[dict], None, Tensor]): the bounding boxes 880 """ 881 input_len = len(inputs) 882 if bboxes is None or not ifbbox: 883 return bboxes 884 bboxes = ms.Tensor(bboxes, ms.int32) 885 masks_lst = [] 886 labels = labels.asnumpy().reshape([input_len, -1]) 887 bboxes = bboxes.asnumpy().reshape([input_len, -1, 4]) 888 for idx, label in enumerate(labels): 889 height, width = inputs[idx].shape[-2], inputs[idx].shape[-1] 890 masks = {} 891 for j, label_item in enumerate(label): 892 target = int(label_item) 893 if not -1 < target < len(self._labels): 894 continue 895 if target not in masks: 896 mask = np.zeros((1, 1, height, width)) 897 else: 898 mask = masks[target] 899 x_min, y_min, x_len, y_len = bboxes[idx][j].astype(int) 900 mask[:, :, x_min:x_min + x_len, y_min:y_min + y_len] = 1 901 masks[target] = mask 902 masks_lst.append(masks) 903 bboxes = masks_lst 904 return bboxes 905 906 def _transform_data(self, inputs, labels, bboxes, ifbbox): 907 """ 908 Transform the data from one iteration of dataset to a unifying form for the follow-up operations. 909 910 Args: 911 inputs (Tensor): the image data 912 labels (Tensor): the labels 913 bboxes (Tensor): the boudnding boxes data 914 ifbbox (bool): whether to preprocess bboxes. If True, a dictionary that indicates bounding boxes w.r.t 915 label id will be returned. If False, the returned bboxes is the the parsed bboxes. 916 917 Returns: 918 inputs (Tensor): the image data, unified to a 4D Tensor. 919 labels (list[list[int]]): the ground truth labels. 920 bboxes (Union[list[dict], None, Tensor]): the bounding boxes 921 """ 922 inputs = ms.Tensor(inputs, ms.float32) 923 if len(inputs.shape) == 3: 924 inputs = _EXPAND_DIMS(inputs, 0) 925 if isinstance(labels, ms.Tensor): 926 labels = ms.Tensor(labels, ms.int32) 927 labels = _EXPAND_DIMS(labels, 0) 928 if isinstance(bboxes, ms.Tensor): 929 bboxes = ms.Tensor(bboxes, ms.int32) 930 bboxes = _EXPAND_DIMS(bboxes, 0) 931 932 bboxes = self._transform_bboxes(inputs, labels, bboxes, ifbbox) 933 934 labels = ms.Tensor(labels, ms.int32) 935 if len(labels.shape) == 1: 936 labels_lst = [[int(i)] for i in labels.asnumpy()] 937 else: 938 labels = labels.asnumpy().reshape([len(inputs), -1]) 939 labels_lst = [] 940 for item in labels: 941 labels_lst.append(list(set(int(i) for i in item if -1 < int(i) < len(self._labels)))) 942 labels = labels_lst 943 return inputs, labels, bboxes 944 945 def _unpack_next_element(self, next_element, ifbbox=False): 946 """ 947 Unpack a single iteration of dataset. 948 949 Args: 950 next_element (Tuple): a single element iterated from dataset object. 951 ifbbox (bool): whether to preprocess bboxes in self._transform_data. 952 953 Returns: 954 tuple, a unified Tuple contains image_data, labels, and bounding boxes. 955 """ 956 if len(next_element) == 3: 957 inputs, labels, bboxes = next_element 958 elif len(next_element) == 2: 959 inputs, labels = next_element 960 bboxes = None 961 else: 962 inputs = next_element[0] 963 labels = [[] for _ in inputs] 964 bboxes = None 965 inputs, labels, bboxes = self._transform_data(inputs, labels, bboxes, ifbbox) 966 return inputs, labels, bboxes 967 968 @staticmethod 969 def _make_label_batch(labels): 970 """ 971 Unify a List of List of labels to be a 2D Tensor with shape (b, m), where b = len(labels) and m is the max 972 length of all the rows in labels. 973 974 Args: 975 labels (List[List]): the union labels of a data batch. 976 977 Returns: 978 2D Tensor. 979 """ 980 max_len = max([len(label) for label in labels]) 981 batch_labels = np.zeros((len(labels), max_len)) 982 983 for idx, _ in enumerate(batch_labels): 984 length = len(labels[idx]) 985 batch_labels[idx, :length] = np.array(labels[idx]) 986 987 return ms.Tensor(batch_labels, ms.int32) 988 989 def _save_manifest(self): 990 """Save manifest.json underneath datafile directory.""" 991 if self._manifest is None: 992 raise RuntimeError("Manifest not yet be initialized.") 993 path_tokens = [self._summary_dir, 994 self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp)] 995 abs_dir_path = self._create_subdir(*path_tokens) 996 save_path = os.path.join(abs_dir_path, self._MANIFEST_FILENAME) 997 fd = os.open(save_path, os.O_WRONLY | os.O_CREAT, mode=self._FILE_MODE) 998 file = os.fdopen(fd, "w") 999 try: 1000 json.dump(self._manifest, file, indent=4) 1001 except IOError: 1002 log.error(f"Failed to save manifest as {save_path}!") 1003 raise 1004 finally: 1005 file.flush() 1006 os.close(fd) 1007 os.chmod(save_path, self._FILE_MODE) 1008 1009 def _save_original_image(self, sample_id, image): 1010 """Save an image to summary directory.""" 1011 id_dirname = self._get_sample_dirname(sample_id) 1012 path_tokens = [self._summary_dir, 1013 self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), 1014 self._ORIGINAL_IMAGE_DIRNAME, 1015 id_dirname] 1016 1017 abs_dir_path = self._create_subdir(*path_tokens) 1018 filename = f"{sample_id}.jpg" 1019 save_path = os.path.join(abs_dir_path, filename) 1020 image.save(save_path) 1021 os.chmod(save_path, self._FILE_MODE) 1022 return os.path.join(*path_tokens[1:], filename) 1023 1024 def _save_heatmap(self, explain_method, class_id, sample_id, image): 1025 """Save heatmap image to summary directory.""" 1026 id_dirname = self._get_sample_dirname(sample_id) 1027 path_tokens = [self._summary_dir, 1028 self._DATAFILE_DIRNAME_PREFIX + str(self._summary_timestamp), 1029 self._HEATMAP_DIRNAME, 1030 explain_method, 1031 id_dirname] 1032 1033 abs_dir_path = self._create_subdir(*path_tokens) 1034 filename = f"{sample_id}_{class_id}.jpg" 1035 save_path = os.path.join(abs_dir_path, filename) 1036 image.save(save_path, optimize=True) 1037 os.chmod(save_path, self._FILE_MODE) 1038 return os.path.join(*path_tokens[1:], filename) 1039 1040 def _create_subdir(self, *args): 1041 """Recursively create subdirectories.""" 1042 abs_path = None 1043 for token in args: 1044 if abs_path is None: 1045 abs_path = os.path.realpath(token) 1046 else: 1047 abs_path = os.path.join(abs_path, token) 1048 # os.makedirs() don't set intermediate dir permission properly, we mkdir() one by one 1049 try: 1050 os.mkdir(abs_path, mode=self._DIR_MODE) 1051 # In some platform, mode may be ignored in os.mkdir(), we have to chmod() again to make sure 1052 os.chmod(abs_path, mode=self._DIR_MODE) 1053 except FileExistsError: 1054 pass 1055 return abs_path 1056 1057 @classmethod 1058 def _get_sample_dirname(cls, sample_id): 1059 """Get the name of parent directory of the image id.""" 1060 return str(int(sample_id / cls._SAMPLE_PER_DIR) * cls._SAMPLE_PER_DIR) 1061 1062 @staticmethod 1063 def _extract_timestamp(filename): 1064 """Extract timestamp from summary filename.""" 1065 matched = re.search(r"summary\.(\d+)", filename) 1066 if matched: 1067 return int(matched.group(1)) 1068 return None 1069 1070 @classmethod 1071 def _spaced_print(cls, message): 1072 """Spaced message printing.""" 1073 # workaround to print logs starting new line in case line width mismatch. 1074 print(cls._SPACER.format(message)) 1075