• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""
17Primitive operator classes.
18
19A collection of operators to build neural networks or to compute functions.
20"""
21
22from .image_ops import (CropAndResize)
23from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
24                        Diag, DiagPart, DType, ExpandDims, Eye,
25                        Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
26                        IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
27                        Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
28                        SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
29                        ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
30                        Shape, DynamicShape, Size, Slice, Split, SplitV, TransShape, ParallelConcat, Padding,
31                        UniqueWithPad, ScatterNdAdd, ScatterNdSub, ScatterNonAliasingAdd, ReverseV2, Rint,
32                        Squeeze, StridedSlice, Tile, TensorScatterUpdate, TensorScatterAdd, EditDistance, Sort,
33                        Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
34                        UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch,
35                        BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
36                        EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
37                        TensorScatterMax, TensorScatterMin, TensorScatterSub)
38from .comm_ops import (AllGather, AllReduce, NeighborExchange, AlltoAll, AllSwap, ReduceScatter, Broadcast,
39                       _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
40                       _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
41                       _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
42from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
43                        TensorSummary, HistogramSummary, Print, Assert)
44from .control_ops import GeSwitch, Merge
45from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
46                        MakeRefKey,
47                        FusedWeightScaleApplyMomentum, FusedCastAdamWeightDecay)
48
49from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
50                       BitwiseAnd, BitwiseOr,
51                       BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
52                       ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
53                       Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
54                       Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
55                       LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan,
56                       Minimum, Mul, Neg, NMSWithMask, NotEqual,
57                       NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
58                       NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
59                       Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
60                       Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
61                       Square, Sub, TensorAdd, Add, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan,
62                       MatrixInverse, IndexAdd, Erfinv, Conj, Real, Imag)
63
64from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, Poisson, UniformInt, UniformReal,
65                         RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
66                         LogUniformCandidateSampler)
67
68from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSparseLazyAdam, AdamNoUpdateParam,
69                     ApplyMomentum, BatchNorm, BiasAdd, Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose,
70                     DepthwiseConv2dNative,
71                     DropoutDoMask, Dropout, Dropout2D, Dropout3D, DropoutGenMask, Flatten,
72                     InstanceNorm, BNTrainingReduce, BNTrainingUpdate,
73                     GeLU, Gelu, FastGeLU, FastGelu, Elu,
74                     GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCLossV2Grad, CTCGreedyDecoder,
75                     LogSoftmax, MaxPool3D, AvgPool3D,
76                     MaxPool, DataFormatDimMap,
77                     AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
78                     MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
79                     ResizeBilinear, Sigmoid, SeLU, HShrink,
80                     SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
81                     SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
82                     SoftmaxCrossEntropyWithLogits, ROIAlign,
83                     SparseSoftmaxCrossEntropyWithLogits, Tanh,
84                     TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
85                     ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2,
86                     FusedSparseFtrl, FusedSparseProximalAdagrad, SparseApplyRMSProp,
87                     ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdagradDA,
88                     ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
89                     ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D, SoftShrink)
90from . import _quant_ops
91from ._quant_ops import *
92from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
93                        ConfusionMatrix, PopulationCount, UpdateState, Load,
94                        CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
95                        PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc)
96from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
97                        CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
98                        CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
99                        LoadIm2Col, UpdateThorGradient, Cholesky, CholeskyTrsm,
100                        DetTriangle, ProdForceSeA)
101from .sparse_ops import (SparseToDense, SparseTensorDenseMatmul)
102from ._embedding_cache_ops import (CacheSwapTable, UpdateCache, MapCacheIdx, SubAndFilter,
103                                   MapUniform, DynamicAssign, PadAndShift)
104from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAtomEnergy, BondForceWithAtomVirial,
105                         DihedralForce, DihedralEnergy, DihedralAtomEnergy, DihedralForceWithAtomEnergy, AngleForce,
106                         AngleEnergy, AngleAtomEnergy, AngleForceWithAtomEnergy, PMEReciprocalForce,
107                         LJForce, LJEnergy, LJForceWithPMEDirectForce, PMEExcludedForce, PMEEnergy, Dihedral14LJForce,
108                         Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
109                         Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy,
110                         MDTemperature, MDIterationLeapFrogLiujian,
111                         CrdToUintCrd, MDIterationSetupRandState, TransferCrd, FFT3D, IFFT3D, NeighborListUpdate)
112from .sponge_update_ops import (ConstrainForceCycleWithVirial, RefreshUintCrd, LastCrdToDr, RefreshCrdVel,
113                                CalculateNowrapCrd, RefreshBoxmapTimes, Totalc6get, CrdToUintCrdQuarter,
114                                MDIterationLeapFrogLiujianWithMaxVel, GetCenterOfMass, MapCenterOfMass,
115                                NeighborListRefresh, MDIterationLeapFrog, MDIterationLeapFrogWithMaxVel,
116                                MDIterationGradientDescent, BondForceWithAtomEnergyAndVirial, ConstrainForceCycle,
117                                LJForceWithVirialEnergy, LJForceWithPMEDirectForceUpdate, PMEReciprocalForceUpdate,
118                                PMEExcludedForceUpdate, LJForceWithVirialEnergyUpdate,
119                                Dihedral14ForceWithAtomEnergyVirial, PMEEnergyUpdate,
120                                ConstrainForceVirial, ConstrainForce, Constrain)
121from .rl_ops import (BufferAppend, BufferGetItem, BufferSample)
122from ._inner_ops import (MatmulDDS, DSDMatmul)
123
124__all__ = [
125    'Unique',
126    'ReverseSequence',
127    'Sort',
128    'EditDistance',
129    'CropAndResize',
130    'Add',
131    'TensorAdd',
132    'Argmax',
133    'Argmin',
134    'MaxPool3D',
135    'AvgPool3D',
136    'ArgMaxWithValue',
137    'ArgMinWithValue',
138    'AddN',
139    'AccumulateNV2',
140    'Sub',
141    'CumSum',
142    'MatMul',
143    'BatchMatMul',
144    'Mul',
145    'MaskedFill',
146    'MaskedSelect',
147    'Meshgrid',
148    'Pow',
149    'Exp',
150    'Expm1',
151    'Rsqrt',
152    'Sqrt',
153    'Square',
154    'DynamicGRUV2',
155    'SquaredDifference',
156    'Xdivy',
157    'Xlogy',
158    'Conv2D',
159    'Conv3D',
160    'Conv2DTranspose',
161    'Conv3DTranspose',
162    'Flatten',
163    'MaxPoolWithArgmax',
164    'BNTrainingReduce',
165    'BNTrainingUpdate',
166    'BatchNorm',
167    'MaxPool',
168    'TopK',
169    'LinSpace',
170    'Adam',
171    'AdamWeightDecay',
172    'FusedCastAdamWeightDecay',
173    'FusedSparseAdam',
174    'FusedSparseLazyAdam',
175    'AdamNoUpdateParam',
176    'Softplus',
177    'Softmax',
178    'Softsign',
179    'LogSoftmax',
180    'SoftmaxCrossEntropyWithLogits',
181    'BCEWithLogitsLoss',
182    'ROIAlign',
183    'SparseSoftmaxCrossEntropyWithLogits',
184    'NLLLoss',
185    'SGD',
186    'ApplyMomentum',
187    'FusedWeightScaleApplyMomentum',
188    'ExpandDims',
189    'Cast',
190    'IsSubClass',
191    'IsInstance',
192    'Reshape',
193    'Squeeze',
194    'Transpose',
195    'OneHot',
196    'GatherV2',
197    'Gather',
198    'SparseGatherV2',
199    'EmbeddingLookup',
200    'Padding',
201    'GatherD',
202    'Identity',
203    'UniqueWithPad',
204    'Concat',
205    'Pack',
206    'Stack',
207    'Unpack',
208    'Unstack',
209    'Tile',
210    'BiasAdd',
211    'GeLU',
212    'Gelu',
213    'FastGeLU',
214    'FastGelu',
215    'Minimum',
216    'Maximum',
217    'StridedSlice',
218    'ReduceSum',
219    'ReduceMean',
220    'LayerNorm',
221    'Rank',
222    'Lerp',
223    'Less',
224    'LessEqual',
225    'RealDiv',
226    'Div',
227    'DivNoNan',
228    'Inv',
229    'Invert',
230    'TruncatedNormal',
231    'Fill',
232    'Ones',
233    'Zeros',
234    'OnesLike',
235    'ZerosLike',
236    'Select',
237    'Split',
238    'SplitV',
239    'Mish',
240    'SeLU',
241    'MulNoNan',
242    'ReLU',
243    'ReLU6',
244    'Elu',
245    'Erf',
246    "Erfinv",
247    'Erfc',
248    'Sigmoid',
249    'HSwish',
250    'HSigmoid',
251    'Tanh',
252    'NoRepeatNGram',
253    'Randperm',
254    'RandomChoiceWithMask',
255    'StandardNormal',
256    'Multinomial',
257    'Gamma',
258    'Poisson',
259    'UniformInt',
260    'UniformReal',
261    'StandardLaplace',
262    'RandomCategorical',
263    'ResizeBilinear',
264    'ScalarSummary',
265    'ImageSummary',
266    'TensorSummary',
267    'HistogramSummary',
268    "Print",
269    "Assert",
270    'InsertGradientOf',
271    'HookBackward',
272    'InvertPermutation',
273    'Shape',
274    'DynamicShape',
275    'DropoutDoMask',
276    'DropoutGenMask',
277    'Dropout',
278    'Dropout2D',
279    'Dropout3D',
280    'Neg',
281    'InplaceAdd',
282    'InplaceSub',
283    'Slice',
284    'DType',
285    'NPUAllocFloatStatus',
286    'NPUGetFloatStatus',
287    'NPUClearFloatStatus',
288    'IsNan',
289    'IsFinite',
290    'IsInf',
291    'FloatStatus',
292    'Reciprocal',
293    'SmoothL1Loss',
294    'SoftMarginLoss',
295    'L2Loss',
296    'CTCLoss',
297    'CTCGreedyDecoder',
298    'RNNTLoss',
299    'DynamicRNN',
300    'ReduceAll',
301    'ReduceAny',
302    'ScalarToArray',
303    'ScalarToTensor',
304    'TupleToArray',
305    'GeSwitch',
306    'Merge',
307    'SameTypeShape',
308    'CheckBprop',
309    'CheckValid',
310    'BoundingBoxEncode',
311    'BoundingBoxDecode',
312    'L2Normalize',
313    'ScatterAdd',
314    'ScatterSub',
315    'ScatterMul',
316    'ScatterDiv',
317    'ScatterNd',
318    'ScatterMax',
319    'ScatterMin',
320    'ScatterNdAdd',
321    'ScatterNdSub',
322    'ScatterNonAliasingAdd',
323    'ReverseV2',
324    'Rint',
325    'ResizeNearestNeighbor',
326    'HistogramFixedWidth',
327    'Pad',
328    'MirrorPad',
329    'GatherNd',
330    'TensorScatterUpdate',
331    'TensorScatterAdd',
332    'ScatterUpdate',
333    'ScatterNdUpdate',
334    'Floor',
335    'NMSWithMask',
336    'IOU',
337    'Partial',
338    'MakeRefKey',
339    'Depend',
340    'UpdateState',
341    'identity',
342    'AvgPool',
343    # Back Primitive
344    'Equal',
345    'EqualCount',
346    'NotEqual',
347    'Greater',
348    'GreaterEqual',
349    'LogicalNot',
350    'LogicalAnd',
351    'LogicalOr',
352    'Size',
353    'DepthwiseConv2dNative',
354    'UnsortedSegmentSum',
355    'UnsortedSegmentMin',
356    'UnsortedSegmentMax',
357    'UnsortedSegmentProd',
358    "AllGather",
359    "AllReduce",
360    "AllSwap",
361    "ReduceScatter",
362    "Broadcast",
363    "ReduceOp",
364    'ScalarCast',
365    'GetNext',
366    'ReduceMax',
367    'ReduceMin',
368    'ReduceProd',
369    'CumProd',
370    'Cdist',
371    'Log',
372    'Log1p',
373    'SigmoidCrossEntropyWithLogits',
374    'FloorDiv',
375    'FloorMod',
376    'TruncateDiv',
377    'TruncateMod',
378    'Ceil',
379    'Acosh',
380    'Asinh',
381    "PReLU",
382    "Cos",
383    "Cosh",
384    "ACos",
385    "Diag",
386    "DiagPart",
387    'Eye',
388    'Assign',
389    'AssignAdd',
390    'AssignSub',
391    "Sin",
392    "Sinh",
393    "Asin",
394    "LSTM",
395    "Abs",
396    "BinaryCrossEntropy",
397    "KLDivLoss",
398    "SparseApplyAdagrad",
399    "SparseApplyAdagradV2",
400    "SpaceToDepth",
401    "DepthToSpace",
402    "Conv2DBackpropInput",
403    "ComputeAccidentalHits",
404    "Sign",
405    "LARSUpdate",
406    "Round",
407    "Eps",
408    "ApplyFtrl",
409    "SpaceToBatch",
410    "SparseApplyFtrl",
411    "SparseApplyFtrlV2",
412    "FusedSparseFtrl",
413    "ApplyProximalAdagrad",
414    "SparseApplyProximalAdagrad",
415    "FusedSparseProximalAdagrad",
416    "SparseApplyRMSProp",
417    "ApplyAdaMax",
418    "ApplyAdadelta",
419    "ApplyAdagrad",
420    "ApplyAdagradV2",
421    "ApplyAdagradDA",
422    "ApplyAddSign",
423    "ApplyPowerSign",
424    "ApplyGradientDescent",
425    "ApplyProximalGradientDescent",
426    "BatchToSpace",
427    "Atan2",
428    "ApplyRMSProp",
429    "ApplyCenteredRMSProp",
430    "SpaceToBatchND",
431    "BatchToSpaceND",
432    "SquareSumAll",
433    "BitwiseAnd",
434    "BitwiseOr",
435    "BitwiseXor",
436    "BesselI0e",
437    "BesselI1e",
438    "Atan",
439    "Atanh",
440    "Tan",
441    "BasicLSTMCell",
442    "BroadcastTo",
443    "DataFormatDimMap",
444    "ApproximateEqual",
445    "InplaceUpdate",
446    "InTopK",
447    "UniformCandidateSampler",
448    "LogUniformCandidateSampler",
449    "LRN",
450    "Mod",
451    "ConfusionMatrix",
452    "PopulationCount",
453    "ParallelConcat",
454    "Push",
455    "Pull",
456    "PullWeight",
457    "PushWeight",
458    "ReLUV2",
459    "SparseToDense",
460    "SparseTensorDenseMatmul",
461    "MatrixInverse",
462    "Range",
463    "SearchSorted",
464    "IndexAdd",
465    "AdaptiveAvgPool2D",
466    "TensorScatterMax",
467    "TensorScatterMin",
468    "TensorScatterSub",
469    "SoftShrink",
470    "FFT3D",
471    "IFFT3D",
472    "HShrink",
473    "PyFunc",
474    "BufferAppend",
475    "BufferGetItem",
476    "BufferSample",
477    "Erfinv",
478    "Conj",
479    "Real",
480    "Imag"
481]
482
483__sponge__ = [
484    "BondForce",
485    "BondEnergy",
486    "BondAtomEnergy",
487    "BondForceWithAtomEnergy",
488    "BondForceWithAtomVirial",
489    "DihedralForce",
490    "DihedralEnergy",
491    "DihedralAtomEnergy",
492    "DihedralForceWithAtomEnergy",
493    "AngleForce",
494    "AngleEnergy",
495    "AngleAtomEnergy",
496    "AngleForceWithAtomEnergy",
497    'PMEReciprocalForce',
498    'LJForce',
499    'LJForceWithPMEDirectForce',
500    'LJEnergy',
501    'PMEExcludedForce',
502    'PMEEnergy',
503    "Dihedral14LJForce",
504    "Dihedral14LJEnergy",
505    "Dihedral14LJForceWithDirectCF",
506    "Dihedral14LJCFForceWithAtomEnergy",
507    "Dihedral14LJAtomEnergy",
508    "Dihedral14CFEnergy",
509    "MDIterationLeapFrog",
510    "Dihedral14CFAtomEnergy",
511    "MDTemperature",
512    "NeighborListUpdate",
513    "MDIterationLeapFrogLiujian",
514    "CrdToUintCrd",
515    "MDIterationSetupRandState",
516    "TransferCrd",
517    # Update
518    "ConstrainForceCycleWithVirial",
519    "RefreshUintCrd",
520    "LastCrdToDr",
521    "RefreshCrdVel",
522    "CalculateNowrapCrd",
523    "RefreshBoxmapTimes",
524    "Totalc6get",
525    "CrdToUintCrdQuarter",
526    "MDIterationLeapFrogLiujianWithMaxVel",
527    "GetCenterOfMass",
528    "MapCenterOfMass",
529    "NeighborListRefresh",
530    "MDIterationLeapFrog",
531    "MDIterationLeapFrogWithMaxVel",
532    "MDIterationGradientDescent",
533    "BondForceWithAtomEnergyAndVirial",
534    "ConstrainForceCycle",
535    "LJForceWithVirialEnergy",
536    "LJForceWithPMEDirectForceUpdate",
537    "PMEReciprocalForceUpdate",
538    "PMEExcludedForceUpdate",
539    "LJForceWithVirialEnergyUpdate",
540    "Dihedral14ForceWithAtomEnergyVirial",
541    "PMEEnergyUpdate",
542    "ConstrainForceVirial",
543    "ConstrainForce",
544    "Constrain",
545]
546
547__all__.extend(__sponge__)
548
549__all__.sort()
550