1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Hierarchical occlusion edit tree searcher.""" 16from enum import Enum 17import copy 18import re 19import math 20 21import numpy as np 22from scipy.ndimage import gaussian_filter 23 24from mindspore import nn 25from mindspore import Tensor 26from mindspore.ops import Squeeze 27from mindspore.train._utils import check_value_type 28from mindspore.explainer._utils import deprecated_error 29 30 31AUTO_LAYER_MAX = 3 # maximum number of layer by auto settings 32AUTO_WIN_SIZE_MIN = 28 # minimum window size by auto settings 33AUTO_WIN_SIZE_DIV = 2 # denominator of windows size calculations by auto settings 34AUTO_STRIDE_DIV = 5 # denominator of stride calculations by auto settings 35AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25 # denominator of gaussian mask radius calculations by auto settings 36DEFAULT_THRESHOLD = 0.5 # default target prediction threshold 37DEFAULT_BATCH_SIZE = 64 # default batch size for batch inference search 38MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' # gaussian mask string pattern 39 40# minimum length of input images' short side with auto settings 41AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV 42 43 44@deprecated_error 45def is_valid_str_mask(mask): 46 """Check if it is a valid string mask.""" 47 check_value_type('mask', mask, str) 48 match = re.match(MASK_GAUSSIAN_RE, mask) 49 return match and int(match.group(1)) > 0 50 51 52@deprecated_error 53def compile_mask(mask, image): 54 """Compile mask to a ready to use object.""" 55 if mask is None: 56 return compile_str_mask(auto_str_mask(image), image) 57 check_value_type('mask', mask, (str, tuple, float, np.ndarray)) 58 if isinstance(mask, str): 59 return compile_str_mask(mask, image) 60 61 if isinstance(mask, tuple): 62 _check_iterable_type('mask', mask, tuple, float) 63 elif isinstance(mask, np.ndarray): 64 if len(image.shape) == 4 and len(mask.shape) == 3: 65 mask = np.expand_dims(mask, axis=0) 66 elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1: 67 mask = mask.squeeze(0) 68 if image.shape != mask.shape: 69 raise ValueError("Image and mask is not match in shape.") 70 return mask 71 72 73@deprecated_error 74def auto_str_mask(image): 75 """Generate auto string mask for the image.""" 76 check_value_type('image', image, np.ndarray) 77 short_side = np.min(image.shape[-2:]) 78 radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV)) 79 if radius == 0: 80 raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, " 81 f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.") 82 return f'gaussian:{radius}' 83 84 85@deprecated_error 86def compile_str_mask(mask, image): 87 """Concert string mask to numpy.ndarray.""" 88 check_value_type('mask', mask, str) 89 check_value_type('image', image, np.ndarray) 90 match = re.match(MASK_GAUSSIAN_RE, mask) 91 if match: 92 radius = int(match.group(1)) 93 if radius > 0: 94 sigma = [0] * len(image.shape) 95 sigma[-2] = radius 96 sigma[-1] = radius 97 return gaussian_filter(image, sigma=sigma, mode='nearest') 98 raise ValueError(f"Invalid string mask: '{mask}'.") 99 100 101@deprecated_error 102class EditStep: 103 """ 104 Edit step that describes a box region, also represents an edit tree. 105 106 Args: 107 layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer. 108 box (tuple[int, int, int, int]): Tuple of x, y, width, height. 109 """ 110 def __init__(self, layer, box): 111 self.layer = layer 112 self.box = box 113 self.network_output = 0 114 self.step_change = 0 115 self.children = None 116 117 @property 118 def x(self): 119 """X-coordinate of the box.""" 120 return self.box[0] 121 122 @property 123 def y(self): 124 """Y-coordinate of the box.""" 125 return self.box[1] 126 127 @property 128 def width(self): 129 """Width of the box.""" 130 return self.box[2] 131 132 @property 133 def height(self): 134 """Height of the box.""" 135 return self.box[3] 136 137 @property 138 def is_leaf(self): 139 """Returns True if no child edit step.""" 140 return not self.children 141 142 @property 143 def leaf_steps(self): 144 """Returns all leaf edit steps in the tree.""" 145 if self.is_leaf: 146 return [self] 147 steps = [] 148 for child in self.children: 149 steps.extend(child.leaf_steps) 150 return steps 151 152 @property 153 def max_layer(self): 154 """Maximum layer number in the edit tree.""" 155 if self.is_leaf: 156 return self.layer 157 layer = self.layer 158 for child in self.children: 159 child_max_layer = child.max_layer 160 if child_max_layer > layer: 161 layer = child_max_layer 162 return layer 163 164 def add_child(self, child): 165 """Add a child edit step.""" 166 if self.children is None: 167 self.children = [child] 168 else: 169 self.children.append(child) 170 171 def remove_all_children(self): 172 """Remove all child steps.""" 173 self.children = None 174 175 def get_layer_or_leaf_steps(self, layer): 176 """Get all edit steps of the layer and all leaf edit steps above the layer.""" 177 if self.layer == layer or (self.layer < layer and self.is_leaf): 178 return [self] 179 steps = [] 180 if self.layer < layer and self.children: 181 for child in self.children: 182 steps.extend(child.get_layer_or_leaf_steps(layer)) 183 return steps 184 185 def get_layer_steps(self, layer): 186 """Get all edit steps of the layer.""" 187 if self.layer == layer: 188 return [self] 189 steps = [] 190 if self.layer < layer and self.children: 191 for child in self.children: 192 steps.extend(child.get_layer_steps(layer)) 193 return steps 194 195 @classmethod 196 def apply(cls, 197 image, 198 mask, 199 edit_steps, 200 by_masking=False, 201 inplace=False): 202 """ 203 Apply edit steps. 204 205 Args: 206 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 207 mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be 208 str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. 209 tuple[float, float, float]: RGB solid color mask, 210 float: Grey scale solid color mask. 211 numpy.ndarray: Image mask in CHW or NCHW(N=1) format. 212 edit_steps (list[EditStep], optional): Edit steps to be applied. 213 by_masking (bool): Whether it is masking mode. 214 inplace (bool): Whether the modification is going to take place in the input image tensor. False to 215 construct a new image tensor as result. 216 217 Returns: 218 numpy.ndarray, the result image tensor. 219 220 Raises: 221 TypeError: Be raised for any argument or data type problem. 222 ValueError: Be raised for any argument or data value problem. 223 """ 224 if by_masking: 225 return cls.apply_masking(image, mask, edit_steps, inplace) 226 return cls.apply_unmasking(image, mask, edit_steps, inplace) 227 228 @classmethod 229 def apply_masking(cls, 230 image, 231 mask, 232 edit_steps, 233 inplace=False): 234 """ 235 Apply edit steps in masking mode. 236 237 Args: 238 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 239 mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be 240 str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. 241 tuple[float, float, float]: RGB solid color mask, 242 float: Grey scale solid color mask. 243 numpy.ndarray: Image mask in CHW or NCHW(N=1) format. 244 edit_steps (list[EditStep], optional): Edit steps to be applied. 245 inplace (bool): Whether the modification is going to take place in the input image tensor. False to 246 construct a new image tensor as result. 247 248 Returns: 249 numpy.ndarray, the result image tensor. 250 251 Raises: 252 TypeError: Be raised for any argument or data type problem. 253 ValueError: Be raised for any argument or data value problem. 254 """ 255 256 cls._apply_check_args(image, mask, edit_steps) 257 258 mask = compile_mask(mask, image) 259 260 background = image if inplace else np.copy(image) 261 262 if not edit_steps: 263 return background 264 265 for step in edit_steps: 266 267 x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2]) 268 269 if x_max <= step.x or y_max <= step.y: 270 continue 271 272 if isinstance(mask, np.ndarray): 273 background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max] 274 else: 275 if isinstance(mask, (int, float)): 276 mask = (mask, mask, mask) 277 for c in range(3): 278 background[..., c, step.y:y_max, step.x:x_max] = mask[c] 279 return background 280 281 @classmethod 282 def apply_unmasking(cls, 283 image, 284 mask, 285 edit_steps, 286 inplace=False): 287 """ 288 Apply edit steps in unmasking mode. 289 290 Args: 291 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 292 mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be 293 str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. 294 tuple[float, float, float]: RGB solid color mask, 295 float: Grey scale solid color mask. 296 numpy.ndarray: Image mask in CHW or NCHW(N=1) format. 297 edit_steps (list[EditStep]): Edit steps to be applied. 298 inplace (bool): Whether the modification is going to take place in the input mask tensor. False to 299 construct a new image tensor as result. 300 301 Returns: 302 numpy.ndarray, the result image tensor. 303 304 Raises: 305 TypeError: Be raised for any argument or data type problem. 306 ValueError: Be raised for any argument or data value problem. 307 """ 308 309 cls._apply_check_args(image, mask, edit_steps) 310 311 mask = compile_mask(mask, image) 312 313 if isinstance(mask, np.ndarray): 314 if inplace: 315 background = mask 316 else: 317 background = np.copy(mask) 318 else: 319 if inplace: 320 raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray') 321 322 background = np.zeros_like(image) 323 if isinstance(mask, (int, float)): 324 background.fill(mask) 325 else: 326 for c in range(3): 327 background[..., c, :, :] = mask[c] 328 329 if not edit_steps: 330 return background 331 332 for step in edit_steps: 333 334 x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2]) 335 336 if x_max <= step.x or y_max <= step.y: 337 continue 338 339 background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max] 340 341 return background 342 343 @staticmethod 344 def _apply_check_args(image, mask, edit_steps): 345 """ 346 Check arguments for apply edit steps. 347 348 Args: 349 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 350 mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be 351 str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. 352 tuple[float, float, float]: RGB solid color mask, 353 float: Grey scale solid color mask. 354 numpy.ndarray: Image mask in CHW or NCHW(N=1) format. 355 edit_steps (list[EditStep], optional): Edit steps to be applied. 356 357 Raises: 358 TypeError: Be raised for any argument or data type problem. 359 ValueError: Be raised for any argument or data value problem. 360 """ 361 check_value_type('image', image, np.ndarray) 362 check_value_type('mask', mask, (str, tuple, float, np.ndarray)) 363 if isinstance(mask, tuple): 364 _check_iterable_type('mask', mask, tuple, float) 365 366 if edit_steps is not None: 367 _check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep) 368 369 @staticmethod 370 def _get_step_xy_max(step, x_limit, y_limit): 371 """Get the step x and y max. position.""" 372 x_max = step.x + step.width 373 y_max = step.y + step.height 374 375 if x_max > x_limit: 376 x_max = x_limit 377 378 if y_max > y_limit: 379 y_max = y_limit 380 return x_max, y_max 381 382 383class NoValidResultError(RuntimeError): 384 """Error for no edit step layer's network output meet the threshold.""" 385 386 387class OriginalOutputError(RuntimeError): 388 """Error for network output of the original image is not strictly larger than the threshold.""" 389 390 391@deprecated_error 392class Searcher: 393 """ 394 Edit step searcher. 395 396 Args: 397 network (Cell): Image tensor in CHW or NCHW(N=1) format. 398 win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers, 399 None means by auto calcuation. 400 strides (Union(list[int], optional): Stride of layers, None means by auto calcuation. 401 threshold (float): Threshold network output value of the target class. 402 by_masking (bool): Whether it is masking mode. 403 404 Raises: 405 ValueError: Be raised for any data or settings' value problem. 406 TypeError: Be raised for any data or settings' type problem. 407 RuntimeError: Be raised if this function was invoked before. 408 409 Supported Platforms: 410 ``Ascend`` ``GPU`` 411 """ 412 413 def __init__(self, 414 network, 415 win_sizes=None, 416 strides=None, 417 threshold=DEFAULT_THRESHOLD, 418 by_masking=False): 419 420 check_value_type('network', network, nn.Cell) 421 422 if win_sizes is not None: 423 _check_iterable_type('win_sizes', win_sizes, list, int) 424 if not win_sizes: 425 raise ValueError('Argument win_sizes is empty.') 426 427 for i in range(1, len(win_sizes)): 428 if win_sizes[i] >= win_sizes[i-1]: 429 raise ValueError('Argument win_sizes is not strictly descending.') 430 431 if win_sizes[-1] <= 0: 432 raise ValueError('Argument win_sizes has non-positive number.') 433 elif strides is not None: 434 raise ValueError('Argument win_sizes cannot be None if strides is not None.') 435 436 if strides is not None: 437 _check_iterable_type('strides', strides, list, int) 438 for i in range(1, len(strides)): 439 if strides[i] >= strides[i-1]: 440 raise ValueError('Argument win_sizes is not strictly descending.') 441 442 if strides[-1] <= 0: 443 raise ValueError('Argument strides has non-positive number.') 444 445 if len(strides) != len(win_sizes): 446 raise ValueError('Length of strides and win_sizes is not equal.') 447 elif win_sizes is not None: 448 raise ValueError('Argument strides cannot be None if win_sizes is not None.') 449 450 self._network = copy.deepcopy(network) 451 self._compiled_mask = None 452 self._threshold = threshold 453 self._win_sizes = copy.copy(win_sizes) if win_sizes else None 454 self._strides = copy.copy(strides) if strides else None 455 self._by_masking = by_masking 456 457 @property 458 def network(self): 459 """Get the network.""" 460 return self._network 461 462 @property 463 def by_masking(self): 464 """Check if it is masking mode.""" 465 return self._by_masking 466 467 @property 468 def threshold(self): 469 """The network output threshold to stop the search.""" 470 return self._threshold 471 472 @property 473 def win_sizes(self): 474 """Windows sizes in pixels.""" 475 return self._win_sizes 476 477 @property 478 def strides(self): 479 """Strides in pixels.""" 480 return self._strides 481 482 @property 483 def compiled_mask(self): 484 """The compiled mask after a successful search() call.""" 485 return self._compiled_mask 486 487 def search(self, image, class_idx, mask=None): 488 """ 489 Search smallest sufficient/destruction region on an image. 490 491 Args: 492 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 493 class_idx (int): Target class index. 494 mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be 495 str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. 496 tuple[float, float, float]: RGB solid color mask, 497 float: Grey scale solid color mask. 498 None: By auto calculation. 499 500 Returns: 501 tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the 502 layer steps. 503 504 Raises: 505 TypeError: Be raised for any argument or data type problem. 506 ValueError: Be raised for any argument or data value problem. 507 NoValidResultError: Be raised if no valid result was found. 508 OriginalOutputError: Be raised if network output of the original image is not strictly larger than 509 the threshold. 510 """ 511 check_value_type('image', image, (Tensor, np.ndarray)) 512 513 if isinstance(image, Tensor): 514 image = image.asnumpy() 515 516 if len(image.shape) == 4: 517 if image.shape[0] != 1: 518 raise ValueError("Argument image's batch size is not 1.") 519 elif len(image.shape) == 3: 520 image = np.expand_dims(image, axis=0) 521 else: 522 raise ValueError("Argument image is not in CHW or NCHW(N=1) format.") 523 524 check_value_type('class_idx', class_idx, int) 525 526 if class_idx < 0: 527 raise ValueError("Argument class_idx is less then zero.") 528 529 self._compiled_mask = compile_mask(mask, image) 530 531 short_side = np.min(image.shape[-2:]) 532 if self._win_sizes is None: 533 win_sizes, strides = self._auto_win_sizes_strides(short_side) 534 else: 535 win_sizes, strides = self._win_sizes, self._strides 536 537 if short_side <= win_sizes[0]: 538 raise ValueError(f"Input image's short side is shorter then or " 539 f"equals to the first window size:{win_sizes[0]}.") 540 541 self._network.set_train(False) 542 543 # the search result will be store as a edit tree that attached to the root step. 544 root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2])) 545 root_job = _SearchJob(by_masking=self._by_masking, 546 class_idx=class_idx, 547 win_sizes=win_sizes, 548 strides=strides, 549 layer=0, 550 search_field=root_step.box, 551 pre_edit_steps=None, 552 parent_step=root_step) 553 self._process_root_job(image, root_job) 554 return self._touch_result(image, class_idx, root_step) 555 556 def _touch_result(self, image, class_idx, root_step): 557 """ 558 Final treatment to the search result. 559 560 Args: 561 image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. 562 class_idx (int): Target class index. 563 root_step (EditStep): The searched root step. 564 565 Returns: 566 tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the 567 layer steps. 568 569 Raises: 570 NoValidResultError: Be raised if no valid result was found. 571 """ 572 # the leaf layer's network output may not meet the threshold, 573 # we have to cutoff the unqualified layers 574 layer_count = root_step.max_layer + 1 575 if layer_count == 0: 576 raise NoValidResultError("No edit step layer was found.") 577 578 # gather the network output of each layer 579 layer_outputs = [None] * layer_count 580 for layer in range(layer_count): 581 steps = root_step.get_layer_or_leaf_steps(layer) 582 if not steps: 583 continue 584 masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking) 585 output = self._network(Tensor(masked_image)) 586 output = output[0, class_idx].asnumpy().item() 587 layer_outputs[layer] = output 588 589 # determine which layer we have to cutoff 590 cutoff_layer = None 591 for layer in reversed(range(layer_count)): 592 if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]): 593 cutoff_layer = layer 594 break 595 596 if cutoff_layer is None or root_step.is_leaf: 597 raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.") 598 599 # cutoff the layer by removing all children of the layer's steps. 600 steps = root_step.get_layer_steps(cutoff_layer) 601 for step in steps: 602 step.remove_all_children() 603 layer_outputs = layer_outputs[:cutoff_layer + 1] 604 605 return root_step, layer_outputs 606 607 def _process_root_job(self, sample_input, root_job): 608 """ 609 Process job queue. 610 611 Args: 612 sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. 613 root_job (_SearchJob): Root search job. 614 """ 615 job_queue = [root_job] 616 while job_queue: 617 job = job_queue.pop(0) 618 sub_job_queue = [] 619 job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue) 620 621 if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET): 622 for step in job_edit_steps: 623 job.parent_step.add_child(step) 624 job_queue.extend(sub_job_queue) 625 626 def _prepare_job(self, job, sample_input): 627 """ 628 Prepare a job for process. 629 630 Args: 631 job (_SearchJob): Search job to be processed. 632 sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. 633 634 Returns: 635 numpy.ndarray, the image tensor workpiece. 636 637 Raises: 638 OriginalOutputError: Be raised if network output of the original image is not strictly larger than the 639 threshold. 640 """ 641 # make the network output with the original image is strictly larger than the threshold 642 if job.layer == 0: 643 original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item() 644 if original_output <= self._threshold: 645 raise OriginalOutputError(f'The original output is not strictly larger the threshold: ' 646 f'{self._threshold}') 647 648 # applying the pre-edit steps from the parent steps 649 if job.pre_edit_steps: 650 # use the latest leaf steps to increase the accuracy 651 leaf_steps = [] 652 for step in job.pre_edit_steps: 653 leaf_steps.extend(step.leaf_steps) 654 pre_edit_steps = leaf_steps 655 else: 656 pre_edit_steps = None 657 workpiece = EditStep.apply(sample_input, 658 self._compiled_mask, 659 pre_edit_steps, 660 self._by_masking) 661 662 job.on_start(sample_input, workpiece, self._compiled_mask, self._network) 663 return workpiece 664 665 def _process_job(self, job, sample_input, job_queue): 666 """ 667 Process a job. 668 669 Args: 670 job (_SearchJob): Search job to be processed. 671 sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. 672 job_queue (list[_SearchJob]): Job queue. 673 674 Returns: 675 tuple[list[EditStep], _StopReason], result edit stop and the stop reason. 676 677 Raises: 678 OriginalOutputError: Be raised if network output of the original image is not strictly larger than the 679 threshold. 680 """ 681 workpiece = self._prepare_job(job, sample_input) 682 683 start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item() 684 last_output = start_output 685 edit_steps = [] 686 # greedy search loop 687 while True: 688 689 if self._is_threshold_met(last_output): 690 return edit_steps, self._StopReason.THRESHOLD_MET 691 692 try: 693 best_edit = job.find_best_edit() 694 except _NoNewStepError: 695 return edit_steps, self._StopReason.NO_NEW_STEP 696 except _RepeatedStepError: 697 return edit_steps, self._StopReason.REPEATED_STEP 698 699 best_edit.step_change = best_edit.network_output - last_output 700 701 if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change): 702 # create net layer search job if new edit step is valid and not yet reaching 703 # the final layer 704 if job.pre_edit_steps: 705 pre_edit_steps = list(job.pre_edit_steps) 706 pre_edit_steps.extend(edit_steps) 707 else: 708 pre_edit_steps = list(edit_steps) 709 710 sub_job = job.create_sub_job(best_edit, pre_edit_steps) 711 job_queue.append(sub_job) 712 713 edit_steps.append(best_edit) 714 715 if job.layer > 0: 716 # stop if the step change meet the parent step change only after layer 0 717 change = best_edit.network_output - start_output 718 if self._is_step_change_met(job.parent_step.step_change, change): 719 return edit_steps, self._StopReason.STEP_CHANGE_MET 720 721 last_output = best_edit.network_output 722 723 def _is_threshold_met(self, network_output): 724 """Check if the threshold was met.""" 725 if self._by_masking: 726 return network_output <= self._threshold 727 return network_output >= self._threshold 728 729 def _is_step_change_met(self, target, step_change): 730 """Check if the change target was met.""" 731 if self._by_masking: 732 return step_change <= target 733 return step_change >= target 734 735 def _is_greedy(self, step_change): 736 """Check if it is a greedy step.""" 737 if self._by_masking: 738 return step_change < 0 739 return step_change > 0 740 741 @classmethod 742 def _auto_win_sizes_strides(cls, short_side): 743 """ 744 Calculate auto window sizes and strides. 745 746 Args: 747 short_side (int): Length of search space. 748 749 Returns: 750 tuple[list[int], list[int]], window sizes and strides. 751 """ 752 win_sizes = [] 753 strides = [] 754 cur_len = int(short_side/AUTO_WIN_SIZE_DIV) 755 while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN: 756 stride = int(cur_len/AUTO_STRIDE_DIV) 757 if stride <= 0: 758 break 759 win_sizes.append(cur_len) 760 strides.append(stride) 761 cur_len = int(cur_len/AUTO_WIN_SIZE_DIV) 762 if not win_sizes: 763 raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, " 764 f"unable to calculates auto settings.") 765 return win_sizes, strides 766 767 class _StopReason(Enum): 768 """Stop reason of search job.""" 769 THRESHOLD_MET = 0 # threshold was met. 770 STEP_CHANGE_MET = 1 # parent step change was met. 771 NO_NEW_STEP = 2 # no new step was found. 772 REPEATED_STEP = 3 # repeated step was found. 773 774 775def _check_iterable_type(arg_name, arg_value, container_type, elem_types): 776 """Concert iterable argument data type.""" 777 check_value_type(arg_name, arg_value, container_type) 778 for elem in arg_value: 779 check_value_type(arg_name + ' element', elem, elem_types) 780 781 782class _NoNewStepError(Exception): 783 """Error for no new step was found.""" 784 785 786class _RepeatedStepError(Exception): 787 """Error for repeated step was found.""" 788 789 790class _SearchJob: 791 """ 792 Search job. 793 794 Args: 795 by_masking (bool): Whether it is masking mode. 796 class_idx (int): Target class index. 797 win_sizes (list[int]): Moving square window size (length of side) of layers. 798 strides (list[int]): Strides of layers. 799 layer (int): Layer number. 800 search_field (tuple[int, int, int, int]): Search field in x, y, width, height format. 801 pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching. 802 parent_step (EditStep): Parent edit step. 803 batch_size (int): Batch size of batched inferences. 804 """ 805 806 def __init__(self, 807 by_masking, 808 class_idx, 809 win_sizes, 810 strides, 811 layer, 812 search_field, 813 pre_edit_steps, 814 parent_step, 815 batch_size=DEFAULT_BATCH_SIZE): 816 817 if layer >= len(win_sizes): 818 raise ValueError('Layer is larger then number of window sizes.') 819 820 self.by_masking = by_masking 821 self.class_idx = class_idx 822 self.win_sizes = win_sizes 823 self.strides = strides 824 self.layer = layer 825 self.search_field = search_field 826 self.pre_edit_steps = pre_edit_steps 827 self.parent_step = parent_step 828 self.batch_size = batch_size 829 self.network = None 830 self.mask = None 831 self.original_input = None 832 833 self._workpiece = None 834 self._found_best_edits = None 835 self._found_uvs = None 836 self._u_pixels = None 837 self._v_pixels = None 838 839 @property 840 def layer_count(self): 841 """Number of layers.""" 842 return len(self.win_sizes) 843 844 def on_start(self, original_input, workpiece, mask, network): 845 """ 846 Notification of the start of the search job. 847 848 Args: 849 original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format. 850 workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format. 851 mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be 852 tuple[float, float, float]: RGB solid color mask, 853 float: Grey scale solid color mask. 854 numpy.ndarray: Image mask, has same format of original_input. 855 network (nn.Cell): Classification network. 856 """ 857 self.original_input = original_input 858 self.mask = mask 859 self.network = network 860 861 self._workpiece = workpiece 862 self._found_best_edits = [] 863 self._found_uvs = [] 864 self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2]) 865 self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3]) 866 867 def create_sub_job(self, parent_step, pre_edit_steps): 868 """Create next layer search job.""" 869 return self.__class__(by_masking=self.by_masking, 870 class_idx=self.class_idx, 871 win_sizes=self.win_sizes, 872 strides=self.strides, 873 layer=self.layer + 1, 874 search_field=copy.copy(parent_step.box), 875 pre_edit_steps=pre_edit_steps, 876 parent_step=parent_step, 877 batch_size=self.batch_size) 878 879 def find_best_edit(self): 880 """ 881 Find the next best edit step. 882 883 Returns: 884 EditStep, the next best edit step. 885 """ 886 workpiece = self._workpiece 887 if len(workpiece.shape) == 3: 888 workpiece = np.expand_dims(workpiece, axis=0) 889 890 # generate input tensors with shifted masked/unmasked region and pack into a batch 891 best_new_workpiece = None 892 best_output = None 893 best_edit = None 894 best_uv = None 895 batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) 896 batch_uvs = [] 897 batch_steps = [] 898 batch_i = 0 899 win_size = self.win_sizes[self.layer] 900 for u, x in enumerate(self._u_pixels): 901 for v, y in enumerate(self._v_pixels): 902 if (u, v) in self._found_uvs: 903 continue 904 905 edit_step = EditStep(self.layer, (x, y, win_size, win_size)) 906 907 if self.by_masking: 908 EditStep.apply(batch[batch_i], 909 self.mask, 910 [edit_step], 911 self.by_masking, 912 inplace=True) 913 else: 914 EditStep.apply(self.original_input, 915 batch[batch_i], 916 [edit_step], 917 self.by_masking, 918 inplace=True) 919 920 batch_i += 1 921 batch_uvs.append((u, v)) 922 batch_steps.append(edit_step) 923 if batch_i != self.batch_size: 924 continue 925 926 # the batch is full, inference and empty it 927 updated = self._update_best(batch, batch_uvs, batch_steps, best_output) 928 if updated: 929 best_output, best_uv, best_edit, best_new_workpiece = updated 930 931 batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) 932 batch_uvs = [] 933 batch_i = 0 934 935 if batch_i > 0: 936 # don't forget the last half full batch 937 updated = self._update_best(batch, batch_uvs, batch_steps, best_output, batch_i) 938 if updated: 939 best_output, best_uv, best_edit, best_new_workpiece = updated 940 941 if best_edit is None: 942 raise _NoNewStepError 943 944 if best_uv in self._found_uvs: 945 raise _RepeatedStepError 946 947 self._found_uvs.append(best_uv) 948 self._found_best_edits.append(best_edit) 949 best_edit.network_output = best_output 950 951 # continue on the best workpiece in the next function call 952 self._workpiece = best_new_workpiece 953 954 return best_edit 955 956 def _update_best(self, batch, batch_uvs, batch_steps, best_output, batch_i=None): 957 """Update the best edit step.""" 958 squeeze = Squeeze() 959 batch_output = self.network(Tensor(batch)) 960 batch_output = batch_output[:, self.class_idx] 961 if len(batch_output.shape) > 1: 962 batch_output = squeeze(batch_output) 963 964 aggregation = np.argmin if self.by_masking else np.argmax 965 if batch_i is None: 966 batch_best_i = aggregation(batch_output.asnumpy()) 967 else: 968 batch_best_i = aggregation(batch_output.asnumpy()[:batch_i, ...]) 969 batch_best_output = batch_output[int(batch_best_i)].asnumpy().item() 970 971 if best_output is None or self._is_output0_better(batch_best_output, best_output): 972 best_output = batch_best_output 973 best_uv = batch_uvs[batch_best_i] 974 best_edit = batch_steps[batch_best_i] 975 best_new_workpiece = batch[batch_best_i] 976 return best_output, best_uv, best_edit, best_new_workpiece 977 return None 978 979 def _is_output0_better(self, output0, output1): 980 """Check if the network output0 is better.""" 981 if self.by_masking: 982 return output0 < output1 983 return output0 > output1 984 985 def _calc_uv_pixels(self, begin, length): 986 """ 987 Calculate the pixel coordinate of shifts. 988 989 Args: 990 begin (int): The beginning pixel coordinate of search field. 991 length (int): The length of search field. 992 993 Returns: 994 list[int], pixel coordinate of shifts. 995 """ 996 win_size = self.win_sizes[self.layer] 997 stride = self.strides[self.layer] 998 shift_count = self._calc_shift_count(length, win_size, stride) 999 pixels = [0] * shift_count 1000 for i in range(shift_count): 1001 if i == shift_count - 1: 1002 pixels[i] = begin + length - win_size 1003 else: 1004 pixels[i] = begin + i*stride 1005 return pixels 1006 1007 @staticmethod 1008 def _calc_shift_count(length, win_size, stride): 1009 """ 1010 Calculate the number of shifts in search field. 1011 1012 Args: 1013 length (int): The length of search field. 1014 win_size (int): The length of sides of moving window. 1015 stride (int): The stride. 1016 1017 Returns: 1018 int, number of shifts. 1019 """ 1020 if length <= win_size or win_size < stride or stride <= 0: 1021 raise ValueError("Invalid length, win_size or stride.") 1022 count = int(math.ceil((length - win_size)/stride)) 1023 if (count - 1)*stride + win_size < length: 1024 return count + 1 1025 return count 1026