• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Defines deprecated operators."""
17import itertools
18import numpy as np
19from mindspore.common._decorator import deprecated
20from mindspore import context
21from mindspore import _checkparam as validator
22from mindspore.ops import signature as sig
23from mindspore.ops.primitive import Primitive, prim_attr_register
24from mindspore.ops.operations.math_ops import _MathBinaryOp
25from mindspore.ops.operations.nn_ops import _check_positive_int_or_tuple
26
27
28class BNTrainingReduce(Primitive):
29    """
30    Please use BatchNorm instead.
31    """
32    @deprecated("1.5", "ops.BatchNorm", False)
33    @prim_attr_register
34    def __init__(self, data_format="NCHW"):
35        """Initialize BNTrainingReduce."""
36        super().__init__(name="BNTrainingReduce")
37        self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
38        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
39        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
40            raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
41                             f"but got the 'data_format' is {self.format} and "
42                             f"the platform is {context.get_context('device_target')}.")
43        self.add_prim_attr('data_format', self.format)
44
45
46class BNTrainingUpdate(Primitive):
47    """
48    Please use BatchNorm instead.
49    """
50    @deprecated("1.5", "ops.BatchNorm", False)
51    @prim_attr_register
52    def __init__(self, isRef=True, epsilon=1e-5, factor=0.1, data_format="NCHW"):
53        """Initialize BNTrainingUpdate."""
54        super().__init__(name="BNTrainingUpdate")
55        self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
56                                outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
57        validator.check_value_type("isRef", isRef, [bool], self.name)
58        validator.check_value_type("epsilon", epsilon, [float], self.name)
59        validator.check_value_type("factor", factor, [float], self.name)
60        self.epsilon = validator.check_float_range(epsilon, 0, 1, validator.INC_RIGHT, 'epsilon', 'BNTrainingUpdate')
61        self.factor = validator.check_float_range(factor, 0, 1, validator.INC_BOTH, 'factor', 'BNTrainingUpdate')
62        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
63        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
64            raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
65                             f"but got the 'data_format' is {self.format} and "
66                             f"the platform is {context.get_context('device_target')}.")
67        self.add_prim_attr('data_format', self.format)
68
69
70class MaxPoolWithArgmax(Primitive):
71    """
72    Please use :class:`mindspore.ops.MaxPoolWithArgmaxV2` instead.
73
74    Supported Platforms:
75        Deprecated
76    """
77    @deprecated("2.0", "ops.MaxPoolWithArgmaxV2", False)
78    @prim_attr_register
79    def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
80        """Initialize MaxPoolWithArgmax."""
81        super().__init__(name="MaxPoolWithArgmax")
82        self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
83        validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
84        validator.check_value_type('strides', strides, [int, tuple], self.name)
85        validator.check_value_type('pad_mode', pad_mode, [str], self.name)
86        self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
87        self.add_prim_attr("pad_mode", self.pad_mode)
88        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
89        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
90            raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
91                             f"but got the 'data_format' is {self.format} and "
92                             f"the platform is {context.get_context('device_target')}.")
93        self.kernel_size = _check_positive_int_or_tuple(
94            "kernel_size", kernel_size, self.name, allow_four=False, ret_four=True)
95        self.kernel_size = (1, self.kernel_size[-2], self.kernel_size[-1], 1)
96        self.add_prim_attr("kernel_size", self.kernel_size)
97
98        self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True)
99        self.strides = (1, self.strides[-2], self.strides[-1], 1)
100        self.add_prim_attr("strides", self.strides)
101
102
103class DropoutGenMask(Primitive):
104    """
105    Please use Dropout instead.
106    """
107    @deprecated("1.5", "ops.Dropout", False)
108    @prim_attr_register
109    def __init__(self, Seed0=0, Seed1=0):
110        """Initialize DropoutGenMask."""
111        self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output'])
112        validator.check_value_type("Seed0", Seed0, [int], self.name)
113        validator.check_value_type("Seed1", Seed1, [int], self.name)
114        self.add_prim_attr("side_effect_hidden", True)
115
116
117class DropoutDoMask(Primitive):
118    """
119    Please use Dropout instead.
120    """
121    @deprecated("1.5", "ops.Dropout", False)
122    @prim_attr_register
123    def __init__(self):
124        super().__init__(name="DropoutDoMask")
125
126
127class Gelu(Primitive):
128    """
129    Please use GeLU instead.
130    """
131    @deprecated("1.1", "GeLU", True)
132    @prim_attr_register
133    def __init__(self):
134        """Initialize Gelu"""
135        super().__init__(name="Gelu")
136        self.init_prim_io_names(inputs=['x'], outputs=['output'])
137
138
139class FastGelu(Primitive):
140    """
141    Please use FastGeLU instead.
142    """
143    @deprecated("1.1", "FastGeLU", True)
144    @prim_attr_register
145    def __init__(self):
146        """Initialize FastGelu."""
147        super().__init__(name="FastGelu")
148        self.init_prim_io_names(inputs=['x'], outputs=['output'])
149
150
151class TensorAdd(_MathBinaryOp):
152    """
153    Please use Add instead.
154    """
155    @deprecated("1.1", "Add", True)
156    @prim_attr_register
157    def __init__(self):
158        """Initialize TensorAdd."""
159        _MathBinaryOp.__init__(self)
160
161
162class InplaceUpdate(Primitive):
163    """
164    Please use :class:`mindspore.ops.InplaceUpdateV2` instead.
165
166    Supported Platforms:
167        Deprecated
168    """
169    @deprecated("2.0", "ops.InplaceUpdateV2", False)
170    @prim_attr_register
171    def __init__(self, indices):
172        """Initialize InplaceUpdate"""
173        self.init_prim_io_names(inputs=['x', 'v'], outputs=['y'])
174        self.indices = indices
175        validator.check_value_type("indices", indices, [int, tuple], self.name)
176        if isinstance(indices, int):
177            self.indices = (indices,)
178        for item in self.indices:
179            validator.check_value_type("item of indices", item, [int], self.name)
180
181
182class DynamicShape(Primitive):
183    """
184    Please use TensorShape instead.
185    """
186    @deprecated("1.7", "TensorShape", True)
187    @prim_attr_register
188    def __init__(self, dtype=9):
189        """init Shape"""
190        super().__init__(name="DynamicShape")
191        self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
192        self.add_prim_attr('is_dynamic_shape', True)
193
194
195class GatherV2(Primitive):
196    """
197    Please use Gather instead.
198    """
199    @deprecated("1.1", "Gather", True)
200    @prim_attr_register
201    def __init__(self):
202        """Initialize GatherV2"""
203        super().__init__(name="GatherV2")
204        self.add_prim_attr("batch_dims", 0)
205        self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
206
207
208class ScalarToArray(Primitive):
209    """
210    Please use scalar_to_tensor instead.
211    """
212    @deprecated("2.0", "ops.scalar_to_tensor", False)
213    @prim_attr_register
214    def __init__(self):
215        super().__init__(name="ScalarToArray")
216
217
218class Pack(Primitive):
219    """
220    Please use Stack instead.
221    """
222    @deprecated("1.1", "Stack", True)
223    @prim_attr_register
224    def __init__(self, axis=0):
225        """Initialize Pack"""
226        super().__init__(name="Pack")
227        validator.check_value_type("axis", axis, [int], self.name)
228        self.axis = axis
229
230
231class Unpack(Primitive):
232    """
233    Please use Unstack instead.
234    """
235    @deprecated("1.1", "Unstack", True)
236    @prim_attr_register
237    def __init__(self, axis=0):
238        """Initialize Unpack"""
239        super().__init__(name="Unpack")
240        validator.check_value_type("axis", axis, [int], self.name)
241        self.axis = axis
242
243
244class ScatterNonAliasingAdd(Primitive):
245    """
246    Please use TensorScatterAdd instead.
247
248    Supported Platforms:
249        Deprecated
250    """
251    __mindspore_signature__ = (
252        sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
253        sig.make_sig('indices', dtype=sig.sig_dtype.T1),
254        sig.make_sig('updates', dtype=sig.sig_dtype.T)
255    )
256    @deprecated("2.1", "ops.ScatterNonAliasingAdd", False)
257    @prim_attr_register
258    def __init__(self):
259        """Initialize ScatterNonAliasingAdd"""
260        super().__init__(name="ScatterNonAliasingAdd")
261        self.init_prim_io_names(inputs=['input_x', 'indices', 'updates'], outputs=['y'])
262        self.add_prim_attr('side_effect_mem', True)
263
264
265class BatchToSpaceND(Primitive):
266    """
267    Please use batch_to_space_nd instead.
268
269    Supported Platforms:
270        Deprecated
271    """
272    @deprecated("2.0", "ops.batch_to_space_nd", False)
273    @prim_attr_register
274    def __init__(self, block_shape, crops):
275        """Initialize BatchToSpaceND"""
276        super().__init__(name="BatchToSpaceND")
277        if isinstance(block_shape, int):
278            block_shape = (block_shape,) * np.array(crops).shape[0]
279        self.add_prim_attr("block_shape", block_shape)
280        validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
281        validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, validator.EQ, self.name)
282        block_rank = len(block_shape)
283        if context.get_context("device_target") == "Ascend":
284            validator.check('block_shape length', block_rank, '', 2, validator.EQ, self.name)
285        for elem in block_shape:
286            validator.check('block_shape element', elem, '', 1, validator.GE, self.name)
287            validator.check_value_type('block_shape element', elem, [int], self.name)
288        self.block_shape = block_shape
289
290        validator.check_value_type('crops type', crops, [list, tuple], self.name)
291        validator.check('crops length', len(crops), '', 1, validator.GE, self.name)
292        validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), validator.EQ, self.name)
293        for elem in itertools.chain(*crops):
294            validator.check_non_negative_int(elem, 'crops element', self.name)
295            validator.check_value_type('crops element', elem, [int], self.name)
296        self.crops = crops
297
298
299class identity(Primitive):
300    """
301    Please use side_effect_propagate instead.
302    """
303
304    # Side effect will propagated from the first argument to return value.
305    side_effect_propagate = 1
306
307    @prim_attr_register
308    def __init__(self):
309        """Initialize identity."""
310        super().__init__(name="identity")
311        self.add_prim_attr('side_effect_propagate', 1)
312
313    @deprecated('2.0', 'nn.Identity', False)
314    def __call__(self, x):
315        return x
316