1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Defines deprecated operators.""" 17import itertools 18import numpy as np 19from mindspore.common._decorator import deprecated 20from mindspore import context 21from mindspore import _checkparam as validator 22from mindspore.ops import signature as sig 23from mindspore.ops.primitive import Primitive, prim_attr_register 24from mindspore.ops.operations.math_ops import _MathBinaryOp 25from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple 26 27 28class BNTrainingReduce(Primitive): 29 """ 30 Please use BatchNorm instead. 31 """ 32 @deprecated("1.5", "ops.BatchNorm", False) 33 @prim_attr_register 34 def __init__(self, data_format="NCHW"): 35 """Initialize BNTrainingReduce.""" 36 super().__init__(name="BNTrainingReduce") 37 self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum']) 38 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 39 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 40 raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " 41 f"but got the 'data_format' is {self.format} and " 42 f"the platform is {context.get_context('device_target')}.") 43 self.add_prim_attr('data_format', self.format) 44 45 46class BNTrainingUpdate(Primitive): 47 """ 48 Please use BatchNorm instead. 49 """ 50 @deprecated("1.5", "ops.BatchNorm", False) 51 @prim_attr_register 52 def __init__(self, isRef=True, epsilon=1e-5, factor=0.1, data_format="NCHW"): 53 """Initialize BNTrainingUpdate.""" 54 super().__init__(name="BNTrainingUpdate") 55 self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'], 56 outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) 57 validator.check_value_type("isRef", isRef, [bool], self.name) 58 validator.check_value_type("epsilon", epsilon, [float], self.name) 59 validator.check_value_type("factor", factor, [float], self.name) 60 self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', 'BNTrainingUpdate') 61 self.factor = validator.check_float_range(factor, 0, 1, validator.INC_BOTH, 'factor', 'BNTrainingUpdate') 62 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 63 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 64 raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " 65 f"but got the 'data_format' is {self.format} and " 66 f"the platform is {context.get_context('device_target')}.") 67 self.add_prim_attr('data_format', self.format) 68 69 70class MaxPoolWithArgmax(Primitive): 71 """ 72 Please use :class:`mindspore.ops.MaxPoolWithArgmaxV2` instead. 73 74 Supported Platforms: 75 Deprecated 76 """ 77 @deprecated("2.0", "ops.MaxPoolWithArgmaxV2", False) 78 @prim_attr_register 79 def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): 80 """Initialize MaxPoolWithArgmax.""" 81 super().__init__(name="MaxPoolWithArgmax") 82 self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) 83 validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) 84 validator.check_value_type('strides', strides, [int, tuple], self.name) 85 validator.check_value_type('pad_mode', pad_mode, [str], self.name) 86 self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) 87 self.add_prim_attr("pad_mode", self.pad_mode) 88 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) 89 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 90 raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " 91 f"but got the 'data_format' is {self.format} and " 92 f"the platform is {context.get_context('device_target')}.") 93 self.kernel_size = _check_positive_int_or_tuple( 94 "kernel_size", kernel_size, self.name, allow_four=False, ret_four=True) 95 self.kernel_size = (1, self.kernel_size[-2], self.kernel_size[-1], 1) 96 self.add_prim_attr("kernel_size", self.kernel_size) 97 98 self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True) 99 self.strides = (1, self.strides[-2], self.strides[-1], 1) 100 self.add_prim_attr("strides", self.strides) 101 102 103class DropoutGenMask(Primitive): 104 """ 105 Please use Dropout instead. 106 """ 107 @deprecated("1.5", "ops.Dropout", False) 108 @prim_attr_register 109 def __init__(self, Seed0=0, Seed1=0): 110 """Initialize DropoutGenMask.""" 111 self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) 112 validator.check_value_type("Seed0", Seed0, [int], self.name) 113 validator.check_value_type("Seed1", Seed1, [int], self.name) 114 self.add_prim_attr("side_effect_hidden", True) 115 116 117class DropoutDoMask(Primitive): 118 """ 119 Please use Dropout instead. 120 """ 121 @deprecated("1.5", "ops.Dropout", False) 122 @prim_attr_register 123 def __init__(self): 124 super().__init__(name="DropoutDoMask") 125 126 127class Gelu(Primitive): 128 """ 129 Please use GeLU instead. 130 """ 131 @deprecated("1.1", "GeLU", True) 132 @prim_attr_register 133 def __init__(self): 134 """Initialize Gelu""" 135 super().__init__(name="Gelu") 136 self.init_prim_io_names(inputs=['x'], outputs=['output']) 137 138 139class FastGelu(Primitive): 140 """ 141 Please use FastGeLU instead. 142 """ 143 @deprecated("1.1", "FastGeLU", True) 144 @prim_attr_register 145 def __init__(self): 146 """Initialize FastGelu.""" 147 super().__init__(name="FastGelu") 148 self.init_prim_io_names(inputs=['x'], outputs=['output']) 149 150 151class TensorAdd(_MathBinaryOp): 152 """ 153 Please use Add instead. 154 """ 155 @deprecated("1.1", "Add", True) 156 @prim_attr_register 157 def __init__(self): 158 """Initialize TensorAdd.""" 159 _MathBinaryOp.__init__(self) 160 161 162class InplaceUpdate(Primitive): 163 """ 164 Please use :class:`mindspore.ops.InplaceUpdateV2` instead. 165 166 Supported Platforms: 167 Deprecated 168 """ 169 @deprecated("2.0", "ops.InplaceUpdateV2", False) 170 @prim_attr_register 171 def __init__(self, indices): 172 """Initialize InplaceUpdate""" 173 self.init_prim_io_names(inputs=['x', 'v'], outputs=['y']) 174 self.indices = indices 175 validator.check_value_type("indices", indices, [int, tuple], self.name) 176 if isinstance(indices, int): 177 self.indices = (indices,) 178 for item in self.indices: 179 validator.check_value_type("item of indices", item, [int], self.name) 180 181 182class DynamicShape(Primitive): 183 """ 184 Please use TensorShape instead. 185 """ 186 @deprecated("1.7", "TensorShape", True) 187 @prim_attr_register 188 def __init__(self, dtype=9): 189 """init Shape""" 190 super().__init__(name="DynamicShape") 191 self.init_prim_io_names(inputs=['tensor'], outputs=['output']) 192 self.add_prim_attr('is_dynamic_shape', True) 193 194 195class GatherV2(Primitive): 196 """ 197 Please use Gather instead. 198 """ 199 @deprecated("1.1", "Gather", True) 200 @prim_attr_register 201 def __init__(self): 202 """Initialize GatherV2""" 203 super().__init__(name="GatherV2") 204 self.add_prim_attr("batch_dims", 0) 205 self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) 206 207 208class ScalarToArray(Primitive): 209 """ 210 Please use scalar_to_tensor instead. 211 """ 212 @deprecated("2.0", "ops.scalar_to_tensor", False) 213 @prim_attr_register 214 def __init__(self): 215 super().__init__(name="ScalarToArray") 216 217 218class Pack(Primitive): 219 """ 220 Please use Stack instead. 221 """ 222 @deprecated("1.1", "Stack", True) 223 @prim_attr_register 224 def __init__(self, axis=0): 225 """Initialize Pack""" 226 super().__init__(name="Pack") 227 validator.check_value_type("axis", axis, [int], self.name) 228 self.axis = axis 229 230 231class Unpack(Primitive): 232 """ 233 Please use Unstack instead. 234 """ 235 @deprecated("1.1", "Unstack", True) 236 @prim_attr_register 237 def __init__(self, axis=0): 238 """Initialize Unpack""" 239 super().__init__(name="Unpack") 240 validator.check_value_type("axis", axis, [int], self.name) 241 self.axis = axis 242 243 244class ScatterNonAliasingAdd(Primitive): 245 """ 246 Please use TensorScatterAdd instead. 247 248 Supported Platforms: 249 Deprecated 250 """ 251 __mindspore_signature__ = ( 252 sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), 253 sig.make_sig('indices', dtype=sig.sig_dtype.T1), 254 sig.make_sig('updates', dtype=sig.sig_dtype.T) 255 ) 256 @deprecated("2.1", "ops.ScatterNonAliasingAdd", False) 257 @prim_attr_register 258 def __init__(self): 259 """Initialize ScatterNonAliasingAdd""" 260 super().__init__(name="ScatterNonAliasingAdd") 261 self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y']) 262 self.add_prim_attr('side_effect_mem', True) 263 264 265class BatchToSpaceND(Primitive): 266 """ 267 Please use batch_to_space_nd instead. 268 269 Supported Platforms: 270 Deprecated 271 """ 272 @deprecated("2.0", "ops.batch_to_space_nd", False) 273 @prim_attr_register 274 def __init__(self, block_shape, crops): 275 """Initialize BatchToSpaceND""" 276 super().__init__(name="BatchToSpaceND") 277 if isinstance(block_shape, int): 278 block_shape = (block_shape,) * np.array(crops).shape[0] 279 self.add_prim_attr("block_shape", block_shape) 280 validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name) 281 validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, validator.EQ, self.name) 282 block_rank = len(block_shape) 283 if context.get_context("device_target") == "Ascend": 284 validator.check('block_shape length', block_rank, '', 2, validator.EQ, self.name) 285 for elem in block_shape: 286 validator.check('block_shape element', elem, '', 1, validator.GE, self.name) 287 validator.check_value_type('block_shape element', elem, [int], self.name) 288 self.block_shape = block_shape 289 290 validator.check_value_type('crops type', crops, [list, tuple], self.name) 291 validator.check('crops length', len(crops), '', 1, validator.GE, self.name) 292 validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), validator.EQ, self.name) 293 for elem in itertools.chain(*crops): 294 validator.check_non_negative_int(elem, 'crops element', self.name) 295 validator.check_value_type('crops element', elem, [int], self.name) 296 self.crops = crops 297 298 299class identity(Primitive): 300 """ 301 Please use side_effect_propagate instead. 302 """ 303 304 # Side effect will propagated from the first argument to return value. 305 side_effect_propagate = 1 306 307 @prim_attr_register 308 def __init__(self): 309 """Initialize identity.""" 310 super().__init__(name="identity") 311 self.add_prim_attr('side_effect_propagate', 1) 312 313 @deprecated('2.0', 'nn.Identity', False) 314 def __call__(self, x): 315 return x 316