1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2021 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17 18"""The names of functional part are summarized here.""" 19 20from mindspore.common._register_for_tensor import tensor_operator_registry 21from mindspore.ops import _constants 22from .primitive import Primitive 23from . import operations as P 24from .operations import _grad_ops 25from .composite import GradOperation 26from .._c_expression import security 27 28typeof = Primitive('typeof') 29hastype = Primitive('hastype') 30cast = P.Cast() 31dtype = P.DType() 32isconstant = Primitive('is_constant') 33isconstant.set_const_prim(True) 34 35issubclass_ = P.IsSubClass() 36isinstance_ = P.IsInstance() 37eye = P.Eye() 38fill = P.Fill() 39tile = P.Tile() 40select = P.Select() 41size = P.Size() 42ones_like = P.OnesLike() 43shape = P.Shape() 44rank = P.Rank() 45reshape = P.Reshape() 46 47merge = P.Merge() 48geswitch = P.GeSwitch() 49addn = P.AddN() 50absolute = P.Abs() 51tensor_add = P.Add() 52add = tensor_add 53neg_tensor = P.Neg() 54tensor_lt = P.Less() 55less = tensor_lt 56tensor_le = P.LessEqual() 57le = tensor_le 58tensor_gt = P.Greater() 59gt = tensor_gt 60tensor_ge = P.GreaterEqual() 61ge = tensor_ge 62tensor_sub = P.Sub() 63sub = tensor_sub 64tensor_mul = P.Mul() 65mul = tensor_mul 66tensor_div = P.RealDiv() 67div = tensor_div 68tensor_floordiv = P.FloorDiv() 69floordiv = tensor_floordiv 70tensor_pow = P.Pow() 71pows = tensor_pow 72tensor_mod = P.FloorMod() 73floormod = tensor_mod 74tensor_exp = P.Exp() 75exp = tensor_exp 76tensor_expm1 = P.Expm1() 77tensor_slice = P.Slice() 78strided_slice = P.StridedSlice() 79same_type_shape = P.SameTypeShape() 80check_bprop = P.CheckBprop() 81equal = P.Equal() 82not_equal = P.NotEqual() 83isfinite = P.IsFinite() 84isnan = P.IsNan() 85assign_sub = P.AssignSub() 86assign_add = P.AssignAdd() 87assign = P.Assign() 88square = P.Square() 89sqrt = P.Sqrt() 90log = P.Log() 91reduce_sum = P.ReduceSum() 92reduce_max = P.ReduceMax() 93reduce_min = P.ReduceMin() 94reduce_mean = P.ReduceMean() 95reduce_prod = P.ReduceProd() 96tensor_slice = P.Slice() 97maximum = P.Maximum() 98minimum = P.Minimum() 99floor = P.Floor() 100logical_not = P.LogicalNot() 101logical_or = P.LogicalOr() 102logical_and = P.LogicalAnd() 103sin = P.Sin() 104cos = P.Cos() 105tan = P.Tan() 106asin = P.Asin() 107acos = P.ACos() 108atan = P.Atan() 109sinh = P.Sinh() 110cosh = P.Cosh() 111tanh = P.Tanh() 112asinh = P.Asinh() 113acosh = P.Acosh() 114atanh = P.Atanh() 115atan2 = P.Atan2() 116bitwise_and = P.BitwiseAnd() 117bitwise_or = P.BitwiseOr() 118bitwise_xor = P.BitwiseXor() 119invert = P.Invert() 120erf = P.Erf() 121erfc = P.Erfc() 122sort = P.Sort() 123tensor_range = P.Range() 124 125scalar_to_array = P.ScalarToArray() 126scalar_to_tensor = P.ScalarToTensor() 127tuple_to_array = P.TupleToArray() 128scalar_cast = P.ScalarCast() 129if not security.enable_security(): 130 print_ = P.Print() 131expand_dims = P.ExpandDims() 132transpose = P.Transpose() 133squeeze = P.Squeeze() 134scatter_nd = P.ScatterNd() 135gather = P.Gather() 136gather_d = P.GatherD() 137gather_nd = P.GatherNd() 138scatter_update = P.ScatterUpdate() 139tensor_scatter_update = P.TensorScatterUpdate() 140scatter_nd_update = P.ScatterNdUpdate() 141stack = P.Stack() 142 143 144def pack(x): 145 """Call stack in this pack function.""" 146 print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead" 147 ".") 148 return stack(x) 149 150 151partial = P.Partial() 152# depend: mount a node to another node 153depend = P.Depend() 154identity = P.identity() 155 156grad_first_parameter = GradOperation(get_all=False, get_by_list=False, sens_param=False) 157grad_all_parameters = GradOperation(get_all=True, get_by_list=False, sens_param=False) 158 159 160def grad(fn, grad_first_param=False): 161 """ 162 A wrapper function to generate the gradient function for the input function. 163 164 Args: 165 fn (Function): Function to do GradOperation. 166 grad_first_param (bool): If True, get the gradient with respect to first input. 167 If False, get all the gradients with respect to inputs. Default: False. 168 """ 169 if grad_first_param: 170 return grad_first_parameter(fn) 171 return grad_all_parameters(fn) 172 173 174tuple_setitem = Primitive('tuple_setitem') 175tuple_getitem = Primitive(_constants.kTupleGetItem) 176list_getitem = Primitive('list_getitem') 177list_setitem = Primitive('list_setitem') 178dict_getitem = Primitive('dict_getitem') 179dict_setitem = Primitive('dict_setitem') 180tuple_div = Primitive("tuple_div") 181tuple_len = Primitive("tuple_len") 182list_len = Primitive("list_len") 183tuple_reversed = Primitive("tuple_reversed") 184make_range = Primitive("make_range") 185make_tuple = Primitive('MakeTuple') 186make_dict = Primitive('make_dict') 187make_list = Primitive('make_list') 188make_slice = Primitive('make_slice') 189tuple_equal = Primitive("tuple_equal") 190list_equal = Primitive("list_equal") 191make_ref = Primitive("make_ref") 192 193scalar_add = Primitive(_constants.kScalarAdd) 194scalar_mul = Primitive(_constants.kScalarMul) 195scalar_sub = Primitive(_constants.kScalarSub) 196scalar_div = Primitive(_constants.kScalarDiv) 197scalar_floordiv = Primitive(_constants.kScalarFloordiv) 198scalar_log = Primitive('scalar_log') 199scalar_pow = Primitive(_constants.kScalarPow) 200scalar_gt = Primitive('scalar_gt') 201scalar_ge = Primitive('scalar_ge') 202scalar_le = Primitive('scalar_le') 203scalar_lt = Primitive('scalar_lt') 204scalar_eq = Primitive('scalar_eq') 205scalar_ne = Primitive('scalar_ne') 206scalar_uadd = Primitive(_constants.kScalarUadd) 207scalar_usub = Primitive(_constants.kScalarUsub) 208scalar_mod = Primitive(_constants.kScalarMod) 209string_eq = Primitive('string_equal') 210string_concat = Primitive('string_concat') 211bool_not = Primitive("bool_not") 212bool_or = Primitive("bool_or") 213bool_and = Primitive("bool_and") 214bool_eq = Primitive("bool_eq") 215logical_and = P.LogicalAnd() 216logical_or = P.LogicalOr() 217logical_not = P.LogicalNot() 218cumsum = P.CumSum() 219cumprod = P.CumProd() 220tensor_scatter_add = P.TensorScatterAdd() 221array_to_scalar = Primitive('array_to_scalar') 222is_ = Primitive("is_") 223is_not = Primitive("is_not") 224in_dict = Primitive("in_dict") 225not_in_dict = Primitive("not_in_dict") 226mixed_precision_cast = Primitive("mixed_precision_cast") 227broadcast_gradient_args = Primitive('BroadcastGradientArgs') 228array_reduce = Primitive('array_reduce') 229zeros_like = P.ZerosLike() 230distribute = Primitive('distribute') 231embed = Primitive('embed') 232ref_to_embed = _grad_ops.RefToEmbed() 233env_setitem = Primitive('env_setitem') 234env_getitem = Primitive('env_getitem') 235env_add = Primitive('env_add') 236J = Primitive('J') 237switch = Primitive('Switch') 238switch_layer = Primitive('switch_layer') 239# for sum bprop 240reduced_shape = Primitive("reduced_shape") 241# shape_mul:input must be shape multiply elements in tuple(shape) 242shape_mul = Primitive("shape_mul") 243# a primitive to compare between tuple. 244stop_gradient = Primitive("stop_gradient") 245 246make_row_tensor = Primitive('MakeRowTensor') 247row_tensor_get_values = Primitive('RowTensorGetValues') 248row_tensor_get_indices = Primitive('RowTensorGetIndices') 249row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') 250row_tensor_add = Primitive('RowTensorAdd') 251 252make_sparse_tensor = Primitive('MakeSparseTensor') 253sparse_tensor_get_values = Primitive('SparseTensorGetValues') 254sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') 255sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') 256 257tensor_operator_registry.register('all', P.ReduceAll) 258tensor_operator_registry.register('any', P.ReduceAny) 259tensor_operator_registry.register('abs', P.Abs) 260tensor_operator_registry.register('mean', P.ReduceMean) 261tensor_operator_registry.register('reshape', P.Reshape) 262tensor_operator_registry.register('transpose', P.Transpose) 263tensor_operator_registry.register('broadcast_to', P.BroadcastTo) 264tensor_operator_registry.register('matmul', P.MatMul) 265tensor_operator_registry.register('argmax', P.Argmax) 266tensor_operator_registry.register('cumsum', P.CumSum) 267tensor_operator_registry.register('reduce_max', P.ReduceMax) 268tensor_operator_registry.register('reduce_min', P.ReduceMin) 269tensor_operator_registry.register('maximum', P.Maximum) 270tensor_operator_registry.register('minimum', P.Minimum) 271tensor_operator_registry.register('fill', P.Fill) 272tensor_operator_registry.register('tile', P.Tile) 273tensor_operator_registry.register('logical_not', P.LogicalNot) 274tensor_operator_registry.register('sum', P.ReduceSum) 275tensor_operator_registry.register('split', P.Split) 276# ms cannot support Tensor(True) compare 277tensor_operator_registry.register('__eq__', equal) 278tensor_operator_registry.register('__ne__', not_equal) 279tensor_operator_registry.register('__neg__', neg_tensor) 280tensor_operator_registry.register('__lt__', tensor_lt) 281tensor_operator_registry.register('__le__', tensor_le) 282tensor_operator_registry.register('__gt__', tensor_gt) 283tensor_operator_registry.register('__ge__', tensor_ge) 284tensor_operator_registry.register('__logical_not__', logical_not) 285tensor_operator_registry.register('shape', shape) 286tensor_operator_registry.register('squeeze', squeeze) 287# support GE backend for no compare operators 288tensor_operator_registry.register('cast', cast) 289tensor_operator_registry.register('shape_mul', shape_mul) 290tensor_operator_registry.register('fill', fill) 291tensor_operator_registry.register('concatenate', P.Concat) 292tensor_operator_registry.register('eye', eye) 293tensor_operator_registry.register('reduce_sum', reduce_sum) 294tensor_operator_registry.register('tensor_slice', tensor_slice) 295tensor_operator_registry.register('select', select) 296tensor_operator_registry.register('gather_d', gather_d) 297tensor_operator_registry.register('gather_nd', gather_nd) 298tensor_operator_registry.register('stack', P.Stack) 299tensor_operator_registry.register('log', log) 300tensor_operator_registry.register('floor', floor) 301 302__all__ = [name for name in dir() if name[0] != "_"] 303__all__.remove('Primitive') 304