# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Hierarchical occlusion edit tree searcher.""" from enum import Enum import copy import re import math import numpy as np from scipy.ndimage import gaussian_filter from mindspore import nn from mindspore import Tensor from mindspore.ops import Squeeze from mindspore.train._utils import check_value_type from mindspore.explainer._utils import deprecated_error AUTO_LAYER_MAX = 3 # maximum number of layer by auto settings AUTO_WIN_SIZE_MIN = 28 # minimum window size by auto settings AUTO_WIN_SIZE_DIV = 2 # denominator of windows size calculations by auto settings AUTO_STRIDE_DIV = 5 # denominator of stride calculations by auto settings AUTO_MASK_GAUSSIAN_RADIUS_DIV = 25 # denominator of gaussian mask radius calculations by auto settings DEFAULT_THRESHOLD = 0.5 # default target prediction threshold DEFAULT_BATCH_SIZE = 64 # default batch size for batch inference search MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' # gaussian mask string pattern # minimum length of input images' short side with auto settings AUTO_IMAGE_SHORT_SIDE_MIN = AUTO_WIN_SIZE_MIN * AUTO_WIN_SIZE_DIV @deprecated_error def is_valid_str_mask(mask): """Check if it is a valid string mask.""" check_value_type('mask', mask, str) match = re.match(MASK_GAUSSIAN_RE, mask) return match and int(match.group(1)) > 0 @deprecated_error def compile_mask(mask, image): """Compile mask to a ready to use object.""" if mask is None: return compile_str_mask(auto_str_mask(image), image) check_value_type('mask', mask, (str, tuple, float, np.ndarray)) if isinstance(mask, str): return compile_str_mask(mask, image) if isinstance(mask, tuple): _check_iterable_type('mask', mask, tuple, float) elif isinstance(mask, np.ndarray): if len(image.shape) == 4 and len(mask.shape) == 3: mask = np.expand_dims(mask, axis=0) elif len(image.shape) == 3 and len(mask.shape) == 4 and mask.shape[0] == 1: mask = mask.squeeze(0) if image.shape != mask.shape: raise ValueError("Image and mask is not match in shape.") return mask @deprecated_error def auto_str_mask(image): """Generate auto string mask for the image.""" check_value_type('image', image, np.ndarray) short_side = np.min(image.shape[-2:]) radius = int(round(short_side/AUTO_MASK_GAUSSIAN_RADIUS_DIV)) if radius == 0: raise ValueError(f"Input image's short side:{short_side} is too small for auto mask, " f"at least {AUTO_MASK_GAUSSIAN_RADIUS_DIV}pixels is required.") return f'gaussian:{radius}' @deprecated_error def compile_str_mask(mask, image): """Concert string mask to numpy.ndarray.""" check_value_type('mask', mask, str) check_value_type('image', image, np.ndarray) match = re.match(MASK_GAUSSIAN_RE, mask) if match: radius = int(match.group(1)) if radius > 0: sigma = [0] * len(image.shape) sigma[-2] = radius sigma[-1] = radius return gaussian_filter(image, sigma=sigma, mode='nearest') raise ValueError(f"Invalid string mask: '{mask}'.") @deprecated_error class EditStep: """ Edit step that describes a box region, also represents an edit tree. Args: layer (int): Layer number, -1 is root layer, 0 or above is normal edit layer. box (tuple[int, int, int, int]): Tuple of x, y, width, height. """ def __init__(self, layer, box): self.layer = layer self.box = box self.network_output = 0 self.step_change = 0 self.children = None @property def x(self): """X-coordinate of the box.""" return self.box[0] @property def y(self): """Y-coordinate of the box.""" return self.box[1] @property def width(self): """Width of the box.""" return self.box[2] @property def height(self): """Height of the box.""" return self.box[3] @property def is_leaf(self): """Returns True if no child edit step.""" return not self.children @property def leaf_steps(self): """Returns all leaf edit steps in the tree.""" if self.is_leaf: return [self] steps = [] for child in self.children: steps.extend(child.leaf_steps) return steps @property def max_layer(self): """Maximum layer number in the edit tree.""" if self.is_leaf: return self.layer layer = self.layer for child in self.children: child_max_layer = child.max_layer if child_max_layer > layer: layer = child_max_layer return layer def add_child(self, child): """Add a child edit step.""" if self.children is None: self.children = [child] else: self.children.append(child) def remove_all_children(self): """Remove all child steps.""" self.children = None def get_layer_or_leaf_steps(self, layer): """Get all edit steps of the layer and all leaf edit steps above the layer.""" if self.layer == layer or (self.layer < layer and self.is_leaf): return [self] steps = [] if self.layer < layer and self.children: for child in self.children: steps.extend(child.get_layer_or_leaf_steps(layer)) return steps def get_layer_steps(self, layer): """Get all edit steps of the layer.""" if self.layer == layer: return [self] steps = [] if self.layer < layer and self.children: for child in self.children: steps.extend(child.get_layer_steps(layer)) return steps @classmethod def apply(cls, image, mask, edit_steps, by_masking=False, inplace=False): """ Apply edit steps. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. numpy.ndarray: Image mask in CHW or NCHW(N=1) format. edit_steps (list[EditStep], optional): Edit steps to be applied. by_masking (bool): Whether it is masking mode. inplace (bool): Whether the modification is going to take place in the input image tensor. False to construct a new image tensor as result. Returns: numpy.ndarray, the result image tensor. Raises: TypeError: Be raised for any argument or data type problem. ValueError: Be raised for any argument or data value problem. """ if by_masking: return cls.apply_masking(image, mask, edit_steps, inplace) return cls.apply_unmasking(image, mask, edit_steps, inplace) @classmethod def apply_masking(cls, image, mask, edit_steps, inplace=False): """ Apply edit steps in masking mode. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. numpy.ndarray: Image mask in CHW or NCHW(N=1) format. edit_steps (list[EditStep], optional): Edit steps to be applied. inplace (bool): Whether the modification is going to take place in the input image tensor. False to construct a new image tensor as result. Returns: numpy.ndarray, the result image tensor. Raises: TypeError: Be raised for any argument or data type problem. ValueError: Be raised for any argument or data value problem. """ cls._apply_check_args(image, mask, edit_steps) mask = compile_mask(mask, image) background = image if inplace else np.copy(image) if not edit_steps: return background for step in edit_steps: x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2]) if x_max <= step.x or y_max <= step.y: continue if isinstance(mask, np.ndarray): background[..., step.y:y_max, step.x:x_max] = mask[..., step.y:y_max, step.x:x_max] else: if isinstance(mask, (int, float)): mask = (mask, mask, mask) for c in range(3): background[..., c, step.y:y_max, step.x:x_max] = mask[c] return background @classmethod def apply_unmasking(cls, image, mask, edit_steps, inplace=False): """ Apply edit steps in unmasking mode. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. numpy.ndarray: Image mask in CHW or NCHW(N=1) format. edit_steps (list[EditStep]): Edit steps to be applied. inplace (bool): Whether the modification is going to take place in the input mask tensor. False to construct a new image tensor as result. Returns: numpy.ndarray, the result image tensor. Raises: TypeError: Be raised for any argument or data type problem. ValueError: Be raised for any argument or data value problem. """ cls._apply_check_args(image, mask, edit_steps) mask = compile_mask(mask, image) if isinstance(mask, np.ndarray): if inplace: background = mask else: background = np.copy(mask) else: if inplace: raise ValueError('Inplace cannot be True when mask is not a numpy.ndarray') background = np.zeros_like(image) if isinstance(mask, (int, float)): background.fill(mask) else: for c in range(3): background[..., c, :, :] = mask[c] if not edit_steps: return background for step in edit_steps: x_max, y_max = cls._get_step_xy_max(step, background.shape[-1], background.shape[-2]) if x_max <= step.x or y_max <= step.y: continue background[..., step.y:y_max, step.x:x_max] = image[..., step.y:y_max, step.x:x_max] return background @staticmethod def _apply_check_args(image, mask, edit_steps): """ Check arguments for apply edit steps. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. mask (Union[str, tuple[float, float, float], float, numpy.ndarray]): The mask, type can be str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. numpy.ndarray: Image mask in CHW or NCHW(N=1) format. edit_steps (list[EditStep], optional): Edit steps to be applied. Raises: TypeError: Be raised for any argument or data type problem. ValueError: Be raised for any argument or data value problem. """ check_value_type('image', image, np.ndarray) check_value_type('mask', mask, (str, tuple, float, np.ndarray)) if isinstance(mask, tuple): _check_iterable_type('mask', mask, tuple, float) if edit_steps is not None: _check_iterable_type('edit_steps', edit_steps, (tuple, list), EditStep) @staticmethod def _get_step_xy_max(step, x_limit, y_limit): """Get the step x and y max. position.""" x_max = step.x + step.width y_max = step.y + step.height if x_max > x_limit: x_max = x_limit if y_max > y_limit: y_max = y_limit return x_max, y_max class NoValidResultError(RuntimeError): """Error for no edit step layer's network output meet the threshold.""" class OriginalOutputError(RuntimeError): """Error for network output of the original image is not strictly larger than the threshold.""" @deprecated_error class Searcher: """ Edit step searcher. Args: network (Cell): Image tensor in CHW or NCHW(N=1) format. win_sizes (Union(list[int], optional): Moving square window size (length of side) of layers, None means by auto calcuation. strides (Union(list[int], optional): Stride of layers, None means by auto calcuation. threshold (float): Threshold network output value of the target class. by_masking (bool): Whether it is masking mode. Raises: ValueError: Be raised for any data or settings' value problem. TypeError: Be raised for any data or settings' type problem. RuntimeError: Be raised if this function was invoked before. Supported Platforms: ``Ascend`` ``GPU`` """ def __init__(self, network, win_sizes=None, strides=None, threshold=DEFAULT_THRESHOLD, by_masking=False): check_value_type('network', network, nn.Cell) if win_sizes is not None: _check_iterable_type('win_sizes', win_sizes, list, int) if not win_sizes: raise ValueError('Argument win_sizes is empty.') for i in range(1, len(win_sizes)): if win_sizes[i] >= win_sizes[i-1]: raise ValueError('Argument win_sizes is not strictly descending.') if win_sizes[-1] <= 0: raise ValueError('Argument win_sizes has non-positive number.') elif strides is not None: raise ValueError('Argument win_sizes cannot be None if strides is not None.') if strides is not None: _check_iterable_type('strides', strides, list, int) for i in range(1, len(strides)): if strides[i] >= strides[i-1]: raise ValueError('Argument win_sizes is not strictly descending.') if strides[-1] <= 0: raise ValueError('Argument strides has non-positive number.') if len(strides) != len(win_sizes): raise ValueError('Length of strides and win_sizes is not equal.') elif win_sizes is not None: raise ValueError('Argument strides cannot be None if win_sizes is not None.') self._network = copy.deepcopy(network) self._compiled_mask = None self._threshold = threshold self._win_sizes = copy.copy(win_sizes) if win_sizes else None self._strides = copy.copy(strides) if strides else None self._by_masking = by_masking @property def network(self): """Get the network.""" return self._network @property def by_masking(self): """Check if it is masking mode.""" return self._by_masking @property def threshold(self): """The network output threshold to stop the search.""" return self._threshold @property def win_sizes(self): """Windows sizes in pixels.""" return self._win_sizes @property def strides(self): """Strides in pixels.""" return self._strides @property def compiled_mask(self): """The compiled mask after a successful search() call.""" return self._compiled_mask def search(self, image, class_idx, mask=None): """ Search smallest sufficient/destruction region on an image. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. class_idx (int): Target class index. mask (Union[str, tuple[float, float, float], float], optional): The mask, type can be str: String mask, e.g. 'gaussian:9' - Gaussian blur with radius of 9. tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. None: By auto calculation. Returns: tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the layer steps. Raises: TypeError: Be raised for any argument or data type problem. ValueError: Be raised for any argument or data value problem. NoValidResultError: Be raised if no valid result was found. OriginalOutputError: Be raised if network output of the original image is not strictly larger than the threshold. """ check_value_type('image', image, (Tensor, np.ndarray)) if isinstance(image, Tensor): image = image.asnumpy() if len(image.shape) == 4: if image.shape[0] != 1: raise ValueError("Argument image's batch size is not 1.") elif len(image.shape) == 3: image = np.expand_dims(image, axis=0) else: raise ValueError("Argument image is not in CHW or NCHW(N=1) format.") check_value_type('class_idx', class_idx, int) if class_idx < 0: raise ValueError("Argument class_idx is less then zero.") self._compiled_mask = compile_mask(mask, image) short_side = np.min(image.shape[-2:]) if self._win_sizes is None: win_sizes, strides = self._auto_win_sizes_strides(short_side) else: win_sizes, strides = self._win_sizes, self._strides if short_side <= win_sizes[0]: raise ValueError(f"Input image's short side is shorter then or " f"equals to the first window size:{win_sizes[0]}.") self._network.set_train(False) # the search result will be store as a edit tree that attached to the root step. root_step = EditStep(-1, (0, 0, image.shape[-1], image.shape[-2])) root_job = _SearchJob(by_masking=self._by_masking, class_idx=class_idx, win_sizes=win_sizes, strides=strides, layer=0, search_field=root_step.box, pre_edit_steps=None, parent_step=root_step) self._process_root_job(image, root_job) return self._touch_result(image, class_idx, root_step) def _touch_result(self, image, class_idx, root_step): """ Final treatment to the search result. Args: image (numpy.ndarray): Image tensor in CHW or NCHW(N=1) format. class_idx (int): Target class index. root_step (EditStep): The searched root step. Returns: tuple[EditStep, list[float]], the root edit step and network output of each layer after applied the layer steps. Raises: NoValidResultError: Be raised if no valid result was found. """ # the leaf layer's network output may not meet the threshold, # we have to cutoff the unqualified layers layer_count = root_step.max_layer + 1 if layer_count == 0: raise NoValidResultError("No edit step layer was found.") # gather the network output of each layer layer_outputs = [None] * layer_count for layer in range(layer_count): steps = root_step.get_layer_or_leaf_steps(layer) if not steps: continue masked_image = EditStep.apply(image, self._compiled_mask, steps, by_masking=self._by_masking) output = self._network(Tensor(masked_image)) output = output[0, class_idx].asnumpy().item() layer_outputs[layer] = output # determine which layer we have to cutoff cutoff_layer = None for layer in reversed(range(layer_count)): if layer_outputs[layer] is not None and self._is_threshold_met(layer_outputs[layer]): cutoff_layer = layer break if cutoff_layer is None or root_step.is_leaf: raise NoValidResultError(f"No edit step layer's network output meet the threshold: {self._threshold}.") # cutoff the layer by removing all children of the layer's steps. steps = root_step.get_layer_steps(cutoff_layer) for step in steps: step.remove_all_children() layer_outputs = layer_outputs[:cutoff_layer + 1] return root_step, layer_outputs def _process_root_job(self, sample_input, root_job): """ Process job queue. Args: sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. root_job (_SearchJob): Root search job. """ job_queue = [root_job] while job_queue: job = job_queue.pop(0) sub_job_queue = [] job_edit_steps, stop_reason = self._process_job(job, sample_input, sub_job_queue) if stop_reason in (self._StopReason.THRESHOLD_MET, self._StopReason.STEP_CHANGE_MET): for step in job_edit_steps: job.parent_step.add_child(step) job_queue.extend(sub_job_queue) def _prepare_job(self, job, sample_input): """ Prepare a job for process. Args: job (_SearchJob): Search job to be processed. sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. Returns: numpy.ndarray, the image tensor workpiece. Raises: OriginalOutputError: Be raised if network output of the original image is not strictly larger than the threshold. """ # make the network output with the original image is strictly larger than the threshold if job.layer == 0: original_output = self._network(Tensor(sample_input))[0, job.class_idx].asnumpy().item() if original_output <= self._threshold: raise OriginalOutputError(f'The original output is not strictly larger the threshold: ' f'{self._threshold}') # applying the pre-edit steps from the parent steps if job.pre_edit_steps: # use the latest leaf steps to increase the accuracy leaf_steps = [] for step in job.pre_edit_steps: leaf_steps.extend(step.leaf_steps) pre_edit_steps = leaf_steps else: pre_edit_steps = None workpiece = EditStep.apply(sample_input, self._compiled_mask, pre_edit_steps, self._by_masking) job.on_start(sample_input, workpiece, self._compiled_mask, self._network) return workpiece def _process_job(self, job, sample_input, job_queue): """ Process a job. Args: job (_SearchJob): Search job to be processed. sample_input (numpy.ndarray): Image tensor in NCHW(N=1) format. job_queue (list[_SearchJob]): Job queue. Returns: tuple[list[EditStep], _StopReason], result edit stop and the stop reason. Raises: OriginalOutputError: Be raised if network output of the original image is not strictly larger than the threshold. """ workpiece = self._prepare_job(job, sample_input) start_output = self._network(Tensor(workpiece))[0, job.class_idx].asnumpy().item() last_output = start_output edit_steps = [] # greedy search loop while True: if self._is_threshold_met(last_output): return edit_steps, self._StopReason.THRESHOLD_MET try: best_edit = job.find_best_edit() except _NoNewStepError: return edit_steps, self._StopReason.NO_NEW_STEP except _RepeatedStepError: return edit_steps, self._StopReason.REPEATED_STEP best_edit.step_change = best_edit.network_output - last_output if job.layer < job.layer_count - 1 and self._is_greedy(best_edit.step_change): # create net layer search job if new edit step is valid and not yet reaching # the final layer if job.pre_edit_steps: pre_edit_steps = list(job.pre_edit_steps) pre_edit_steps.extend(edit_steps) else: pre_edit_steps = list(edit_steps) sub_job = job.create_sub_job(best_edit, pre_edit_steps) job_queue.append(sub_job) edit_steps.append(best_edit) if job.layer > 0: # stop if the step change meet the parent step change only after layer 0 change = best_edit.network_output - start_output if self._is_step_change_met(job.parent_step.step_change, change): return edit_steps, self._StopReason.STEP_CHANGE_MET last_output = best_edit.network_output def _is_threshold_met(self, network_output): """Check if the threshold was met.""" if self._by_masking: return network_output <= self._threshold return network_output >= self._threshold def _is_step_change_met(self, target, step_change): """Check if the change target was met.""" if self._by_masking: return step_change <= target return step_change >= target def _is_greedy(self, step_change): """Check if it is a greedy step.""" if self._by_masking: return step_change < 0 return step_change > 0 @classmethod def _auto_win_sizes_strides(cls, short_side): """ Calculate auto window sizes and strides. Args: short_side (int): Length of search space. Returns: tuple[list[int], list[int]], window sizes and strides. """ win_sizes = [] strides = [] cur_len = int(short_side/AUTO_WIN_SIZE_DIV) while len(win_sizes) < AUTO_LAYER_MAX and cur_len >= AUTO_WIN_SIZE_MIN: stride = int(cur_len/AUTO_STRIDE_DIV) if stride <= 0: break win_sizes.append(cur_len) strides.append(stride) cur_len = int(cur_len/AUTO_WIN_SIZE_DIV) if not win_sizes: raise ValueError(f"Image's short side is less then {AUTO_IMAGE_SHORT_SIDE_MIN}, " f"unable to calculates auto settings.") return win_sizes, strides class _StopReason(Enum): """Stop reason of search job.""" THRESHOLD_MET = 0 # threshold was met. STEP_CHANGE_MET = 1 # parent step change was met. NO_NEW_STEP = 2 # no new step was found. REPEATED_STEP = 3 # repeated step was found. def _check_iterable_type(arg_name, arg_value, container_type, elem_types): """Concert iterable argument data type.""" check_value_type(arg_name, arg_value, container_type) for elem in arg_value: check_value_type(arg_name + ' element', elem, elem_types) class _NoNewStepError(Exception): """Error for no new step was found.""" class _RepeatedStepError(Exception): """Error for repeated step was found.""" class _SearchJob: """ Search job. Args: by_masking (bool): Whether it is masking mode. class_idx (int): Target class index. win_sizes (list[int]): Moving square window size (length of side) of layers. strides (list[int]): Strides of layers. layer (int): Layer number. search_field (tuple[int, int, int, int]): Search field in x, y, width, height format. pre_edit_steps (list[EditStep], optional): Edit steps to be applied before searching. parent_step (EditStep): Parent edit step. batch_size (int): Batch size of batched inferences. """ def __init__(self, by_masking, class_idx, win_sizes, strides, layer, search_field, pre_edit_steps, parent_step, batch_size=DEFAULT_BATCH_SIZE): if layer >= len(win_sizes): raise ValueError('Layer is larger then number of window sizes.') self.by_masking = by_masking self.class_idx = class_idx self.win_sizes = win_sizes self.strides = strides self.layer = layer self.search_field = search_field self.pre_edit_steps = pre_edit_steps self.parent_step = parent_step self.batch_size = batch_size self.network = None self.mask = None self.original_input = None self._workpiece = None self._found_best_edits = None self._found_uvs = None self._u_pixels = None self._v_pixels = None @property def layer_count(self): """Number of layers.""" return len(self.win_sizes) def on_start(self, original_input, workpiece, mask, network): """ Notification of the start of the search job. Args: original_input (numpy.ndarray): The original image tensor in CHW or NCHW(N=1) format. workpiece (numpy.ndarray): The intermediate image tensor in CHW or NCHW(N=1) format. mask (Union[tuple[float, float, float], float, numpy.ndarray]): The mask, type can be tuple[float, float, float]: RGB solid color mask, float: Grey scale solid color mask. numpy.ndarray: Image mask, has same format of original_input. network (nn.Cell): Classification network. """ self.original_input = original_input self.mask = mask self.network = network self._workpiece = workpiece self._found_best_edits = [] self._found_uvs = [] self._u_pixels = self._calc_uv_pixels(self.search_field[0], self.search_field[2]) self._v_pixels = self._calc_uv_pixels(self.search_field[1], self.search_field[3]) def create_sub_job(self, parent_step, pre_edit_steps): """Create next layer search job.""" return self.__class__(by_masking=self.by_masking, class_idx=self.class_idx, win_sizes=self.win_sizes, strides=self.strides, layer=self.layer + 1, search_field=copy.copy(parent_step.box), pre_edit_steps=pre_edit_steps, parent_step=parent_step, batch_size=self.batch_size) def find_best_edit(self): """ Find the next best edit step. Returns: EditStep, the next best edit step. """ workpiece = self._workpiece if len(workpiece.shape) == 3: workpiece = np.expand_dims(workpiece, axis=0) # generate input tensors with shifted masked/unmasked region and pack into a batch best_new_workpiece = None best_output = None best_edit = None best_uv = None batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) batch_uvs = [] batch_steps = [] batch_i = 0 win_size = self.win_sizes[self.layer] for u, x in enumerate(self._u_pixels): for v, y in enumerate(self._v_pixels): if (u, v) in self._found_uvs: continue edit_step = EditStep(self.layer, (x, y, win_size, win_size)) if self.by_masking: EditStep.apply(batch[batch_i], self.mask, [edit_step], self.by_masking, inplace=True) else: EditStep.apply(self.original_input, batch[batch_i], [edit_step], self.by_masking, inplace=True) batch_i += 1 batch_uvs.append((u, v)) batch_steps.append(edit_step) if batch_i != self.batch_size: continue # the batch is full, inference and empty it updated = self._update_best(batch, batch_uvs, batch_steps, best_output) if updated: best_output, best_uv, best_edit, best_new_workpiece = updated batch = np.repeat(workpiece, repeats=self.batch_size, axis=0) batch_uvs = [] batch_i = 0 if batch_i > 0: # don't forget the last half full batch updated = self._update_best(batch, batch_uvs, batch_steps, best_output, batch_i) if updated: best_output, best_uv, best_edit, best_new_workpiece = updated if best_edit is None: raise _NoNewStepError if best_uv in self._found_uvs: raise _RepeatedStepError self._found_uvs.append(best_uv) self._found_best_edits.append(best_edit) best_edit.network_output = best_output # continue on the best workpiece in the next function call self._workpiece = best_new_workpiece return best_edit def _update_best(self, batch, batch_uvs, batch_steps, best_output, batch_i=None): """Update the best edit step.""" squeeze = Squeeze() batch_output = self.network(Tensor(batch)) batch_output = batch_output[:, self.class_idx] if len(batch_output.shape) > 1: batch_output = squeeze(batch_output) aggregation = np.argmin if self.by_masking else np.argmax if batch_i is None: batch_best_i = aggregation(batch_output.asnumpy()) else: batch_best_i = aggregation(batch_output.asnumpy()[:batch_i, ...]) batch_best_output = batch_output[int(batch_best_i)].asnumpy().item() if best_output is None or self._is_output0_better(batch_best_output, best_output): best_output = batch_best_output best_uv = batch_uvs[batch_best_i] best_edit = batch_steps[batch_best_i] best_new_workpiece = batch[batch_best_i] return best_output, best_uv, best_edit, best_new_workpiece return None def _is_output0_better(self, output0, output1): """Check if the network output0 is better.""" if self.by_masking: return output0 < output1 return output0 > output1 def _calc_uv_pixels(self, begin, length): """ Calculate the pixel coordinate of shifts. Args: begin (int): The beginning pixel coordinate of search field. length (int): The length of search field. Returns: list[int], pixel coordinate of shifts. """ win_size = self.win_sizes[self.layer] stride = self.strides[self.layer] shift_count = self._calc_shift_count(length, win_size, stride) pixels = [0] * shift_count for i in range(shift_count): if i == shift_count - 1: pixels[i] = begin + length - win_size else: pixels[i] = begin + i*stride return pixels @staticmethod def _calc_shift_count(length, win_size, stride): """ Calculate the number of shifts in search field. Args: length (int): The length of search field. win_size (int): The length of sides of moving window. stride (int): The stride. Returns: int, number of shifts. """ if length <= win_size or win_size < stride or stride <= 0: raise ValueError("Invalid length, win_size or stride.") count = int(math.ceil((length - win_size)/stride)) if (count - 1)*stride + win_size < length: return count + 1 return count