1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Base class `PerturbationAttribtuion`""" 17 18from mindspore.train._utils import check_value_type 19from mindspore.nn import Cell 20 21from ..attribution import Attribution 22 23 24class PerturbationAttribution(Attribution): 25 """ 26 Base class for perturbation-based attribution methods. 27 28 All perturbation-based _attribution methods extend from this class. 29 """ 30 31 def __init__(self, 32 network, 33 activation_fn, 34 perturbation_per_eval, 35 ): 36 super(PerturbationAttribution, self).__init__(network) 37 check_value_type("activation_fn", activation_fn, Cell) 38 self._activation_fn = activation_fn 39 check_value_type('perturbation_per_eval', perturbation_per_eval, int) 40 if perturbation_per_eval <= 0: 41 raise ValueError('Argument perturbation_per_eval should be a positive integer.') 42 self._perturbation_per_eval = perturbation_per_eval 43