1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 2# 3# Copyright 2021-2022 Huawei Technologies Co., Ltd 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================ 17"""The names of functional part are summarized here.""" 18 19from mindspore.common._register_for_tensor import tensor_operator_registry 20from mindspore.ops import _constants 21from mindspore.ops.function import * 22from mindspore.ops.function.array_func import narrow, flatten 23from mindspore.ops.function.math_func import all, argmax_ext 24from mindspore.ops.function.random_func import uniform_ext 25from mindspore.ops import operations as P 26from mindspore.ops.operations import array_ops 27from mindspore.ops.operations._sequence_ops import TensorToTuple 28from mindspore.ops.primitive import Primitive 29from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops, _sequence_ops, other_ops 30from mindspore.ops.operations.math_ops import Median 31from mindspore.ops.operations.array_ops import UniqueConsecutive 32from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D 33from mindspore.ops.operations.math_ops import Roll 34from mindspore.ops.composite.math_ops import mm 35from mindspore.ops.function.math_func import dot 36from mindspore.ops import auto_generate 37from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum 38from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul,\ 39 scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow,\ 40 scalar_uadd, scalar_usub, flash_attention_score 41 42typeof = Primitive('typeof') 43hastype = Primitive('hastype') 44cast = P.Cast() 45dtype = P.DType() 46isconstant = _inner_ops.IsConstant() 47isconstant.set_const_prim(True) 48merge = P.Merge() 49geswitch = P.GeSwitch() 50reduce_sum = P.ReduceSum() 51reduce_max = P.ReduceMax() 52reduce_min = P.ReduceMin() 53reduce_mean = P.ReduceMean() 54tensor_range = P.Range() 55tensor_scatter_update = P.TensorScatterUpdate() 56scatter_nd_update = P.ScatterNdUpdate() 57mixed_precision_cast = _inner_ops.MixedPrecisionCast() 58_py_interpret = other_ops.PyInterpret() 59_dtype_to_enum = DtypeToEnum() 60 61# Dynamic shape 62is_sequence_value_unknown = Primitive("IsShapeUnKnown") 63is_sequence_shape_unknown = Primitive("IsDimUnKnown") 64is_dynamic_sequence_element_unknown = Primitive("IsElementUnknown") 65is_tensor_bool_cond = Primitive("IsTensorBoolCond") 66 67partial = P.Partial() 68# depend: mount a node to another node 69depend = P.Depend() 70identity = P.identity() 71# tuple/list/scalar ops 72tuple_setitem = Primitive('tuple_setitem') 73tuple_getitem = Primitive(_constants.kTupleGetItem) 74list_getitem = Primitive('list_getitem') 75list_setitem = Primitive('list_setitem') 76dict_getitem = Primitive('dict_getitem') 77dict_setitem = Primitive('dict_setitem') 78tuple_div = Primitive("tuple_div") 79tuple_len = Primitive("sequence_len") 80list_len = Primitive("sequence_len") 81tuple_reversed = Primitive("tuple_reversed") 82make_range = Primitive("make_range") 83make_tuple = Primitive('MakeTuple') 84make_dict = Primitive('make_dict') 85make_list = Primitive('make_list') 86make_slice = Primitive('make_slice') 87tuple_equal = Primitive("tuple_equal") 88list_equal = Primitive("list_equal") 89scalar_ne = Primitive('scalar_ne') 90string_eq = Primitive('string_eq') 91string_concat = Primitive('string_concat') 92bool_not = Primitive('BoolNot') 93bool_or = Primitive("bool_or") 94bool_and = Primitive("bool_and") 95bool_eq = Primitive("bool_eq") 96array_to_scalar = Primitive('array_to_scalar') 97is_ = Primitive("is_") 98is_not = Primitive("is_not") 99in_dict = Primitive("in_dict") 100not_in_dict = Primitive("not_in_dict") 101broadcast_gradient_args = Primitive('BroadcastGradientArgs') 102array_reduce = Primitive('array_reduce') 103distribute = Primitive('distribute') 104embed = Primitive('embed') 105ref_to_embed = _grad_ops.RefToEmbed() 106environ_create = Primitive('EnvironCreate') 107environ_set = Primitive('EnvironSet') 108environ_get = Primitive('EnrironGet') 109environ_add = Primitive('EnvironAdd') 110J = Primitive('J') 111SliceGetItem = Primitive("SliceGetItem") 112switch = Primitive('Switch') 113switch_layer = Primitive('switch_layer') 114# for sum bprop 115reduced_shape = Primitive("reduced_shape") 116# shape_mul:input must be shape multiply elements in tuple(shape) 117shape_mul = _sequence_ops.shape_mul() 118 119setattr(tensor_operator_registry, 'tuple_to_tensor', _sequence_ops.TupleToTensor) 120setattr(tensor_operator_registry, 'add', add) 121setattr(tensor_operator_registry, 'softmax', softmax) 122setattr(tensor_operator_registry, 'addr', addr) 123setattr(tensor_operator_registry, 'addcdiv', addcdiv) 124setattr(tensor_operator_registry, 'addcmul', addcmul) 125setattr(tensor_operator_registry, 'all', all) 126setattr(tensor_operator_registry, 'angle', angle) 127setattr(tensor_operator_registry, 'any', any) 128setattr(tensor_operator_registry, 'atan2', atan2) 129setattr(tensor_operator_registry, 'abs', abs) 130setattr(tensor_operator_registry, 'baddbmm', baddbmm) 131setattr(tensor_operator_registry, 'geqrf', geqrf) 132setattr(tensor_operator_registry, 'histc', histc) 133setattr(tensor_operator_registry, 'real', real) 134setattr(tensor_operator_registry, 'reciprocal', reciprocal) 135setattr(tensor_operator_registry, 'rsqrt', rsqrt) 136setattr(tensor_operator_registry, 'bincount', bincount) 137setattr(tensor_operator_registry, 'slogdet', slogdet) 138setattr(tensor_operator_registry, 'trace', trace) 139setattr(tensor_operator_registry, 'tril', tril) 140setattr(tensor_operator_registry, 'chunk', chunk) 141setattr(tensor_operator_registry, 'count_nonzero', count_nonzero) 142setattr(tensor_operator_registry, 'sqrt', sqrt) 143setattr(tensor_operator_registry, 'square', square) 144setattr(tensor_operator_registry, 'sub', sub) 145setattr(tensor_operator_registry, 'triu', triu) 146setattr(tensor_operator_registry, 'tan', tan) 147setattr(tensor_operator_registry, 't', t) 148setattr(tensor_operator_registry, 'cauchy', P.Cauchy) 149setattr(tensor_operator_registry, 'log_normal', P.LogNormalReverse) 150setattr(tensor_operator_registry, 'acos', acos) 151setattr(tensor_operator_registry, 'cos', cos) 152setattr(tensor_operator_registry, 'acosh', acosh) 153setattr(tensor_operator_registry, 'cosh', cosh) 154setattr(tensor_operator_registry, 'cov', cov) 155setattr(tensor_operator_registry, 'asin', asin) 156setattr(tensor_operator_registry, 'sin', sin) 157setattr(tensor_operator_registry, 'sinc', sinc) 158setattr(tensor_operator_registry, 'pow', pow) 159setattr(tensor_operator_registry, 'negative', neg) 160setattr(tensor_operator_registry, 'amin', amin) 161setattr(tensor_operator_registry, 'amax', amax) 162setattr(tensor_operator_registry, 'aminmax', aminmax) 163setattr(tensor_operator_registry, 'mean', mean) 164setattr(tensor_operator_registry, 'prod', prod) 165setattr(tensor_operator_registry, 'round', round) 166setattr(tensor_operator_registry, 'reshape', reshape) 167setattr(tensor_operator_registry, 'reverse', reverse) 168setattr(tensor_operator_registry, 'reverse_sequence', reverse_sequence) 169setattr(tensor_operator_registry, 'xlogy', xlogy) 170setattr(tensor_operator_registry, 'flatten', flatten) 171setattr(tensor_operator_registry, 'transpose', transpose) 172setattr(tensor_operator_registry, 'broadcast_to', broadcast_to) 173setattr(tensor_operator_registry, 'matmul', matmul) 174setattr(tensor_operator_registry, 'inner', inner) 175setattr(tensor_operator_registry, 'xdivy', xdivy) 176setattr(tensor_operator_registry, 'argmax', argmax) 177setattr(tensor_operator_registry, 'argmin', argmin) 178setattr(tensor_operator_registry, 'cumsum', P.CumSum) 179setattr(tensor_operator_registry, 'cummin', cummin) 180setattr(tensor_operator_registry, 'cummax', cummax) 181setattr(tensor_operator_registry, 'nelement', numel) 182setattr(tensor_operator_registry, 'numel', numel) 183setattr(tensor_operator_registry, 'positive', positive) 184setattr(tensor_operator_registry, 'permute', permute) 185setattr(tensor_operator_registry, 'remainder', remainder) 186setattr(tensor_operator_registry, 'index_fill', index_fill) 187setattr(tensor_operator_registry, 'index_select', index_select) 188setattr(tensor_operator_registry, 'flip', flip) 189setattr(tensor_operator_registry, 'fliplr', fliplr) 190setattr(tensor_operator_registry, 'flipud', flipud) 191setattr(tensor_operator_registry, 'float_power', float_power) 192setattr(tensor_operator_registry, 'fmax', fmax) 193setattr(tensor_operator_registry, 'fmin', fmin) 194setattr(tensor_operator_registry, 'fmod', fmod) 195setattr(tensor_operator_registry, 'is_floating_point', is_floating_point) 196setattr(tensor_operator_registry, 'bitwise_and', bitwise_and) 197setattr(tensor_operator_registry, 'bitwise_or', bitwise_or) 198setattr(tensor_operator_registry, 'bitwise_xor', bitwise_xor) 199setattr(tensor_operator_registry, 'bitwise_left_shift', bitwise_left_shift) 200setattr(tensor_operator_registry, 'bitwise_right_shift', bitwise_right_shift) 201setattr(tensor_operator_registry, 'ger', ger) 202setattr(tensor_operator_registry, 'reduce_max', P.ReduceMax) 203setattr(tensor_operator_registry, 'reduce_min', P.ReduceMin) 204setattr(tensor_operator_registry, 'random_categorical', random_categorical) 205setattr(tensor_operator_registry, 'mirror_pad', P.MirrorPad) 206setattr(tensor_operator_registry, 'minimum', minimum) 207setattr(tensor_operator_registry, 'matrix_power', matrix_power) 208setattr(tensor_operator_registry, 'det', det) 209setattr(tensor_operator_registry, 'dot', dot) 210setattr(tensor_operator_registry, 'outer', outer) 211setattr(tensor_operator_registry, 'log1p', log1p) 212setattr(tensor_operator_registry, 'logdet', logdet) 213setattr(tensor_operator_registry, 'log_matrix_determinant', log_matrix_determinant) 214setattr(tensor_operator_registry, 'matrix_determinant', matrix_determinant) 215setattr(tensor_operator_registry, 'ceil', ceil) 216setattr(tensor_operator_registry, 'fillv2', P.FillV2) 217setattr(tensor_operator_registry, 'tile', tile) 218setattr(tensor_operator_registry, 'logit', logit) 219setattr(tensor_operator_registry, 'sum', sum) 220setattr(tensor_operator_registry, 'split', split) 221setattr(tensor_operator_registry, 'tensor_split', tensor_split) 222setattr(tensor_operator_registry, 'vsplit', vsplit) 223setattr(tensor_operator_registry, 'hsplit', hsplit) 224setattr(tensor_operator_registry, 'dsplit', dsplit) 225setattr(tensor_operator_registry, 'zeros_like', zeros_like) 226setattr(tensor_operator_registry, 'scalar_to_tensor', scalar_to_tensor) 227setattr(tensor_operator_registry, 'stop_gradient', stop_gradient) 228setattr(tensor_operator_registry, 'masked_fill', masked_fill) 229setattr(tensor_operator_registry, 'masked_select', masked_select) 230setattr(tensor_operator_registry, 'nonzero', nonzero) 231setattr(tensor_operator_registry, 'i0', i0) 232setattr(tensor_operator_registry, 'isclose', isclose) 233setattr(tensor_operator_registry, 'isneginf', isneginf) 234setattr(tensor_operator_registry, 'isposinf', isposinf) 235setattr(tensor_operator_registry, 'isreal', isreal) 236setattr(tensor_operator_registry, 'inv', inv) 237setattr(tensor_operator_registry, 'digamma', digamma) 238setattr(tensor_operator_registry, 'lgamma', lgamma) 239setattr(tensor_operator_registry, 'logaddexp', logaddexp) 240setattr(tensor_operator_registry, 'logaddexp2', logaddexp2) 241setattr(tensor_operator_registry, 'logcumsumexp', logcumsumexp) 242setattr(tensor_operator_registry, 'logsumexp', logsumexp) 243setattr(tensor_operator_registry, 'inverse', inverse) 244setattr(tensor_operator_registry, 'invert', invert) 245setattr(tensor_operator_registry, 'hardshrink', hardshrink) 246setattr(tensor_operator_registry, 'heaviside', heaviside) 247setattr(tensor_operator_registry, 'hypot', hypot) 248setattr(tensor_operator_registry, 'searchsorted', P.SearchSorted) 249setattr(tensor_operator_registry, 'soft_shrink', soft_shrink) 250setattr(tensor_operator_registry, 'svd', linalg_ops.Svd) 251setattr(tensor_operator_registry, 'diag', diag) 252setattr(tensor_operator_registry, 'diagflat', diagflat) 253setattr(tensor_operator_registry, 'unique_consecutive', UniqueConsecutive) 254setattr(tensor_operator_registry, 'unique_with_pad', unique_with_pad) 255setattr(tensor_operator_registry, 'inplace_update', inplace_update) 256setattr(tensor_operator_registry, 'col2im', col2im) 257setattr(tensor_operator_registry, 'standard_laplace', P.StandardLaplace) 258setattr(tensor_operator_registry, 'erf', erf) 259setattr(tensor_operator_registry, 'erfc', erfc) 260setattr(tensor_operator_registry, 'standard_normal', P.StandardNormal) 261setattr(tensor_operator_registry, 'sigmoid', sigmoid) 262setattr(tensor_operator_registry, 'median', Median) 263setattr(tensor_operator_registry, 'tanh', tanh) 264setattr(tensor_operator_registry, 'exp', exp) 265setattr(tensor_operator_registry, 'addbmm', addbmm) 266setattr(tensor_operator_registry, 'addmm', addmm) 267setattr(tensor_operator_registry, 'addmv', addmv) 268setattr(tensor_operator_registry, 'adjoint', adjoint) 269setattr(tensor_operator_registry, 'asinh', asinh) 270setattr(tensor_operator_registry, 'arcsinh', arcsinh) 271setattr(tensor_operator_registry, 'atan', atan) 272setattr(tensor_operator_registry, 'atanh', atanh) 273setattr(tensor_operator_registry, 'arctanh', arctanh) 274setattr(tensor_operator_registry, 'bmm', bmm) 275setattr(tensor_operator_registry, 'conj', conj) 276setattr(tensor_operator_registry, 'cross', cross) 277setattr(tensor_operator_registry, 'erfinv', erfinv) 278setattr(tensor_operator_registry, 'less_equal', less_equal) 279setattr(tensor_operator_registry, 'lcm', lcm) 280setattr(tensor_operator_registry, 'ldexp', ldexp) 281setattr(tensor_operator_registry, 'clamp', clamp) 282setattr(tensor_operator_registry, 'fold', fold) 283setattr(tensor_operator_registry, 'unfold', unfold) 284setattr(tensor_operator_registry, 'diagonal', diagonal) 285setattr(tensor_operator_registry, 'diagonal_scatter', diagonal_scatter) 286setattr(tensor_operator_registry, 'index_add', index_add) 287setattr(tensor_operator_registry, 'greater', greater) 288setattr(tensor_operator_registry, 'greater_equal', greater_equal) 289setattr(tensor_operator_registry, 'igamma', igamma) 290setattr(tensor_operator_registry, 'igammac', igammac) 291setattr(tensor_operator_registry, 'lu_solve', lu_solve) 292setattr(tensor_operator_registry, 'nextafter', nextafter) 293setattr(tensor_operator_registry, 'qr', qr) 294setattr(tensor_operator_registry, 'ormqr', ormqr) 295setattr(tensor_operator_registry, 'masked_scatter', array_ops.MaskedScatter) 296setattr(tensor_operator_registry, 'index_put', array_ops.IndexPut) 297setattr(tensor_operator_registry, 'quantile', quantile) 298setattr(tensor_operator_registry, 'nanquantile', nanquantile) 299setattr(tensor_operator_registry, 'orgqr', orgqr) 300# ms cannot support Tensor(True) compare 301setattr(tensor_operator_registry, '__eq__', equal) 302setattr(tensor_operator_registry, '__ne__', not_equal) 303setattr(tensor_operator_registry, '__neg__', neg) 304setattr(tensor_operator_registry, '__lt__', tensor_lt) 305setattr(tensor_operator_registry, '__le__', tensor_le) 306setattr(tensor_operator_registry, '__gt__', tensor_gt) 307setattr(tensor_operator_registry, '__ge__', tensor_ge) 308setattr(tensor_operator_registry, '__logical_not__', logical_not) 309setattr(tensor_operator_registry, 'gt', gt) 310setattr(tensor_operator_registry, 'ge', ge) 311setattr(tensor_operator_registry, 'shape', shape) 312setattr(tensor_operator_registry, 'squeeze', squeeze) 313setattr(tensor_operator_registry, 'unsqueeze', unsqueeze) 314setattr(tensor_operator_registry, 'expand_dims', expand_dims) 315setattr(tensor_operator_registry, 'contiguous', auto_generate.contiguous) 316# support GE backend for no compare operators 317setattr(tensor_operator_registry, 'cast', cast) 318setattr(tensor_operator_registry, 'shape_mul', shape_mul) 319setattr(tensor_operator_registry, 'concatenate', concat) 320setattr(tensor_operator_registry, 'fill', fill) 321setattr(tensor_operator_registry, 'fills', fills) 322setattr(tensor_operator_registry, 'fill_diagonal', P.FillDiagonal) 323setattr(tensor_operator_registry, 'eye', eye) 324setattr(tensor_operator_registry, 'eigvals', eigvals) 325setattr(tensor_operator_registry, 'reduce_sum', reduce_sum) 326setattr(tensor_operator_registry, 'reducesum', P.ReduceSum) 327setattr(tensor_operator_registry, 'tensor_slice', tensor_slice) 328setattr(tensor_operator_registry, 'select', select) 329setattr(tensor_operator_registry, 'uniform', uniform_ext) 330setattr(tensor_operator_registry, 'gather', gather) 331setattr(tensor_operator_registry, 'gather_d', gather_d) 332setattr(tensor_operator_registry, 'gather_elements', gather_elements) 333setattr(tensor_operator_registry, 'gather_nd', gather_nd) 334setattr(tensor_operator_registry, 'stack', stack) 335setattr(tensor_operator_registry, 'unstack', unstack) 336setattr(tensor_operator_registry, 'unbind', unstack) 337setattr(tensor_operator_registry, 'log', log) 338setattr(tensor_operator_registry, 'log10', log10) 339setattr(tensor_operator_registry, 'log2', log2) 340setattr(tensor_operator_registry, 'lerp', lerp) 341setattr(tensor_operator_registry, 'floor', floor) 342setattr(tensor_operator_registry, 'floor_divide', floor_divide) 343# support sparse tensor operators 344setattr(tensor_operator_registry, 'csr_add', csr_add) 345setattr(tensor_operator_registry, 'csr_mul', csr_mul) 346setattr(tensor_operator_registry, 'csr2coo', csr2coo) 347setattr(tensor_operator_registry, 'coo2csr', coo2csr) 348setattr(tensor_operator_registry, 'csr_div', csr_div) 349setattr(tensor_operator_registry, 'csr_mv', csr_mv) 350setattr(tensor_operator_registry, 'csr_mm_akg', _csr_ops.CSRMM) 351setattr(tensor_operator_registry, 'csr_mm', csr_mm) 352setattr(tensor_operator_registry, 'csr_reduce_sum', csr_reduce_sum) 353setattr(tensor_operator_registry, 'dense_to_sparse_csr', dense_to_sparse_csr) 354setattr(tensor_operator_registry, 'dense_to_sparse_coo', dense_to_sparse_coo) 355setattr(tensor_operator_registry, 'csr_to_dense', csr_to_dense) 356setattr(tensor_operator_registry, 'narrow', narrow) 357setattr(tensor_operator_registry, 'sort', sort) 358setattr(tensor_operator_registry, 'argsort', argsort) 359setattr(tensor_operator_registry, 'msort', msort) 360setattr(tensor_operator_registry, 'mm', mm) 361setattr(tensor_operator_registry, 'nan_to_num', nan_to_num) 362setattr(tensor_operator_registry, 'nansum', nansum) 363setattr(tensor_operator_registry, 'nanmean', nanmean) 364setattr(tensor_operator_registry, 'nanmedian', nanmedian) 365setattr(tensor_operator_registry, 'csr_to_coo', csr_to_coo) 366setattr(tensor_operator_registry, 'zeros', zeros) 367setattr(tensor_operator_registry, 'ones', ones) 368setattr(tensor_operator_registry, 'unsorted_segment_min', unsorted_segment_min) 369setattr(tensor_operator_registry, 'unsorted_segment_max', unsorted_segment_max) 370setattr(tensor_operator_registry, 'unsorted_segment_prod', unsorted_segment_prod) 371setattr(tensor_operator_registry, 'scatter', scatter) 372setattr(tensor_operator_registry, 'tensor_scatter_update', tensor_scatter_update) 373setattr(tensor_operator_registry, 'tensor_scatter_mul', tensor_scatter_mul) 374setattr(tensor_operator_registry, 'tensor_scatter_div', tensor_scatter_div) 375setattr(tensor_operator_registry, 'tensor_scatter_min', tensor_scatter_min) 376setattr(tensor_operator_registry, 'tensor_scatter_max', tensor_scatter_max) 377setattr(tensor_operator_registry, 'tensor_scatter_sub', tensor_scatter_sub) 378setattr(tensor_operator_registry, 'tensor_scatter_add', tensor_scatter_add) 379setattr(tensor_operator_registry, 'slice_scatter', slice_scatter) 380setattr(tensor_operator_registry, 'select_scatter', select_scatter) 381setattr(tensor_operator_registry, 'bernoulli', bernoulli) 382setattr(tensor_operator_registry, 'poisson', P.Poisson) 383setattr(tensor_operator_registry, 'randperm', P.Randperm) 384setattr(tensor_operator_registry, 'multinomial', multinomial) 385setattr(tensor_operator_registry, 'norm', norm) 386setattr(tensor_operator_registry, 'renorm', renorm) 387setattr(tensor_operator_registry, 'adaptive_max_pool2d', AdaptiveMaxPool2D) 388setattr(tensor_operator_registry, 'coalesce', coalesce) 389setattr(tensor_operator_registry, 'argmax_with_value', max) 390setattr(tensor_operator_registry, 'argmin_with_value', min) 391setattr(tensor_operator_registry, 'argwhere', argwhere) 392setattr(tensor_operator_registry, 'coo_add', coo_add) 393setattr(tensor_operator_registry, 'topk', topk) 394setattr(tensor_operator_registry, 'isfinite', isfinite) 395setattr(tensor_operator_registry, 'to', cast) 396setattr(tensor_operator_registry, 'bool', cast) 397setattr(tensor_operator_registry, 'float', cast) 398setattr(tensor_operator_registry, 'half', cast) 399setattr(tensor_operator_registry, 'int', cast) 400setattr(tensor_operator_registry, 'long', cast) 401setattr(tensor_operator_registry, 'cholesky', cholesky) 402setattr(tensor_operator_registry, 'cholesky_inverse', cholesky_inverse) 403setattr(tensor_operator_registry, 'cholesky_solve', cholesky_solve) 404setattr(tensor_operator_registry, 'expand', broadcast_to) 405setattr(tensor_operator_registry, 'tensortotuple', TensorToTuple) 406setattr(tensor_operator_registry, 'cumprod', cumprod) 407setattr(tensor_operator_registry, 'diff', diff) 408setattr(tensor_operator_registry, 'div', div) 409setattr(tensor_operator_registry, 'equal', equal) 410setattr(tensor_operator_registry, 'expm1', expm1) 411setattr(tensor_operator_registry, 'frac', frac) 412setattr(tensor_operator_registry, 'isinf', isinf) 413setattr(tensor_operator_registry, 'isnan', isnan) 414setattr(tensor_operator_registry, 'is_complex', is_complex) 415setattr(tensor_operator_registry, 'le', le) 416setattr(tensor_operator_registry, 'less', less) 417setattr(tensor_operator_registry, 'logical_and', logical_and) 418setattr(tensor_operator_registry, 'logical_not', logical_not) 419setattr(tensor_operator_registry, 'logical_or', logical_or) 420setattr(tensor_operator_registry, 'logical_xor', logical_xor) 421setattr(tensor_operator_registry, 'lstsq', lstsq) 422setattr(tensor_operator_registry, 'mvlgamma', mvlgamma) 423setattr(tensor_operator_registry, 'maximum', maximum) 424setattr(tensor_operator_registry, 'max', max) 425setattr(tensor_operator_registry, 'min', min) 426setattr(tensor_operator_registry, 'mul', mul) 427setattr(tensor_operator_registry, 'multiply', multiply) 428setattr(tensor_operator_registry, 'moveaxis', moveaxis) 429setattr(tensor_operator_registry, 'movedim', movedim) 430setattr(tensor_operator_registry, 'neg', neg) 431setattr(tensor_operator_registry, 'ne', ne) 432setattr(tensor_operator_registry, 'not_equal', not_equal) 433setattr(tensor_operator_registry, 'sgn', sgn) 434setattr(tensor_operator_registry, 'sign', sign) 435setattr(tensor_operator_registry, 'signbit', signbit) 436setattr(tensor_operator_registry, 'sinh', sinh) 437setattr(tensor_operator_registry, 'trunc', trunc) 438setattr(tensor_operator_registry, 'where', where) 439setattr(tensor_operator_registry, 'imag', imag) 440setattr(tensor_operator_registry, 'repeat_interleave', repeat_interleave) 441setattr(tensor_operator_registry, 'rad2deg', rad2deg) 442setattr(tensor_operator_registry, 'deg2rad', deg2rad) 443setattr(tensor_operator_registry, 'copysign', copysign) 444setattr(tensor_operator_registry, 'roll', Roll) 445setattr(tensor_operator_registry, 'rot90', rot90) 446setattr(tensor_operator_registry, 'swapaxes', swapaxes) 447setattr(tensor_operator_registry, 'swapdims', swapdims) 448setattr(tensor_operator_registry, 'repeat_elements', repeat_elements) 449setattr(tensor_operator_registry, 'top_k', top_k) 450 451__all__ = [name for name in dir() if name[0] != "_"] 452__all__.remove('Primitive') 453__all__.remove('argmax_ext') 454__all__.remove('uniform_ext') 455