• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17
18"""The names of functional part are summarized here."""
19
20from mindspore.common._register_for_tensor import tensor_operator_registry
21from mindspore.ops import _constants
22from .primitive import Primitive
23from . import operations as P
24from .operations import _grad_ops
25from .composite import GradOperation
26from .._c_expression import security
27
28typeof = Primitive('typeof')
29hastype = Primitive('hastype')
30cast = P.Cast()
31dtype = P.DType()
32isconstant = Primitive('is_constant')
33isconstant.set_const_prim(True)
34
35issubclass_ = P.IsSubClass()
36isinstance_ = P.IsInstance()
37eye = P.Eye()
38fill = P.Fill()
39tile = P.Tile()
40select = P.Select()
41size = P.Size()
42ones_like = P.OnesLike()
43shape = P.Shape()
44rank = P.Rank()
45reshape = P.Reshape()
46
47merge = P.Merge()
48geswitch = P.GeSwitch()
49addn = P.AddN()
50absolute = P.Abs()
51tensor_add = P.Add()
52add = tensor_add
53neg_tensor = P.Neg()
54tensor_lt = P.Less()
55less = tensor_lt
56tensor_le = P.LessEqual()
57le = tensor_le
58tensor_gt = P.Greater()
59gt = tensor_gt
60tensor_ge = P.GreaterEqual()
61ge = tensor_ge
62tensor_sub = P.Sub()
63sub = tensor_sub
64tensor_mul = P.Mul()
65mul = tensor_mul
66tensor_div = P.RealDiv()
67div = tensor_div
68tensor_floordiv = P.FloorDiv()
69floordiv = tensor_floordiv
70tensor_pow = P.Pow()
71pows = tensor_pow
72tensor_mod = P.FloorMod()
73floormod = tensor_mod
74tensor_exp = P.Exp()
75exp = tensor_exp
76tensor_expm1 = P.Expm1()
77tensor_slice = P.Slice()
78strided_slice = P.StridedSlice()
79same_type_shape = P.SameTypeShape()
80check_bprop = P.CheckBprop()
81equal = P.Equal()
82not_equal = P.NotEqual()
83isfinite = P.IsFinite()
84isnan = P.IsNan()
85assign_sub = P.AssignSub()
86assign_add = P.AssignAdd()
87assign = P.Assign()
88square = P.Square()
89sqrt = P.Sqrt()
90log = P.Log()
91reduce_sum = P.ReduceSum()
92reduce_max = P.ReduceMax()
93reduce_min = P.ReduceMin()
94reduce_mean = P.ReduceMean()
95reduce_prod = P.ReduceProd()
96tensor_slice = P.Slice()
97maximum = P.Maximum()
98minimum = P.Minimum()
99floor = P.Floor()
100logical_not = P.LogicalNot()
101logical_or = P.LogicalOr()
102logical_and = P.LogicalAnd()
103sin = P.Sin()
104cos = P.Cos()
105tan = P.Tan()
106asin = P.Asin()
107acos = P.ACos()
108atan = P.Atan()
109sinh = P.Sinh()
110cosh = P.Cosh()
111tanh = P.Tanh()
112asinh = P.Asinh()
113acosh = P.Acosh()
114atanh = P.Atanh()
115atan2 = P.Atan2()
116bitwise_and = P.BitwiseAnd()
117bitwise_or = P.BitwiseOr()
118bitwise_xor = P.BitwiseXor()
119invert = P.Invert()
120erf = P.Erf()
121erfc = P.Erfc()
122sort = P.Sort()
123tensor_range = P.Range()
124
125scalar_to_array = P.ScalarToArray()
126scalar_to_tensor = P.ScalarToTensor()
127tuple_to_array = P.TupleToArray()
128scalar_cast = P.ScalarCast()
129if not security.enable_security():
130    print_ = P.Print()
131expand_dims = P.ExpandDims()
132transpose = P.Transpose()
133squeeze = P.Squeeze()
134scatter_nd = P.ScatterNd()
135gather = P.Gather()
136gather_d = P.GatherD()
137gather_nd = P.GatherNd()
138scatter_update = P.ScatterUpdate()
139tensor_scatter_update = P.TensorScatterUpdate()
140scatter_nd_update = P.ScatterNdUpdate()
141stack = P.Stack()
142
143
144def pack(x):
145    """Call stack in this pack function."""
146    print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead"
147          ".")
148    return stack(x)
149
150
151partial = P.Partial()
152# depend: mount a node to another node
153depend = P.Depend()
154identity = P.identity()
155
156grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False)
157grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False)
158
159
160def grad(fn, grad_first_param=False):
161    """
162    A wrapper function to generate the gradient function for the input function.
163
164    Args:
165        fn (Function): Function to do GradOperation.
166        grad_first_param (bool): If True, get the gradient with respect to first input.
167            If False, get all the gradients with respect to inputs. Default: False.
168    """
169    if grad_first_param:
170        return grad_first_parameter(fn)
171    return grad_all_parameters(fn)
172
173
174tuple_setitem = Primitive('tuple_setitem')
175tuple_getitem = Primitive(_constants.kTupleGetItem)
176list_getitem = Primitive('list_getitem')
177list_setitem = Primitive('list_setitem')
178dict_getitem = Primitive('dict_getitem')
179dict_setitem = Primitive('dict_setitem')
180tuple_div = Primitive("tuple_div")
181tuple_len = Primitive("tuple_len")
182list_len = Primitive("list_len")
183tuple_reversed = Primitive("tuple_reversed")
184make_range = Primitive("make_range")
185make_tuple = Primitive('MakeTuple')
186make_dict = Primitive('make_dict')
187make_list = Primitive('make_list')
188make_slice = Primitive('make_slice')
189tuple_equal = Primitive("tuple_equal")
190list_equal = Primitive("list_equal")
191make_ref = Primitive("make_ref")
192
193scalar_add = Primitive(_constants.kScalarAdd)
194scalar_mul = Primitive(_constants.kScalarMul)
195scalar_sub = Primitive(_constants.kScalarSub)
196scalar_div = Primitive(_constants.kScalarDiv)
197scalar_floordiv = Primitive(_constants.kScalarFloordiv)
198scalar_log = Primitive('scalar_log')
199scalar_pow = Primitive(_constants.kScalarPow)
200scalar_gt = Primitive('scalar_gt')
201scalar_ge = Primitive('scalar_ge')
202scalar_le = Primitive('scalar_le')
203scalar_lt = Primitive('scalar_lt')
204scalar_eq = Primitive('scalar_eq')
205scalar_ne = Primitive('scalar_ne')
206scalar_uadd = Primitive(_constants.kScalarUadd)
207scalar_usub = Primitive(_constants.kScalarUsub)
208scalar_mod = Primitive(_constants.kScalarMod)
209string_eq = Primitive('string_equal')
210string_concat = Primitive('string_concat')
211bool_not = Primitive("bool_not")
212bool_or = Primitive("bool_or")
213bool_and = Primitive("bool_and")
214bool_eq = Primitive("bool_eq")
215logical_and = P.LogicalAnd()
216logical_or = P.LogicalOr()
217logical_not = P.LogicalNot()
218cumsum = P.CumSum()
219cumprod = P.CumProd()
220tensor_scatter_add = P.TensorScatterAdd()
221array_to_scalar = Primitive('array_to_scalar')
222is_ = Primitive("is_")
223is_not = Primitive("is_not")
224in_dict = Primitive("in_dict")
225not_in_dict = Primitive("not_in_dict")
226mixed_precision_cast = Primitive("mixed_precision_cast")
227broadcast_gradient_args = Primitive('BroadcastGradientArgs')
228array_reduce = Primitive('array_reduce')
229zeros_like = P.ZerosLike()
230distribute = Primitive('distribute')
231embed = Primitive('embed')
232ref_to_embed = _grad_ops.RefToEmbed()
233env_setitem = Primitive('env_setitem')
234env_getitem = Primitive('env_getitem')
235env_add = Primitive('env_add')
236J = Primitive('J')
237switch = Primitive('Switch')
238switch_layer = Primitive('switch_layer')
239# for sum bprop
240reduced_shape = Primitive("reduced_shape")
241# shape_mul:input must be shape multiply elements in tuple(shape)
242shape_mul = Primitive("shape_mul")
243# a primitive to compare between tuple.
244stop_gradient = Primitive("stop_gradient")
245
246make_row_tensor = Primitive('MakeRowTensor')
247row_tensor_get_values = Primitive('RowTensorGetValues')
248row_tensor_get_indices = Primitive('RowTensorGetIndices')
249row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
250row_tensor_add = Primitive('RowTensorAdd')
251
252make_sparse_tensor = Primitive('MakeSparseTensor')
253sparse_tensor_get_values = Primitive('SparseTensorGetValues')
254sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
255sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
256
257tensor_operator_registry.register('all', P.ReduceAll)
258tensor_operator_registry.register('any', P.ReduceAny)
259tensor_operator_registry.register('abs', P.Abs)
260tensor_operator_registry.register('mean', P.ReduceMean)
261tensor_operator_registry.register('reshape', P.Reshape)
262tensor_operator_registry.register('transpose', P.Transpose)
263tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
264tensor_operator_registry.register('matmul', P.MatMul)
265tensor_operator_registry.register('argmax', P.Argmax)
266tensor_operator_registry.register('cumsum', P.CumSum)
267tensor_operator_registry.register('reduce_max', P.ReduceMax)
268tensor_operator_registry.register('reduce_min', P.ReduceMin)
269tensor_operator_registry.register('maximum', P.Maximum)
270tensor_operator_registry.register('minimum', P.Minimum)
271tensor_operator_registry.register('fill', P.Fill)
272tensor_operator_registry.register('tile', P.Tile)
273tensor_operator_registry.register('logical_not', P.LogicalNot)
274tensor_operator_registry.register('sum', P.ReduceSum)
275tensor_operator_registry.register('split', P.Split)
276# ms cannot support Tensor(True) compare
277tensor_operator_registry.register('__eq__', equal)
278tensor_operator_registry.register('__ne__', not_equal)
279tensor_operator_registry.register('__neg__', neg_tensor)
280tensor_operator_registry.register('__lt__', tensor_lt)
281tensor_operator_registry.register('__le__', tensor_le)
282tensor_operator_registry.register('__gt__', tensor_gt)
283tensor_operator_registry.register('__ge__', tensor_ge)
284tensor_operator_registry.register('__logical_not__', logical_not)
285tensor_operator_registry.register('shape', shape)
286tensor_operator_registry.register('squeeze', squeeze)
287# support GE backend for no compare operators
288tensor_operator_registry.register('cast', cast)
289tensor_operator_registry.register('shape_mul', shape_mul)
290tensor_operator_registry.register('fill', fill)
291tensor_operator_registry.register('concatenate', P.Concat)
292tensor_operator_registry.register('eye', eye)
293tensor_operator_registry.register('reduce_sum', reduce_sum)
294tensor_operator_registry.register('tensor_slice', tensor_slice)
295tensor_operator_registry.register('select', select)
296tensor_operator_registry.register('gather_d', gather_d)
297tensor_operator_registry.register('gather_nd', gather_nd)
298tensor_operator_registry.register('stack', P.Stack)
299tensor_operator_registry.register('log', log)
300tensor_operator_registry.register('floor', floor)
301
302__all__ = [name for name in dir() if name[0] != "_"]
303__all__.remove('Primitive')
304