• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2021-2022 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""The names of functional part are summarized here."""
18
19from mindspore.common._register_for_tensor import tensor_operator_registry
20from mindspore.ops import _constants
21from mindspore.ops.function import *
22from mindspore.ops.function.array_func import narrow, flatten
23from mindspore.ops.function.math_func import all, argmax_ext
24from mindspore.ops.function.random_func import uniform_ext
25from mindspore.ops import operations as P
26from mindspore.ops.operations import array_ops
27from mindspore.ops.operations._sequence_ops import TensorToTuple
28from mindspore.ops.primitive import Primitive
29from mindspore.ops.operations import _grad_ops, _csr_ops, _inner_ops, linalg_ops, _sequence_ops, other_ops
30from mindspore.ops.operations.math_ops import Median
31from mindspore.ops.operations.array_ops import UniqueConsecutive
32from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
33from mindspore.ops.operations.math_ops import Roll
34from mindspore.ops.composite.math_ops import mm
35from mindspore.ops.function.math_func import dot
36from mindspore.ops import auto_generate
37from mindspore.ops_generate.gen_ops_inner_prim import DtypeToEnum
38from mindspore.ops.operations.manually_defined.ops_def import scalar_div, scalar_mod, scalar_add, scalar_mul,\
39    scalar_sub, scalar_gt, scalar_ge, scalar_le, scalar_lt, scalar_eq, scalar_floordiv, scalar_log, scalar_pow,\
40    scalar_uadd, scalar_usub, flash_attention_score
41
42typeof = Primitive('typeof')
43hastype = Primitive('hastype')
44cast = P.Cast()
45dtype = P.DType()
46isconstant = _inner_ops.IsConstant()
47isconstant.set_const_prim(True)
48merge = P.Merge()
49geswitch = P.GeSwitch()
50reduce_sum = P.ReduceSum()
51reduce_max = P.ReduceMax()
52reduce_min = P.ReduceMin()
53reduce_mean = P.ReduceMean()
54tensor_range = P.Range()
55tensor_scatter_update = P.TensorScatterUpdate()
56scatter_nd_update = P.ScatterNdUpdate()
57mixed_precision_cast = _inner_ops.MixedPrecisionCast()
58_py_interpret = other_ops.PyInterpret()
59_dtype_to_enum = DtypeToEnum()
60
61# Dynamic shape
62is_sequence_value_unknown = Primitive("IsShapeUnKnown")
63is_sequence_shape_unknown = Primitive("IsDimUnKnown")
64is_dynamic_sequence_element_unknown = Primitive("IsElementUnknown")
65is_tensor_bool_cond = Primitive("IsTensorBoolCond")
66
67partial = P.Partial()
68# depend: mount a node to another node
69depend = P.Depend()
70identity = P.identity()
71# tuple/list/scalar ops
72tuple_setitem = Primitive('tuple_setitem')
73tuple_getitem = Primitive(_constants.kTupleGetItem)
74list_getitem = Primitive('list_getitem')
75list_setitem = Primitive('list_setitem')
76dict_getitem = Primitive('dict_getitem')
77dict_setitem = Primitive('dict_setitem')
78tuple_div = Primitive("tuple_div")
79tuple_len = Primitive("sequence_len")
80list_len = Primitive("sequence_len")
81tuple_reversed = Primitive("tuple_reversed")
82make_range = Primitive("make_range")
83make_tuple = Primitive('MakeTuple')
84make_dict = Primitive('make_dict')
85make_list = Primitive('make_list')
86make_slice = Primitive('make_slice')
87tuple_equal = Primitive("tuple_equal")
88list_equal = Primitive("list_equal")
89scalar_ne = Primitive('scalar_ne')
90string_eq = Primitive('string_eq')
91string_concat = Primitive('string_concat')
92bool_not = Primitive('BoolNot')
93bool_or = Primitive("bool_or")
94bool_and = Primitive("bool_and")
95bool_eq = Primitive("bool_eq")
96array_to_scalar = Primitive('array_to_scalar')
97is_ = Primitive("is_")
98is_not = Primitive("is_not")
99in_dict = Primitive("in_dict")
100not_in_dict = Primitive("not_in_dict")
101broadcast_gradient_args = Primitive('BroadcastGradientArgs')
102array_reduce = Primitive('array_reduce')
103distribute = Primitive('distribute')
104embed = Primitive('embed')
105ref_to_embed = _grad_ops.RefToEmbed()
106environ_create = Primitive('EnvironCreate')
107environ_set = Primitive('EnvironSet')
108environ_get = Primitive('EnrironGet')
109environ_add = Primitive('EnvironAdd')
110J = Primitive('J')
111SliceGetItem = Primitive("SliceGetItem")
112switch = Primitive('Switch')
113switch_layer = Primitive('switch_layer')
114# for sum bprop
115reduced_shape = Primitive("reduced_shape")
116# shape_mul:input must be shape multiply elements in tuple(shape)
117shape_mul = _sequence_ops.shape_mul()
118
119setattr(tensor_operator_registry, 'tuple_to_tensor', _sequence_ops.TupleToTensor)
120setattr(tensor_operator_registry, 'add', add)
121setattr(tensor_operator_registry, 'softmax', softmax)
122setattr(tensor_operator_registry, 'addr', addr)
123setattr(tensor_operator_registry, 'addcdiv', addcdiv)
124setattr(tensor_operator_registry, 'addcmul', addcmul)
125setattr(tensor_operator_registry, 'all', all)
126setattr(tensor_operator_registry, 'angle', angle)
127setattr(tensor_operator_registry, 'any', any)
128setattr(tensor_operator_registry, 'atan2', atan2)
129setattr(tensor_operator_registry, 'abs', abs)
130setattr(tensor_operator_registry, 'baddbmm', baddbmm)
131setattr(tensor_operator_registry, 'geqrf', geqrf)
132setattr(tensor_operator_registry, 'histc', histc)
133setattr(tensor_operator_registry, 'real', real)
134setattr(tensor_operator_registry, 'reciprocal', reciprocal)
135setattr(tensor_operator_registry, 'rsqrt', rsqrt)
136setattr(tensor_operator_registry, 'bincount', bincount)
137setattr(tensor_operator_registry, 'slogdet', slogdet)
138setattr(tensor_operator_registry, 'trace', trace)
139setattr(tensor_operator_registry, 'tril', tril)
140setattr(tensor_operator_registry, 'chunk', chunk)
141setattr(tensor_operator_registry, 'count_nonzero', count_nonzero)
142setattr(tensor_operator_registry, 'sqrt', sqrt)
143setattr(tensor_operator_registry, 'square', square)
144setattr(tensor_operator_registry, 'sub', sub)
145setattr(tensor_operator_registry, 'triu', triu)
146setattr(tensor_operator_registry, 'tan', tan)
147setattr(tensor_operator_registry, 't', t)
148setattr(tensor_operator_registry, 'cauchy', P.Cauchy)
149setattr(tensor_operator_registry, 'log_normal', P.LogNormalReverse)
150setattr(tensor_operator_registry, 'acos', acos)
151setattr(tensor_operator_registry, 'cos', cos)
152setattr(tensor_operator_registry, 'acosh', acosh)
153setattr(tensor_operator_registry, 'cosh', cosh)
154setattr(tensor_operator_registry, 'cov', cov)
155setattr(tensor_operator_registry, 'asin', asin)
156setattr(tensor_operator_registry, 'sin', sin)
157setattr(tensor_operator_registry, 'sinc', sinc)
158setattr(tensor_operator_registry, 'pow', pow)
159setattr(tensor_operator_registry, 'negative', neg)
160setattr(tensor_operator_registry, 'amin', amin)
161setattr(tensor_operator_registry, 'amax', amax)
162setattr(tensor_operator_registry, 'aminmax', aminmax)
163setattr(tensor_operator_registry, 'mean', mean)
164setattr(tensor_operator_registry, 'prod', prod)
165setattr(tensor_operator_registry, 'round', round)
166setattr(tensor_operator_registry, 'reshape', reshape)
167setattr(tensor_operator_registry, 'reverse', reverse)
168setattr(tensor_operator_registry, 'reverse_sequence', reverse_sequence)
169setattr(tensor_operator_registry, 'xlogy', xlogy)
170setattr(tensor_operator_registry, 'flatten', flatten)
171setattr(tensor_operator_registry, 'transpose', transpose)
172setattr(tensor_operator_registry, 'broadcast_to', broadcast_to)
173setattr(tensor_operator_registry, 'matmul', matmul)
174setattr(tensor_operator_registry, 'inner', inner)
175setattr(tensor_operator_registry, 'xdivy', xdivy)
176setattr(tensor_operator_registry, 'argmax', argmax)
177setattr(tensor_operator_registry, 'argmin', argmin)
178setattr(tensor_operator_registry, 'cumsum', P.CumSum)
179setattr(tensor_operator_registry, 'cummin', cummin)
180setattr(tensor_operator_registry, 'cummax', cummax)
181setattr(tensor_operator_registry, 'nelement', numel)
182setattr(tensor_operator_registry, 'numel', numel)
183setattr(tensor_operator_registry, 'positive', positive)
184setattr(tensor_operator_registry, 'permute', permute)
185setattr(tensor_operator_registry, 'remainder', remainder)
186setattr(tensor_operator_registry, 'index_fill', index_fill)
187setattr(tensor_operator_registry, 'index_select', index_select)
188setattr(tensor_operator_registry, 'flip', flip)
189setattr(tensor_operator_registry, 'fliplr', fliplr)
190setattr(tensor_operator_registry, 'flipud', flipud)
191setattr(tensor_operator_registry, 'float_power', float_power)
192setattr(tensor_operator_registry, 'fmax', fmax)
193setattr(tensor_operator_registry, 'fmin', fmin)
194setattr(tensor_operator_registry, 'fmod', fmod)
195setattr(tensor_operator_registry, 'is_floating_point', is_floating_point)
196setattr(tensor_operator_registry, 'bitwise_and', bitwise_and)
197setattr(tensor_operator_registry, 'bitwise_or', bitwise_or)
198setattr(tensor_operator_registry, 'bitwise_xor', bitwise_xor)
199setattr(tensor_operator_registry, 'bitwise_left_shift', bitwise_left_shift)
200setattr(tensor_operator_registry, 'bitwise_right_shift', bitwise_right_shift)
201setattr(tensor_operator_registry, 'ger', ger)
202setattr(tensor_operator_registry, 'reduce_max', P.ReduceMax)
203setattr(tensor_operator_registry, 'reduce_min', P.ReduceMin)
204setattr(tensor_operator_registry, 'random_categorical', random_categorical)
205setattr(tensor_operator_registry, 'mirror_pad', P.MirrorPad)
206setattr(tensor_operator_registry, 'minimum', minimum)
207setattr(tensor_operator_registry, 'matrix_power', matrix_power)
208setattr(tensor_operator_registry, 'det', det)
209setattr(tensor_operator_registry, 'dot', dot)
210setattr(tensor_operator_registry, 'outer', outer)
211setattr(tensor_operator_registry, 'log1p', log1p)
212setattr(tensor_operator_registry, 'logdet', logdet)
213setattr(tensor_operator_registry, 'log_matrix_determinant', log_matrix_determinant)
214setattr(tensor_operator_registry, 'matrix_determinant', matrix_determinant)
215setattr(tensor_operator_registry, 'ceil', ceil)
216setattr(tensor_operator_registry, 'fillv2', P.FillV2)
217setattr(tensor_operator_registry, 'tile', tile)
218setattr(tensor_operator_registry, 'logit', logit)
219setattr(tensor_operator_registry, 'sum', sum)
220setattr(tensor_operator_registry, 'split', split)
221setattr(tensor_operator_registry, 'tensor_split', tensor_split)
222setattr(tensor_operator_registry, 'vsplit', vsplit)
223setattr(tensor_operator_registry, 'hsplit', hsplit)
224setattr(tensor_operator_registry, 'dsplit', dsplit)
225setattr(tensor_operator_registry, 'zeros_like', zeros_like)
226setattr(tensor_operator_registry, 'scalar_to_tensor', scalar_to_tensor)
227setattr(tensor_operator_registry, 'stop_gradient', stop_gradient)
228setattr(tensor_operator_registry, 'masked_fill', masked_fill)
229setattr(tensor_operator_registry, 'masked_select', masked_select)
230setattr(tensor_operator_registry, 'nonzero', nonzero)
231setattr(tensor_operator_registry, 'i0', i0)
232setattr(tensor_operator_registry, 'isclose', isclose)
233setattr(tensor_operator_registry, 'isneginf', isneginf)
234setattr(tensor_operator_registry, 'isposinf', isposinf)
235setattr(tensor_operator_registry, 'isreal', isreal)
236setattr(tensor_operator_registry, 'inv', inv)
237setattr(tensor_operator_registry, 'digamma', digamma)
238setattr(tensor_operator_registry, 'lgamma', lgamma)
239setattr(tensor_operator_registry, 'logaddexp', logaddexp)
240setattr(tensor_operator_registry, 'logaddexp2', logaddexp2)
241setattr(tensor_operator_registry, 'logcumsumexp', logcumsumexp)
242setattr(tensor_operator_registry, 'logsumexp', logsumexp)
243setattr(tensor_operator_registry, 'inverse', inverse)
244setattr(tensor_operator_registry, 'invert', invert)
245setattr(tensor_operator_registry, 'hardshrink', hardshrink)
246setattr(tensor_operator_registry, 'heaviside', heaviside)
247setattr(tensor_operator_registry, 'hypot', hypot)
248setattr(tensor_operator_registry, 'searchsorted', P.SearchSorted)
249setattr(tensor_operator_registry, 'soft_shrink', soft_shrink)
250setattr(tensor_operator_registry, 'svd', linalg_ops.Svd)
251setattr(tensor_operator_registry, 'diag', diag)
252setattr(tensor_operator_registry, 'diagflat', diagflat)
253setattr(tensor_operator_registry, 'unique_consecutive', UniqueConsecutive)
254setattr(tensor_operator_registry, 'unique_with_pad', unique_with_pad)
255setattr(tensor_operator_registry, 'inplace_update', inplace_update)
256setattr(tensor_operator_registry, 'col2im', col2im)
257setattr(tensor_operator_registry, 'standard_laplace', P.StandardLaplace)
258setattr(tensor_operator_registry, 'erf', erf)
259setattr(tensor_operator_registry, 'erfc', erfc)
260setattr(tensor_operator_registry, 'standard_normal', P.StandardNormal)
261setattr(tensor_operator_registry, 'sigmoid', sigmoid)
262setattr(tensor_operator_registry, 'median', Median)
263setattr(tensor_operator_registry, 'tanh', tanh)
264setattr(tensor_operator_registry, 'exp', exp)
265setattr(tensor_operator_registry, 'addbmm', addbmm)
266setattr(tensor_operator_registry, 'addmm', addmm)
267setattr(tensor_operator_registry, 'addmv', addmv)
268setattr(tensor_operator_registry, 'adjoint', adjoint)
269setattr(tensor_operator_registry, 'asinh', asinh)
270setattr(tensor_operator_registry, 'arcsinh', arcsinh)
271setattr(tensor_operator_registry, 'atan', atan)
272setattr(tensor_operator_registry, 'atanh', atanh)
273setattr(tensor_operator_registry, 'arctanh', arctanh)
274setattr(tensor_operator_registry, 'bmm', bmm)
275setattr(tensor_operator_registry, 'conj', conj)
276setattr(tensor_operator_registry, 'cross', cross)
277setattr(tensor_operator_registry, 'erfinv', erfinv)
278setattr(tensor_operator_registry, 'less_equal', less_equal)
279setattr(tensor_operator_registry, 'lcm', lcm)
280setattr(tensor_operator_registry, 'ldexp', ldexp)
281setattr(tensor_operator_registry, 'clamp', clamp)
282setattr(tensor_operator_registry, 'fold', fold)
283setattr(tensor_operator_registry, 'unfold', unfold)
284setattr(tensor_operator_registry, 'diagonal', diagonal)
285setattr(tensor_operator_registry, 'diagonal_scatter', diagonal_scatter)
286setattr(tensor_operator_registry, 'index_add', index_add)
287setattr(tensor_operator_registry, 'greater', greater)
288setattr(tensor_operator_registry, 'greater_equal', greater_equal)
289setattr(tensor_operator_registry, 'igamma', igamma)
290setattr(tensor_operator_registry, 'igammac', igammac)
291setattr(tensor_operator_registry, 'lu_solve', lu_solve)
292setattr(tensor_operator_registry, 'nextafter', nextafter)
293setattr(tensor_operator_registry, 'qr', qr)
294setattr(tensor_operator_registry, 'ormqr', ormqr)
295setattr(tensor_operator_registry, 'masked_scatter', array_ops.MaskedScatter)
296setattr(tensor_operator_registry, 'index_put', array_ops.IndexPut)
297setattr(tensor_operator_registry, 'quantile', quantile)
298setattr(tensor_operator_registry, 'nanquantile', nanquantile)
299setattr(tensor_operator_registry, 'orgqr', orgqr)
300# ms cannot support Tensor(True) compare
301setattr(tensor_operator_registry, '__eq__', equal)
302setattr(tensor_operator_registry, '__ne__', not_equal)
303setattr(tensor_operator_registry, '__neg__', neg)
304setattr(tensor_operator_registry, '__lt__', tensor_lt)
305setattr(tensor_operator_registry, '__le__', tensor_le)
306setattr(tensor_operator_registry, '__gt__', tensor_gt)
307setattr(tensor_operator_registry, '__ge__', tensor_ge)
308setattr(tensor_operator_registry, '__logical_not__', logical_not)
309setattr(tensor_operator_registry, 'gt', gt)
310setattr(tensor_operator_registry, 'ge', ge)
311setattr(tensor_operator_registry, 'shape', shape)
312setattr(tensor_operator_registry, 'squeeze', squeeze)
313setattr(tensor_operator_registry, 'unsqueeze', unsqueeze)
314setattr(tensor_operator_registry, 'expand_dims', expand_dims)
315setattr(tensor_operator_registry, 'contiguous', auto_generate.contiguous)
316# support GE backend for no compare operators
317setattr(tensor_operator_registry, 'cast', cast)
318setattr(tensor_operator_registry, 'shape_mul', shape_mul)
319setattr(tensor_operator_registry, 'concatenate', concat)
320setattr(tensor_operator_registry, 'fill', fill)
321setattr(tensor_operator_registry, 'fills', fills)
322setattr(tensor_operator_registry, 'fill_diagonal', P.FillDiagonal)
323setattr(tensor_operator_registry, 'eye', eye)
324setattr(tensor_operator_registry, 'eigvals', eigvals)
325setattr(tensor_operator_registry, 'reduce_sum', reduce_sum)
326setattr(tensor_operator_registry, 'reducesum', P.ReduceSum)
327setattr(tensor_operator_registry, 'tensor_slice', tensor_slice)
328setattr(tensor_operator_registry, 'select', select)
329setattr(tensor_operator_registry, 'uniform', uniform_ext)
330setattr(tensor_operator_registry, 'gather', gather)
331setattr(tensor_operator_registry, 'gather_d', gather_d)
332setattr(tensor_operator_registry, 'gather_elements', gather_elements)
333setattr(tensor_operator_registry, 'gather_nd', gather_nd)
334setattr(tensor_operator_registry, 'stack', stack)
335setattr(tensor_operator_registry, 'unstack', unstack)
336setattr(tensor_operator_registry, 'unbind', unstack)
337setattr(tensor_operator_registry, 'log', log)
338setattr(tensor_operator_registry, 'log10', log10)
339setattr(tensor_operator_registry, 'log2', log2)
340setattr(tensor_operator_registry, 'lerp', lerp)
341setattr(tensor_operator_registry, 'floor', floor)
342setattr(tensor_operator_registry, 'floor_divide', floor_divide)
343# support sparse tensor operators
344setattr(tensor_operator_registry, 'csr_add', csr_add)
345setattr(tensor_operator_registry, 'csr_mul', csr_mul)
346setattr(tensor_operator_registry, 'csr2coo', csr2coo)
347setattr(tensor_operator_registry, 'coo2csr', coo2csr)
348setattr(tensor_operator_registry, 'csr_div', csr_div)
349setattr(tensor_operator_registry, 'csr_mv', csr_mv)
350setattr(tensor_operator_registry, 'csr_mm_akg', _csr_ops.CSRMM)
351setattr(tensor_operator_registry, 'csr_mm', csr_mm)
352setattr(tensor_operator_registry, 'csr_reduce_sum', csr_reduce_sum)
353setattr(tensor_operator_registry, 'dense_to_sparse_csr', dense_to_sparse_csr)
354setattr(tensor_operator_registry, 'dense_to_sparse_coo', dense_to_sparse_coo)
355setattr(tensor_operator_registry, 'csr_to_dense', csr_to_dense)
356setattr(tensor_operator_registry, 'narrow', narrow)
357setattr(tensor_operator_registry, 'sort', sort)
358setattr(tensor_operator_registry, 'argsort', argsort)
359setattr(tensor_operator_registry, 'msort', msort)
360setattr(tensor_operator_registry, 'mm', mm)
361setattr(tensor_operator_registry, 'nan_to_num', nan_to_num)
362setattr(tensor_operator_registry, 'nansum', nansum)
363setattr(tensor_operator_registry, 'nanmean', nanmean)
364setattr(tensor_operator_registry, 'nanmedian', nanmedian)
365setattr(tensor_operator_registry, 'csr_to_coo', csr_to_coo)
366setattr(tensor_operator_registry, 'zeros', zeros)
367setattr(tensor_operator_registry, 'ones', ones)
368setattr(tensor_operator_registry, 'unsorted_segment_min', unsorted_segment_min)
369setattr(tensor_operator_registry, 'unsorted_segment_max', unsorted_segment_max)
370setattr(tensor_operator_registry, 'unsorted_segment_prod', unsorted_segment_prod)
371setattr(tensor_operator_registry, 'scatter', scatter)
372setattr(tensor_operator_registry, 'tensor_scatter_update', tensor_scatter_update)
373setattr(tensor_operator_registry, 'tensor_scatter_mul', tensor_scatter_mul)
374setattr(tensor_operator_registry, 'tensor_scatter_div', tensor_scatter_div)
375setattr(tensor_operator_registry, 'tensor_scatter_min', tensor_scatter_min)
376setattr(tensor_operator_registry, 'tensor_scatter_max', tensor_scatter_max)
377setattr(tensor_operator_registry, 'tensor_scatter_sub', tensor_scatter_sub)
378setattr(tensor_operator_registry, 'tensor_scatter_add', tensor_scatter_add)
379setattr(tensor_operator_registry, 'slice_scatter', slice_scatter)
380setattr(tensor_operator_registry, 'select_scatter', select_scatter)
381setattr(tensor_operator_registry, 'bernoulli', bernoulli)
382setattr(tensor_operator_registry, 'poisson', P.Poisson)
383setattr(tensor_operator_registry, 'randperm', P.Randperm)
384setattr(tensor_operator_registry, 'multinomial', multinomial)
385setattr(tensor_operator_registry, 'norm', norm)
386setattr(tensor_operator_registry, 'renorm', renorm)
387setattr(tensor_operator_registry, 'adaptive_max_pool2d', AdaptiveMaxPool2D)
388setattr(tensor_operator_registry, 'coalesce', coalesce)
389setattr(tensor_operator_registry, 'argmax_with_value', max)
390setattr(tensor_operator_registry, 'argmin_with_value', min)
391setattr(tensor_operator_registry, 'argwhere', argwhere)
392setattr(tensor_operator_registry, 'coo_add', coo_add)
393setattr(tensor_operator_registry, 'topk', topk)
394setattr(tensor_operator_registry, 'isfinite', isfinite)
395setattr(tensor_operator_registry, 'to', cast)
396setattr(tensor_operator_registry, 'bool', cast)
397setattr(tensor_operator_registry, 'float', cast)
398setattr(tensor_operator_registry, 'half', cast)
399setattr(tensor_operator_registry, 'int', cast)
400setattr(tensor_operator_registry, 'long', cast)
401setattr(tensor_operator_registry, 'cholesky', cholesky)
402setattr(tensor_operator_registry, 'cholesky_inverse', cholesky_inverse)
403setattr(tensor_operator_registry, 'cholesky_solve', cholesky_solve)
404setattr(tensor_operator_registry, 'expand', broadcast_to)
405setattr(tensor_operator_registry, 'tensortotuple', TensorToTuple)
406setattr(tensor_operator_registry, 'cumprod', cumprod)
407setattr(tensor_operator_registry, 'diff', diff)
408setattr(tensor_operator_registry, 'div', div)
409setattr(tensor_operator_registry, 'equal', equal)
410setattr(tensor_operator_registry, 'expm1', expm1)
411setattr(tensor_operator_registry, 'frac', frac)
412setattr(tensor_operator_registry, 'isinf', isinf)
413setattr(tensor_operator_registry, 'isnan', isnan)
414setattr(tensor_operator_registry, 'is_complex', is_complex)
415setattr(tensor_operator_registry, 'le', le)
416setattr(tensor_operator_registry, 'less', less)
417setattr(tensor_operator_registry, 'logical_and', logical_and)
418setattr(tensor_operator_registry, 'logical_not', logical_not)
419setattr(tensor_operator_registry, 'logical_or', logical_or)
420setattr(tensor_operator_registry, 'logical_xor', logical_xor)
421setattr(tensor_operator_registry, 'lstsq', lstsq)
422setattr(tensor_operator_registry, 'mvlgamma', mvlgamma)
423setattr(tensor_operator_registry, 'maximum', maximum)
424setattr(tensor_operator_registry, 'max', max)
425setattr(tensor_operator_registry, 'min', min)
426setattr(tensor_operator_registry, 'mul', mul)
427setattr(tensor_operator_registry, 'multiply', multiply)
428setattr(tensor_operator_registry, 'moveaxis', moveaxis)
429setattr(tensor_operator_registry, 'movedim', movedim)
430setattr(tensor_operator_registry, 'neg', neg)
431setattr(tensor_operator_registry, 'ne', ne)
432setattr(tensor_operator_registry, 'not_equal', not_equal)
433setattr(tensor_operator_registry, 'sgn', sgn)
434setattr(tensor_operator_registry, 'sign', sign)
435setattr(tensor_operator_registry, 'signbit', signbit)
436setattr(tensor_operator_registry, 'sinh', sinh)
437setattr(tensor_operator_registry, 'trunc', trunc)
438setattr(tensor_operator_registry, 'where', where)
439setattr(tensor_operator_registry, 'imag', imag)
440setattr(tensor_operator_registry, 'repeat_interleave', repeat_interleave)
441setattr(tensor_operator_registry, 'rad2deg', rad2deg)
442setattr(tensor_operator_registry, 'deg2rad', deg2rad)
443setattr(tensor_operator_registry, 'copysign', copysign)
444setattr(tensor_operator_registry, 'roll', Roll)
445setattr(tensor_operator_registry, 'rot90', rot90)
446setattr(tensor_operator_registry, 'swapaxes', swapaxes)
447setattr(tensor_operator_registry, 'swapdims', swapdims)
448setattr(tensor_operator_registry, 'repeat_elements', repeat_elements)
449setattr(tensor_operator_registry, 'top_k', top_k)
450
451__all__ = [name for name in dir() if name[0] != "_"]
452__all__.remove('Primitive')
453__all__.remove('argmax_ext')
454__all__.remove('uniform_ext')
455