• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2024 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17"""Resources for ast tree parse."""
18from __future__ import absolute_import
19
20import ast
21
22from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor
23from mindspore.experimental import MapParameter
24from mindspore.common.sparse_tensor import RowTensorInner
25from mindspore.ops import functional as F, composite as C
26from mindspore.ops import Primitive
27from mindspore.ops.composite import multitype_ops
28from mindspore._c_expression import security
29from . import standard_method as M
30from . import trope as T
31from .namespace import ModuleNamespace
32
33# namespace define
34functional_ns = ModuleNamespace('mindspore.ops.functional')
35composite_ns = ModuleNamespace('mindspore.ops.composite')
36trope_ns = ModuleNamespace('mindspore._extends.parse.trope')
37
38SYMBOL_UNDEFINE = 0xFF      # Undefined var and function
39
40# Some space set aside for readability of code
41parse_object_map = {
42    # ast grammar
43    ast.Add:        (trope_ns, 'add', '+'),
44    ast.Sub:        (trope_ns, 'sub', '-'),
45    ast.Mult:       (trope_ns, 'mul', '*'),
46    ast.Div:        (trope_ns, 'truediv', '/'),
47    ast.FloorDiv:   (trope_ns, 'floordiv', '//'),
48    ast.Mod:        (trope_ns, 'mod', '%'),
49    ast.Pow:        (trope_ns, 'pow', '**'),
50    ast.MatMult:    (trope_ns, 'matmul', '@'),
51    ast.LShift:     (trope_ns, 'lshift', '<<'),
52    ast.RShift:     (trope_ns, 'rshift', '>>'),
53    ast.BitAnd:     (trope_ns, 'and_', '&'),
54    ast.BitOr:      (trope_ns, 'or_', '|'),
55    ast.BitXor:     (trope_ns, 'xor', '^'),
56    ast.UAdd:       (trope_ns, 'pos', '+'),
57    ast.USub:       (trope_ns, 'neg', '-'),
58    ast.Invert:     (trope_ns, 'invert', '~'),
59    ast.Not:        (trope_ns, 'not_', 'not'),
60    ast.Eq:         (trope_ns, 'eq', '=='),
61    ast.NotEq:      (trope_ns, 'ne', '!='),
62    ast.Lt:         (trope_ns, 'lt', '<'),
63    ast.Gt:         (trope_ns, 'gt', '>'),
64    ast.LtE:        (trope_ns, 'le', '<='),
65    ast.GtE:        (trope_ns, 'ge', '>='),
66    ast.Is:         (trope_ns, 'is_', 'is'),
67    ast.IsNot:      (trope_ns, 'is_not', 'is not'),
68    ast.In:         (trope_ns, 'contains', 'in'),
69    ast.NotIn:      (trope_ns, 'not_contains', 'not in'),
70
71    # operation symbol type
72    'getitem':      (composite_ns, 'getitem', ''),
73    'ms_next':      (composite_ns, 'ms_next', ''),
74
75    # undefined type
76    SYMBOL_UNDEFINE: (None, 'undefine', ''),
77}
78
79# Operation symbols corresponding to ast grammar
80ops_symbol_map = {
81    # ast grammar
82    ast.Add:        '+',
83    ast.Sub:        '-',
84    ast.Mult:       '*',
85    ast.Div:        '/',
86    ast.FloorDiv:   '//',
87    ast.Mod:        '%',
88    ast.Pow:        '**',
89    ast.LShift:     '<<',
90    ast.RShift:     '>>',
91    ast.BitAnd:     '&',
92    ast.BitOr:      '|',
93    ast.BitXor:     '^',
94
95    # undefined type
96    SYMBOL_UNDEFINE: '',
97}
98
99# Escape an object to another object, eg: system function(len,xxx)
100# Some space set aside for readability of code
101convert_object_map = {
102    T.add:          multitype_ops.add,
103    T.sub:          multitype_ops.sub,
104    T.mul:          multitype_ops.mul,
105    T.truediv:      multitype_ops.div,
106    T.getitem:      multitype_ops.getitem,
107    T.setitem:      multitype_ops.setitem,
108    T.floordiv:     multitype_ops.floordiv,
109    T.mod:          multitype_ops.mod,
110    T.pow:          multitype_ops.pow_,
111    T.matmul:       F.matmul,
112    T.lshift:       multitype_ops.left_shift,
113    T.rshift:       multitype_ops.right_shift,
114    T.and_:         multitype_ops.bitwise_and,
115    T.or_:          multitype_ops.bitwise_or,
116    T.xor:          multitype_ops.bitwise_xor,
117    T.pos:          multitype_ops.uadd,
118    T.neg:          multitype_ops.negative,
119    T.invert:       F.logical_not,
120    T.not_:         multitype_ops.logical_not,
121    T.eq:           multitype_ops.equal,
122    T.ne:           multitype_ops.not_equal,
123    T.lt:           multitype_ops.less,
124    T.gt:           multitype_ops.greater,
125    T.le:           multitype_ops.less_equal,
126    T.ge:           multitype_ops.greater_equal,
127    T.is_:          F.is_,
128    T.is_not:       F.is_not,
129    T.contains:     multitype_ops.in_,
130    T.not_contains: multitype_ops.not_in_,
131
132    # system function
133    T.abs:          Primitive('inner_abs'),
134    T.round:        Primitive('inner_round'),
135    T.len:          M.ms_len,
136    T.bool_:        M.bool_,
137    T.map:          C.Map(),
138    T.filter:       M.filter_,
139    T.partial:      F.partial,
140    T.zip:          M.ms_zip,
141    T.enumerate:    M.enumerate_,
142    T.isinstance:   Primitive('isinstance'),
143    T.max:          M.ms_max,
144    T.min:          M.ms_min,
145    T.sum:          M.ms_sum,
146    T.getattr:      Primitive('getattr'),
147    T.hasattr:      M.hasattr,
148
149    # custom define operation
150    T.iter:         C.iter_converter,
151    T.next:         C.ms_next,
152    T.hasnext:      C.ms_hasnext,
153    T.MakeTuple:    F.make_tuple,
154    T.make_dict:    F.make_dict,
155    T.make_list:    F.make_list,
156    T.make_slice:   F.make_slice,
157    T.range:        F.make_range,
158    T.mutable:      Primitive('mutable'),
159
160    # user defined
161    RowTensorInner: F.make_row_tensor_inner,
162    RowTensor:      F.make_row_tensor,
163    SparseTensor:   F.make_sparse_tensor,
164    COOTensor:      F.make_coo_tensor,
165    CSRTensor:      F.make_csr_tensor,
166    MapParameter:   F.make_map_parameter
167}
168
169if not security.enable_security():
170    convert_object_map[T.print] = F.print_
171
172# Convert class object to callable function
173convert_class_to_function_map = {
174    "class 'int'":   M.int_func,
175    "class 'float'": M.float_func,
176    "class 'bool'":  M.bool_func,
177    "class 'str'":   M.str_func
178}
179
180constant_fold_functions = [
181    abs,
182    all,
183    any,
184    float,
185    int,
186    bool,
187    len,
188    max,
189    min,
190    pow,
191    repr,
192    round,
193    str,
194    sum,
195    type,
196    T.add,
197    T.sub,
198    T.mul,
199    T.truediv,
200    T.floordiv,
201    T.mod,
202    T.pos,
203    T.neg,
204    T.not_,
205    T.and_,
206    T.or_,
207    T.xor,
208    T.lshift,
209    T.rshift,
210    T.matmul,
211    T.getitem,
212    T.invert
213]
214