• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""builtin_operations"""
16import numpy as np
17from mindspore.ops import functional as F
18from mindspore.ops import composite as C
19from mindspore.common.tensor import Tensor
20import mindspore.common.dtype as mstype
21from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
22
23
24def ScalarAdd(x, y):
25    """Implement `scalar_add`."""
26    return x + y
27
28
29def ScalarMul(x, y):
30    """Implement `scalar_mul`."""
31    return x * y
32
33
34def ScalarMod(x, y):
35    """Implement `scalar_mul`."""
36    return x % y
37
38
39def ScalarSub(x, y):
40    """Implement `scalar_sub`."""
41    return x - y
42
43
44def ScalarUsub(x):
45    """Implement `scalar_usub`."""
46    return -x
47
48
49def TupleGetItem(x, index):
50    """Implement `tuple_getitem`."""
51    if isinstance(x, Tensor):
52        x = x.asnumpy()
53        y = x[index]
54        return Tensor(y)
55    return x[index]
56
57
58def scalar_gt(x, y):
59    """Implement `scalar_gt`."""
60    return x > y
61
62
63def scalar_ne(x, y):
64    """Implement `scalar_ne`."""
65    return x != y
66
67
68def scalar_eq(x, y):
69    """Implement `scalar_eq`."""
70    return x == y
71
72
73def scalar_le(x, y):
74    """Implement `scalar_le`."""
75    return x <= y
76
77
78def scalar_lt(x, y):
79    """Implement `scalar_lt`."""
80    return x < y
81
82
83def identity(x):
84    """Implement `identity`."""
85    return x
86
87
88def zeros_like_tensor(x):
89    """Implement `zeros_like_tensor`."""
90    x = x.asnumpy()
91    value = Tensor(np.zeros(x.shape).astype(np.float32))
92    return value
93
94
95def Switch(c, x, y):
96    """Implement `switch`."""
97    return x if c else y
98
99
100def list_getitem(data, item):
101    """Implement `list_getitem`."""
102    return data[item]
103
104
105def bool_not(x):
106    """Implement `bool_not`."""
107    return not x
108
109
110def bool_and(x, y):
111    """Implement `bool_and`."""
112    return x and y
113
114
115def bool_or(x, y):
116    """Implement `bool_or`."""
117    return x or y
118
119
120def make_list(*xs):
121    """Implement `make_list`."""
122    return list(xs)
123
124
125def list_len(x):
126    """Implement `list_len`."""
127    return len(x)
128
129
130def Depend(value, expr):
131    """Implement `Depend`."""
132    return value
133
134
135def UpdateState(monad, *exprs):
136    """Implement `UpdateState`."""
137    return monad
138
139
140def Load(value, u=None):
141    """Implement `Load`."""
142    return value
143
144
145# only used in PyNative mode
146def make_ref(key, value, ref):
147    return value
148
149
150def scalar_cast(x, t):
151    """Implement scalar_cast."""
152    np_type = dtype_to_nptype(t)
153    value = np_type(x)
154    cast_value = np.ndarray.item(value)
155    return cast_value
156
157
158def typeof(x):
159    """Implement typeof."""
160    return get_py_obj_dtype(x)
161
162
163def tuple_to_array(x):
164    """Implement `tuple_to_array`."""
165    return Tensor(np.array(x))
166
167
168def stop_gradient(x):
169    """Implement `stop_gradient`."""
170    return x
171
172
173hyper_map = C.HyperMap()
174
175
176def mixed_precision_cast(dst_type, x):
177    """Implement `mixed_precision_cast`."""
178
179    def cast_inner(data):
180        if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64):
181            return F.cast(data, dst_type)
182        return data
183
184    return hyper_map(cast_inner, x)
185