• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Hierarchical occlusion edit tree searcher."""
16from enum import Enum
17import copy
18import re
19import math
20
21import numpy as np
22from scipy.ndimage import gaussian_filter
23
24from mindspore import nn
25from mindspore import Tensor
26from mindspore.ops import Squeeze
27from mindspore.train._utils import check_value_type
28from mindspore.explainer._utils import deprecated_error
29
30
31AUTO_LAYER_MAX = 3                       # maximum number of layer by auto settings
32AUTO_WIN_SIZE_MIN = 28                   # minimum window size by auto settings
33AUTO_WIN_SIZE_DIV = 2                    # denominator of windows size calculations by auto settings
34AUTO_STRIDE_DIV = 5                      # denominator of stride calculations by auto settings
35AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25       # denominator of gaussian mask radius calculations by auto settings
36DEFAULT_THRESHOLD = 0.5                  # default target prediction threshold
37DEFAULT_BATCH_SIZE = 64                  # default batch size for batch inference search
38MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$'   # gaussian mask string pattern
39
40# minimum length of input images' short side with auto settings
41AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV
42
43
44@deprecated_error
45def is_valid_str_mask(mask):
46    """Check if it is a valid string mask."""
47    check_value_type('mask', mask, str)
48    match = re.match(MASK_GAUSSIAN_RE, mask)
49    return match and int(match.group(1)) > 0
50
51
52@deprecated_error
53def compile_mask(mask, image):
54    """Compile mask to a ready to use object."""
55    if mask is None:
56        return compile_str_mask(auto_str_mask(image), image)
57    check_value_type('mask', mask, (str, tuple, float, np.ndarray))
58    if isinstance(mask, str):
59        return compile_str_mask(mask, image)
60
61    if isinstance(mask, tuple):
62        _check_iterable_type('mask', mask, tuple, float)
63    elif isinstance(mask, np.ndarray):
64        if len(image.shape) == 4 and len(mask.shape) == 3:
65            mask = np.expand_dims(mask, axis=0)
66        elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1:
67            mask = mask.squeeze(0)
68        if image.shape != mask.shape:
69            raise ValueError("Image and mask is not match in shape.")
70    return mask
71
72
73@deprecated_error
74def auto_str_mask(image):
75    """Generate auto string mask for the image."""
76    check_value_type('image', image, np.ndarray)
77    short_side = np.min(image.shape[-2:])
78    radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV))
79    if radius == 0:
80        raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, "
81                         f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.")
82    return f'gaussian:{radius}'
83
84
85@deprecated_error
86def compile_str_mask(mask, image):
87    """Concert string mask to numpy.ndarray."""
88    check_value_type('mask', mask, str)
89    check_value_type('image', image, np.ndarray)
90    match = re.match(MASK_GAUSSIAN_RE, mask)
91    if match:
92        radius = int(match.group(1))
93        if radius > 0:
94            sigma = [0] * len(image.shape)
95            sigma[-2] = radius
96            sigma[-1] = radius
97            return gaussian_filter(image, sigma=sigma, mode='nearest')
98    raise ValueError(f"Invalid string mask: '{mask}'.")
99
100
101@deprecated_error
102class EditStep:
103    """
104    Edit step that describes a box region, also represents an edit tree.
105
106    Args:
107        layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer.
108        box (tuple[int, int, int, int]): Tuple of x, y, width, height.
109    """
110    def __init__(self, layer, box):
111        self.layer = layer
112        self.box = box
113        self.network_output = 0
114        self.step_change = 0
115        self.children = None
116
117    @property
118    def x(self):
119        """X-coordinate of the box."""
120        return self.box[0]
121
122    @property
123    def y(self):
124        """Y-coordinate of the box."""
125        return self.box[1]
126
127    @property
128    def width(self):
129        """Width of the box."""
130        return self.box[2]
131
132    @property
133    def height(self):
134        """Height of the box."""
135        return self.box[3]
136
137    @property
138    def is_leaf(self):
139        """Returns True if no child edit step."""
140        return not self.children
141
142    @property
143    def leaf_steps(self):
144        """Returns all leaf edit steps in the tree."""
145        if self.is_leaf:
146            return [self]
147        steps = []
148        for child in self.children:
149            steps.extend(child.leaf_steps)
150        return steps
151
152    @property
153    def max_layer(self):
154        """Maximum layer number in the edit tree."""
155        if self.is_leaf:
156            return self.layer
157        layer = self.layer
158        for child in self.children:
159            child_max_layer = child.max_layer
160            if child_max_layer > layer:
161                layer = child_max_layer
162        return layer
163
164    def add_child(self, child):
165        """Add a child edit step."""
166        if self.children is None:
167            self.children = [child]
168        else:
169            self.children.append(child)
170
171    def remove_all_children(self):
172        """Remove all child steps."""
173        self.children = None
174
175    def get_layer_or_leaf_steps(self, layer):
176        """Get all edit steps of the layer and all leaf edit steps above the layer."""
177        if self.layer == layer or (self.layer < layer and self.is_leaf):
178            return [self]
179        steps = []
180        if self.layer < layer and self.children:
181            for child in self.children:
182                steps.extend(child.get_layer_or_leaf_steps(layer))
183        return steps
184
185    def get_layer_steps(self, layer):
186        """Get all edit steps of the layer."""
187        if self.layer == layer:
188            return [self]
189        steps = []
190        if self.layer < layer and self.children:
191            for child in self.children:
192                steps.extend(child.get_layer_steps(layer))
193        return steps
194
195    @classmethod
196    def apply(cls,
197              image,
198              mask,
199              edit_steps,
200              by_masking=False,
201              inplace=False):
202        """
203        Apply edit steps.
204
205        Args:
206            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
207            mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
208                str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
209                tuple[float, float, float]: RGB solid color mask,
210                float: Grey scale solid color mask.
211                numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
212            edit_steps (list[EditStep], optional): Edit steps to be applied.
213            by_masking (bool): Whether it is masking mode.
214            inplace (bool): Whether the modification is going to take place in the input image tensor. False to
215                construct a new image tensor as result.
216
217        Returns:
218            numpy.ndarray, the result image tensor.
219
220        Raises:
221            TypeError: Be raised for any argument or data type problem.
222            ValueError: Be raised for any argument or data value problem.
223        """
224        if by_masking:
225            return cls.apply_masking(image, mask, edit_steps, inplace)
226        return cls.apply_unmasking(image, mask, edit_steps, inplace)
227
228    @classmethod
229    def apply_masking(cls,
230                      image,
231                      mask,
232                      edit_steps,
233                      inplace=False):
234        """
235        Apply edit steps in masking mode.
236
237        Args:
238            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
239            mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
240                str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
241                tuple[float, float, float]: RGB solid color mask,
242                float: Grey scale solid color mask.
243                numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
244            edit_steps (list[EditStep], optional): Edit steps to be applied.
245            inplace (bool): Whether the modification is going to take place in the input image tensor. False to
246                construct a new image tensor as result.
247
248        Returns:
249            numpy.ndarray, the result image tensor.
250
251        Raises:
252            TypeError: Be raised for any argument or data type problem.
253            ValueError: Be raised for any argument or data value problem.
254        """
255
256        cls._apply_check_args(image, mask, edit_steps)
257
258        mask = compile_mask(mask, image)
259
260        background = image if inplace else np.copy(image)
261
262        if not edit_steps:
263            return background
264
265        for step in edit_steps:
266
267            x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2])
268
269            if x_max <= step.x or y_max <= step.y:
270                continue
271
272            if isinstance(mask, np.ndarray):
273                background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max]
274            else:
275                if isinstance(mask, (int, float)):
276                    mask = (mask, mask, mask)
277                for c in range(3):
278                    background[..., c, step.y:y_max, step.x:x_max] = mask[c]
279        return background
280
281    @classmethod
282    def apply_unmasking(cls,
283                        image,
284                        mask,
285                        edit_steps,
286                        inplace=False):
287        """
288        Apply edit steps in unmasking mode.
289
290        Args:
291            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
292            mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
293                str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
294                tuple[float, float, float]: RGB solid color mask,
295                float: Grey scale solid color mask.
296                numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
297            edit_steps (list[EditStep]): Edit steps to be applied.
298            inplace (bool): Whether the modification is going to take place in the input mask tensor. False to
299                construct a new image tensor as result.
300
301        Returns:
302            numpy.ndarray, the result image tensor.
303
304        Raises:
305            TypeError: Be raised for any argument or data type problem.
306            ValueError: Be raised for any argument or data value problem.
307        """
308
309        cls._apply_check_args(image, mask, edit_steps)
310
311        mask = compile_mask(mask, image)
312
313        if isinstance(mask, np.ndarray):
314            if inplace:
315                background = mask
316            else:
317                background = np.copy(mask)
318        else:
319            if inplace:
320                raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray')
321
322            background = np.zeros_like(image)
323            if isinstance(mask, (int, float)):
324                background.fill(mask)
325            else:
326                for c in range(3):
327                    background[..., c, :, :] = mask[c]
328
329        if not edit_steps:
330            return background
331
332        for step in edit_steps:
333
334            x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2])
335
336            if x_max <= step.x or y_max <= step.y:
337                continue
338
339            background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max]
340
341        return background
342
343    @staticmethod
344    def _apply_check_args(image, mask, edit_steps):
345        """
346        Check arguments for apply edit steps.
347
348        Args:
349            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
350            mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
351                str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
352                tuple[float, float, float]: RGB solid color mask,
353                float: Grey scale solid color mask.
354                numpy.ndarray: Image mask in CHW or NCHW(N=1) format.
355            edit_steps (list[EditStep], optional): Edit steps to be applied.
356
357        Raises:
358            TypeError: Be raised for any argument or data type problem.
359            ValueError: Be raised for any argument or data value problem.
360        """
361        check_value_type('image', image, np.ndarray)
362        check_value_type('mask', mask, (str, tuple, float, np.ndarray))
363        if isinstance(mask, tuple):
364            _check_iterable_type('mask', mask, tuple, float)
365
366        if edit_steps is not None:
367            _check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep)
368
369    @staticmethod
370    def _get_step_xy_max(step, x_limit, y_limit):
371        """Get the step x and y max. position."""
372        x_max = step.x + step.width
373        y_max = step.y + step.height
374
375        if x_max > x_limit:
376            x_max = x_limit
377
378        if y_max > y_limit:
379            y_max = y_limit
380        return x_max, y_max
381
382
383class NoValidResultError(RuntimeError):
384    """Error for no edit step layer's network output meet the threshold."""
385
386
387class OriginalOutputError(RuntimeError):
388    """Error for network output of the original image is not strictly larger than the threshold."""
389
390
391@deprecated_error
392class Searcher:
393    """
394    Edit step searcher.
395
396    Args:
397        network (Cell): Image tensor in CHW or NCHW(N=1) format.
398        win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers,
399            None means by auto calcuation.
400        strides (Union(list[int], optional): Stride of layers, None means by auto calcuation.
401        threshold (float): Threshold network output value of the target class.
402        by_masking (bool): Whether it is masking mode.
403
404    Raises:
405        ValueError: Be raised for any data or settings' value problem.
406        TypeError: Be raised for any data or settings' type problem.
407        RuntimeError: Be raised if this function was invoked before.
408
409    Supported Platforms:
410        ``Ascend`` ``GPU``
411    """
412
413    def __init__(self,
414                 network,
415                 win_sizes=None,
416                 strides=None,
417                 threshold=DEFAULT_THRESHOLD,
418                 by_masking=False):
419
420        check_value_type('network', network, nn.Cell)
421
422        if win_sizes is not None:
423            _check_iterable_type('win_sizes', win_sizes, list, int)
424            if not win_sizes:
425                raise ValueError('Argument win_sizes is empty.')
426
427            for i in range(1, len(win_sizes)):
428                if win_sizes[i] >= win_sizes[i-1]:
429                    raise ValueError('Argument win_sizes is not strictly descending.')
430
431            if win_sizes[-1] <= 0:
432                raise ValueError('Argument win_sizes has non-positive number.')
433        elif strides is not None:
434            raise ValueError('Argument win_sizes cannot be None if strides is not None.')
435
436        if strides is not None:
437            _check_iterable_type('strides', strides, list, int)
438            for i in range(1, len(strides)):
439                if strides[i] >= strides[i-1]:
440                    raise ValueError('Argument win_sizes is not strictly descending.')
441
442            if strides[-1] <= 0:
443                raise ValueError('Argument strides has non-positive number.')
444
445            if len(strides) != len(win_sizes):
446                raise ValueError('Length of strides and win_sizes is not equal.')
447        elif win_sizes is not None:
448            raise ValueError('Argument strides cannot be None if win_sizes is not None.')
449
450        self._network = copy.deepcopy(network)
451        self._compiled_mask = None
452        self._threshold = threshold
453        self._win_sizes = copy.copy(win_sizes) if win_sizes else None
454        self._strides = copy.copy(strides) if strides else None
455        self._by_masking = by_masking
456
457    @property
458    def network(self):
459        """Get the network."""
460        return self._network
461
462    @property
463    def by_masking(self):
464        """Check if it is masking mode."""
465        return self._by_masking
466
467    @property
468    def threshold(self):
469        """The network output threshold to stop the search."""
470        return self._threshold
471
472    @property
473    def win_sizes(self):
474        """Windows sizes in pixels."""
475        return self._win_sizes
476
477    @property
478    def strides(self):
479        """Strides in pixels."""
480        return self._strides
481
482    @property
483    def compiled_mask(self):
484        """The compiled mask after a successful search() call."""
485        return self._compiled_mask
486
487    def search(self, image, class_idx, mask=None):
488        """
489        Search smallest sufficient/destruction region on an image.
490
491        Args:
492            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
493            class_idx (int): Target class index.
494            mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be
495                str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9.
496                tuple[float, float, float]: RGB solid color mask,
497                float: Grey scale solid color mask.
498                None: By auto calculation.
499
500        Returns:
501            tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the
502                layer steps.
503
504        Raises:
505            TypeError: Be raised for any argument or data type problem.
506            ValueError: Be raised for any argument or data value problem.
507            NoValidResultError: Be raised if no valid result was found.
508            OriginalOutputError: Be raised if network output of the original image is not strictly larger than
509                the threshold.
510        """
511        check_value_type('image', image, (Tensor, np.ndarray))
512
513        if isinstance(image, Tensor):
514            image = image.asnumpy()
515
516        if len(image.shape) == 4:
517            if image.shape[0] != 1:
518                raise ValueError("Argument image's batch size is not 1.")
519        elif len(image.shape) == 3:
520            image = np.expand_dims(image, axis=0)
521        else:
522            raise ValueError("Argument image is not in CHW or NCHW(N=1) format.")
523
524        check_value_type('class_idx', class_idx, int)
525
526        if class_idx < 0:
527            raise ValueError("Argument class_idx is less then zero.")
528
529        self._compiled_mask = compile_mask(mask, image)
530
531        short_side = np.min(image.shape[-2:])
532        if self._win_sizes is None:
533            win_sizes, strides = self._auto_win_sizes_strides(short_side)
534        else:
535            win_sizes, strides = self._win_sizes, self._strides
536
537        if short_side <= win_sizes[0]:
538            raise ValueError(f"Input image's short side is shorter then or "
539                             f"equals to the first window size:{win_sizes[0]}.")
540
541        self._network.set_train(False)
542
543        # the search result will be store as a edit tree that attached to the root step.
544        root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2]))
545        root_job = _SearchJob(by_masking=self._by_masking,
546                              class_idx=class_idx,
547                              win_sizes=win_sizes,
548                              strides=strides,
549                              layer=0,
550                              search_field=root_step.box,
551                              pre_edit_steps=None,
552                              parent_step=root_step)
553        self._process_root_job(image, root_job)
554        return self._touch_result(image, class_idx, root_step)
555
556    def _touch_result(self, image, class_idx, root_step):
557        """
558        Final treatment to the search result.
559
560        Args:
561            image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format.
562            class_idx (int): Target class index.
563            root_step (EditStep): The searched root step.
564
565        Returns:
566            tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the
567                layer steps.
568
569        Raises:
570            NoValidResultError: Be raised if no valid result was found.
571        """
572        # the leaf layer's network output may not meet the threshold,
573        # we have to cutoff the unqualified layers
574        layer_count = root_step.max_layer + 1
575        if layer_count == 0:
576            raise NoValidResultError("No edit step layer was found.")
577
578        # gather the network output of each layer
579        layer_outputs = [None] * layer_count
580        for layer in range(layer_count):
581            steps = root_step.get_layer_or_leaf_steps(layer)
582            if not steps:
583                continue
584            masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking)
585            output = self._network(Tensor(masked_image))
586            output = output[0, class_idx].asnumpy().item()
587            layer_outputs[layer] = output
588
589        # determine which layer we have to cutoff
590        cutoff_layer = None
591        for layer in reversed(range(layer_count)):
592            if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]):
593                cutoff_layer = layer
594                break
595
596        if cutoff_layer is None or root_step.is_leaf:
597            raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.")
598
599        # cutoff the layer by removing all children of the layer's steps.
600        steps = root_step.get_layer_steps(cutoff_layer)
601        for step in steps:
602            step.remove_all_children()
603        layer_outputs = layer_outputs[:cutoff_layer + 1]
604
605        return root_step, layer_outputs
606
607    def _process_root_job(self, sample_input, root_job):
608        """
609        Process job queue.
610
611        Args:
612            sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
613            root_job (_SearchJob): Root search job.
614        """
615        job_queue = [root_job]
616        while job_queue:
617            job = job_queue.pop(0)
618            sub_job_queue = []
619            job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue)
620
621            if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET):
622                for step in job_edit_steps:
623                    job.parent_step.add_child(step)
624                job_queue.extend(sub_job_queue)
625
626    def _prepare_job(self, job, sample_input):
627        """
628        Prepare a job for process.
629
630        Args:
631            job (_SearchJob): Search job to be processed.
632            sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
633
634        Returns:
635            numpy.ndarray, the image tensor workpiece.
636
637        Raises:
638            OriginalOutputError: Be raised if network output of the original image is not strictly larger than the
639                threshold.
640        """
641        # make the network output with the original image is strictly larger than the threshold
642        if job.layer == 0:
643            original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item()
644            if original_output <= self._threshold:
645                raise OriginalOutputError(f'The original output is not strictly larger the threshold: '
646                                          f'{self._threshold}')
647
648        # applying the pre-edit steps from the parent steps
649        if job.pre_edit_steps:
650            # use the latest leaf steps to increase the accuracy
651            leaf_steps = []
652            for step in job.pre_edit_steps:
653                leaf_steps.extend(step.leaf_steps)
654            pre_edit_steps = leaf_steps
655        else:
656            pre_edit_steps = None
657        workpiece = EditStep.apply(sample_input,
658                                   self._compiled_mask,
659                                   pre_edit_steps,
660                                   self._by_masking)
661
662        job.on_start(sample_input, workpiece, self._compiled_mask, self._network)
663        return workpiece
664
665    def _process_job(self, job, sample_input, job_queue):
666        """
667        Process a job.
668
669        Args:
670            job (_SearchJob): Search job to be processed.
671            sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format.
672            job_queue (list[_SearchJob]): Job queue.
673
674        Returns:
675            tuple[list[EditStep], _StopReason], result edit stop and the stop reason.
676
677        Raises:
678            OriginalOutputError: Be raised if network output of the original image is not strictly larger than the
679                threshold.
680        """
681        workpiece = self._prepare_job(job, sample_input)
682
683        start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item()
684        last_output = start_output
685        edit_steps = []
686        # greedy search loop
687        while True:
688
689            if self._is_threshold_met(last_output):
690                return edit_steps, self._StopReason.THRESHOLD_MET
691
692            try:
693                best_edit = job.find_best_edit()
694            except _NoNewStepError:
695                return edit_steps, self._StopReason.NO_NEW_STEP
696            except _RepeatedStepError:
697                return edit_steps, self._StopReason.REPEATED_STEP
698
699            best_edit.step_change = best_edit.network_output - last_output
700
701            if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change):
702                # create net layer search job if new edit step is valid and not yet reaching
703                # the final layer
704                if job.pre_edit_steps:
705                    pre_edit_steps = list(job.pre_edit_steps)
706                    pre_edit_steps.extend(edit_steps)
707                else:
708                    pre_edit_steps = list(edit_steps)
709
710                sub_job = job.create_sub_job(best_edit, pre_edit_steps)
711                job_queue.append(sub_job)
712
713            edit_steps.append(best_edit)
714
715            if job.layer > 0:
716                # stop if the step change meet the parent step change only after layer 0
717                change = best_edit.network_output - start_output
718                if self._is_step_change_met(job.parent_step.step_change, change):
719                    return edit_steps, self._StopReason.STEP_CHANGE_MET
720
721            last_output = best_edit.network_output
722
723    def _is_threshold_met(self, network_output):
724        """Check if the threshold was met."""
725        if self._by_masking:
726            return network_output <= self._threshold
727        return network_output >= self._threshold
728
729    def _is_step_change_met(self, target, step_change):
730        """Check if the change target was met."""
731        if self._by_masking:
732            return step_change <= target
733        return step_change >= target
734
735    def _is_greedy(self, step_change):
736        """Check if it is a greedy step."""
737        if self._by_masking:
738            return step_change < 0
739        return step_change > 0
740
741    @classmethod
742    def _auto_win_sizes_strides(cls, short_side):
743        """
744        Calculate auto window sizes and strides.
745
746        Args:
747            short_side (int): Length of search space.
748
749        Returns:
750            tuple[list[int], list[int]], window sizes and strides.
751        """
752        win_sizes = []
753        strides = []
754        cur_len = int(short_side/AUTO_WIN_SIZE_DIV)
755        while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN:
756            stride = int(cur_len/AUTO_STRIDE_DIV)
757            if stride <= 0:
758                break
759            win_sizes.append(cur_len)
760            strides.append(stride)
761            cur_len = int(cur_len/AUTO_WIN_SIZE_DIV)
762        if not win_sizes:
763            raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, "
764                             f"unable to calculates auto settings.")
765        return win_sizes, strides
766
767    class _StopReason(Enum):
768        """Stop reason of search job."""
769        THRESHOLD_MET = 0       # threshold was met.
770        STEP_CHANGE_MET = 1     # parent step change was met.
771        NO_NEW_STEP = 2         # no new step was found.
772        REPEATED_STEP = 3       # repeated step was found.
773
774
775def _check_iterable_type(arg_name, arg_value, container_type, elem_types):
776    """Concert iterable argument data type."""
777    check_value_type(arg_name, arg_value, container_type)
778    for elem in arg_value:
779        check_value_type(arg_name + ' element', elem, elem_types)
780
781
782class _NoNewStepError(Exception):
783    """Error for no new step was found."""
784
785
786class _RepeatedStepError(Exception):
787    """Error for repeated step was found."""
788
789
790class _SearchJob:
791    """
792    Search job.
793
794    Args:
795        by_masking (bool): Whether it is masking mode.
796        class_idx (int): Target class index.
797        win_sizes (list[int]): Moving square window size (length of side) of layers.
798        strides (list[int]): Strides of layers.
799        layer (int): Layer number.
800        search_field (tuple[int, int, int, int]): Search field in x, y, width, height format.
801        pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching.
802        parent_step (EditStep): Parent edit step.
803        batch_size (int): Batch size of batched inferences.
804    """
805
806    def __init__(self,
807                 by_masking,
808                 class_idx,
809                 win_sizes,
810                 strides,
811                 layer,
812                 search_field,
813                 pre_edit_steps,
814                 parent_step,
815                 batch_size=DEFAULT_BATCH_SIZE):
816
817        if layer >= len(win_sizes):
818            raise ValueError('Layer is larger then number of window sizes.')
819
820        self.by_masking = by_masking
821        self.class_idx = class_idx
822        self.win_sizes = win_sizes
823        self.strides = strides
824        self.layer = layer
825        self.search_field = search_field
826        self.pre_edit_steps = pre_edit_steps
827        self.parent_step = parent_step
828        self.batch_size = batch_size
829        self.network = None
830        self.mask = None
831        self.original_input = None
832
833        self._workpiece = None
834        self._found_best_edits = None
835        self._found_uvs = None
836        self._u_pixels = None
837        self._v_pixels = None
838
839    @property
840    def layer_count(self):
841        """Number of layers."""
842        return len(self.win_sizes)
843
844    def on_start(self, original_input, workpiece, mask, network):
845        """
846        Notification of the start of the search job.
847
848        Args:
849            original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format.
850            workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format.
851            mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be
852                tuple[float, float, float]: RGB solid color mask,
853                float: Grey scale solid color mask.
854                numpy.ndarray: Image mask, has same format of original_input.
855            network (nn.Cell): Classification network.
856        """
857        self.original_input = original_input
858        self.mask = mask
859        self.network = network
860
861        self._workpiece = workpiece
862        self._found_best_edits = []
863        self._found_uvs = []
864        self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2])
865        self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3])
866
867    def create_sub_job(self, parent_step, pre_edit_steps):
868        """Create next layer search job."""
869        return self.__class__(by_masking=self.by_masking,
870                              class_idx=self.class_idx,
871                              win_sizes=self.win_sizes,
872                              strides=self.strides,
873                              layer=self.layer + 1,
874                              search_field=copy.copy(parent_step.box),
875                              pre_edit_steps=pre_edit_steps,
876                              parent_step=parent_step,
877                              batch_size=self.batch_size)
878
879    def find_best_edit(self):
880        """
881        Find the next best edit step.
882
883        Returns:
884            EditStep, the next best edit step.
885        """
886        workpiece = self._workpiece
887        if len(workpiece.shape) == 3:
888            workpiece = np.expand_dims(workpiece, axis=0)
889
890        # generate input tensors with shifted masked/unmasked region and pack into a batch
891        best_new_workpiece = None
892        best_output = None
893        best_edit = None
894        best_uv = None
895        batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
896        batch_uvs = []
897        batch_steps = []
898        batch_i = 0
899        win_size = self.win_sizes[self.layer]
900        for u, x in enumerate(self._u_pixels):
901            for v, y in enumerate(self._v_pixels):
902                if (u, v) in self._found_uvs:
903                    continue
904
905                edit_step = EditStep(self.layer, (x, y, win_size, win_size))
906
907                if self.by_masking:
908                    EditStep.apply(batch[batch_i],
909                                   self.mask,
910                                   [edit_step],
911                                   self.by_masking,
912                                   inplace=True)
913                else:
914                    EditStep.apply(self.original_input,
915                                   batch[batch_i],
916                                   [edit_step],
917                                   self.by_masking,
918                                   inplace=True)
919
920                batch_i += 1
921                batch_uvs.append((u, v))
922                batch_steps.append(edit_step)
923                if batch_i != self.batch_size:
924                    continue
925
926                # the batch is full, inference and empty it
927                updated = self._update_best(batch, batch_uvs, batch_steps, best_output)
928                if updated:
929                    best_output, best_uv, best_edit, best_new_workpiece = updated
930
931                batch = np.repeat(workpiece, repeats=self.batch_size, axis=0)
932                batch_uvs = []
933                batch_i = 0
934
935        if batch_i > 0:
936            # don't forget the last half full batch
937            updated = self._update_best(batch, batch_uvs, batch_steps, best_output, batch_i)
938            if updated:
939                best_output, best_uv, best_edit, best_new_workpiece = updated
940
941        if best_edit is None:
942            raise _NoNewStepError
943
944        if best_uv in self._found_uvs:
945            raise _RepeatedStepError
946
947        self._found_uvs.append(best_uv)
948        self._found_best_edits.append(best_edit)
949        best_edit.network_output = best_output
950
951        # continue on the best workpiece in the next function call
952        self._workpiece = best_new_workpiece
953
954        return best_edit
955
956    def _update_best(self, batch, batch_uvs, batch_steps, best_output, batch_i=None):
957        """Update the best edit step."""
958        squeeze = Squeeze()
959        batch_output = self.network(Tensor(batch))
960        batch_output = batch_output[:, self.class_idx]
961        if len(batch_output.shape) > 1:
962            batch_output = squeeze(batch_output)
963
964        aggregation = np.argmin if self.by_masking else np.argmax
965        if batch_i is None:
966            batch_best_i = aggregation(batch_output.asnumpy())
967        else:
968            batch_best_i = aggregation(batch_output.asnumpy()[:batch_i, ...])
969        batch_best_output = batch_output[int(batch_best_i)].asnumpy().item()
970
971        if best_output is None or self._is_output0_better(batch_best_output, best_output):
972            best_output = batch_best_output
973            best_uv = batch_uvs[batch_best_i]
974            best_edit = batch_steps[batch_best_i]
975            best_new_workpiece = batch[batch_best_i]
976            return best_output, best_uv, best_edit, best_new_workpiece
977        return None
978
979    def _is_output0_better(self, output0, output1):
980        """Check if the network output0 is better."""
981        if self.by_masking:
982            return output0 < output1
983        return output0 > output1
984
985    def _calc_uv_pixels(self, begin, length):
986        """
987        Calculate the pixel coordinate of shifts.
988
989        Args:
990            begin (int): The beginning pixel coordinate of search field.
991            length (int): The length of search field.
992
993        Returns:
994             list[int], pixel coordinate of shifts.
995        """
996        win_size = self.win_sizes[self.layer]
997        stride = self.strides[self.layer]
998        shift_count = self._calc_shift_count(length, win_size, stride)
999        pixels = [0] * shift_count
1000        for i in range(shift_count):
1001            if i == shift_count - 1:
1002                pixels[i] = begin + length - win_size
1003            else:
1004                pixels[i] = begin + i*stride
1005        return pixels
1006
1007    @staticmethod
1008    def _calc_shift_count(length, win_size, stride):
1009        """
1010        Calculate the number of shifts in search field.
1011
1012        Args:
1013            length (int): The length of search field.
1014            win_size (int): The length of sides of moving window.
1015            stride (int): The stride.
1016
1017        Returns:
1018             int, number of shifts.
1019        """
1020        if length <= win_size or win_size < stride or stride <= 0:
1021            raise ValueError("Invalid length, win_size or stride.")
1022        count = int(math.ceil((length - win_size)/stride))
1023        if (count - 1)*stride + win_size < length:
1024            return count + 1
1025        return count
1026