1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/resource.h"
20 #include "pipeline/jit/static_analysis/static_analysis.h"
21 #include "debug/trace.h"
22 #include "ir/dtype.h"
23 #include "pipeline/jit/parse/data_converter.h"
24 #include "frontend/operator/ops.h"
25 #include "frontend/optimizer/ad/dfunctor.h"
26
27 namespace mindspore {
28 // namespace to support opmap definition
29 namespace pipeline {
30
GetMethodMap()31 BuiltInTypeMap &GetMethodMap() {
32 static BuiltInTypeMap method_map = {{kObjectTypeString,
33 {
34 {"__bool__", std::string("str_bool")} // C.str_bool
35 }},
36 {kMetaTypeNone,
37 {
38 {"__bool__", std::string("none_bool")} // C.none_bool
39 }},
40 {kObjectTypeFunction,
41 {
42 {"__bool__", std::string("func_bool")} // C.str_bool
43 }},
44 {kNumberTypeBool,
45 {
46 {"__and__", prim::kPrimBoolAnd}, // P.bool_and
47 {"__or__", prim::kPrimBoolOr}, // P.bool_or
48 {"__eq__", prim::kPrimBoolEq}, // P.bool_eq
49 {"__ne__", std::string("bool_ne")}, // C.bool_ne
50 {"__bool__", prim::kPrimIdentity} // P.identity
51 }},
52 {kNumberTypeInt,
53 {
54 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add
55 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
56 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
57 {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
58 {"__truediv__", std::string("int_truediv")}, // C.int_truediv
59 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
60 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
61 {"__floor__", prim::kPrimIdentity}, // P.identity
62 {"__trunc__", prim::kPrimIdentity}, // P.identity
63 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
64 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
65 {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq
66 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
67 {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt
68 {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt
69 {"__le__", prim::kPrimScalarLe}, // P.scalar_le
70 {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge
71 {"__bool__", std::string("int_bool")}, // C.int_bool
72 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
73 }},
74 {kNumberTypeUInt,
75 {
76 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
77 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
78 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
79 {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
80 {"__truediv__", std::string("int_truediv")}, // C.int_truediv
81 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
82 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
83 {"__floor__", prim::kPrimIdentity}, // P.identity,
84 {"__trunc__", prim::kPrimIdentity}, // P.identity,
85 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
86 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
87 {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
88 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
89 {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
90 {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
91 {"__le__", prim::kPrimScalarLe}, // P.scalar_le,
92 {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
93 {"__bool__", std::string("int_bool")}, // C.int_bool
94 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
95 }},
96 {kNumberTypeFloat,
97 {
98 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
99 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
100 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
101 {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
102 {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
103 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
104 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
105 {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
106 {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
107 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
108 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
109 {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq,
110 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
111 {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt,
112 {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt,
113 {"__le__", prim::kPrimScalarLe}, // P.scalar_le,
114 {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge,
115 {"__bool__", std::string("float_bool")}, // C.float_bool
116 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
117 }},
118 {kObjectTypeTuple,
119 {
120 {"__len__", prim::kPrimTupleLen}, // P.tuple_len,
121 {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem,
122 {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem,
123 {"__ms_iter__", prim::kPrimIdentity}, // P.identity,
124 {"__ms_next__", std::string("tuple_next")}, // C.tuple_next,
125 {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext
126 {"__bool__", std::string("tuple_bool")} // C.tuple_bool
127 }},
128 {kObjectTypeList,
129 {
130 {"__len__", prim::kPrimListLen}, // P.list_len,
131 {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem,
132 {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem,
133 {"__ms_iter__", prim::kPrimIdentity}, // P.identity
134 {"__ms_next__", std::string("list_next")}, // C.list_next
135 {"append", std::string("list_append")}, // C.list_next
136 {"__bool__", std::string("list_bool")}, // C.list_bool
137 {"__ms_hasnext__", std::string("list_hasnext")},
138 }},
139 {kObjectTypeDictionary,
140 {
141 {"__len__", prim::kPrimDictLen}, // P.dict_len
142 {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
143 {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
144 {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
145 {"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
146 {"__bool__", std::string("dict_bool")} // C.dict_bool
147 }},
148 {kObjectTypeTensorType,
149 {
150 {"all", std::string("all_")}, // C.reduce_all
151 {"any", std::string("any_")}, // C.reduce_any
152 {"__add__", std::string("add")}, // C.add
153 {"__sub__", std::string("sub")}, // C.sub
154 {"__mul__", std::string("mul")}, // C.mul
155 {"abs", std::string("abs_")}, // C.abs_
156 {"mean", std::string("mean")}, // C.mean
157 {"__truediv__", std::string("truediv")}, // C.truediv
158 {"__floordiv__", std::string("floordiv")}, // C.floordiv
159 {"__mod__", std::string("mod")}, // C.mod
160 {"__pow__", std::string("pow_")}, // C.pow
161 {"__floor__", std::string("array_floor")}, // C.array_floor
162 {"__trunc__", std::string("array_trunc")}, // C.array_trunc
163 {"__pos__", std::string("array_uadd")}, // C.array_uadd
164 {"__neg__", std::string("array_usub")}, // C.array_usub
165 {"__eq__", std::string("eq")}, // C.eq
166 {"__ne__", std::string("ne")}, // C.ne
167 {"__lt__", std::string("lt")}, // C.lt
168 {"__gt__", std::string("gt")}, // C.gt
169 {"__le__", std::string("le")}, // C.le
170 {"__ge__", std::string("ge")}, // C.ge
171 {"expand_as", std::string("expand_tensor_as")}, // C.expand_as
172 {"view", std::string("view")}, // C.view
173 {"__len__", prim::kPrimArrayLen}, // P.array_len,
174 {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem,
175 {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem,
176 {"__ms_iter__", std::string("array_iter")}, // C.array_iter
177 {"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
178 {"item", std::string("item")}, // P.item,
179 {"itemset", std::string("itemset")}, // P.itemset,
180 {"transpose", std::string("transpose")}, // P.transpose
181 {"flatten", std::string("flatten")}, // P.reshape(,-1)
182 {"reshape", std::string("reshape")}, // P.reshape()
183 {"ravel", std::string("ravel")}, // P.reshape(,(-1,))
184 {"swapaxes", std::string("swapaxes")}, // P.transpose()
185 {"squeeze", std::string("squeeze")}, // P.squeeze()
186 {"astype", std::string("astype")}, // P.cast()
187 {"cumsum", std::string("cumsum")}, // P.cumsum()
188 {"copy", std::string("copy")}, // copy()
189 {"max", std::string("max")}, // P.reduce_max()
190 {"min", std::string("min")}, // P.reduce_min()
191 {"fill", std::string("fill")}, // P.fill()
192 {"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
193 {"clip", std::string("clip")}, // P.maximum(P.minimum)
194 {"__bool__", std::string("tensor_bool")}, // C.tensor_bool
195 {"argmax", std::string("argmax")}, // P.Argmax()
196 {"argmin", std::string("argmin")}, // P.Argmax()
197 {"resize", std::string("resize")}, // P.Reshape()
198 {"choose", std::string("choose")}, // P.Select()
199 {"diagonal", std::string("diagonal")}, // P.Eye()
200 {"searchsorted", std::string("searchsorted")}, // P.Select()
201 {"take", std::string("take")}, // P.GatherNd()
202 {"trace", std::string("trace")}, // P.Eye()
203 {"var", std::string("var")}, // P.ReduceSum
204 {"std", std::string("std")}, // P.ReduceSum
205 {"sum", std::string("sum")}, // P.ReduceSum
206 {"repeat", std::string("repeat")}, // C.repeat_elements
207 }},
208 {kObjectTypeRowTensorType,
209 {
210 {"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
211 }},
212 {kObjectTypeJTagged, {}},
213 {kObjectTypeSymbolicKeyType, {}},
214 {kObjectTypeEnvType, {}}};
215 return method_map;
216 }
217
GetAttrMap()218 BuiltInTypeMap &GetAttrMap() {
219 static BuiltInTypeMap attr_map = {
220 {kObjectTypeTensorType,
221 {
222 {"shape", std::string("shape_")}, // C.shape_
223 {"dtype", std::string("dtype_")}, // C.dtype_
224 {"size", std::string("size_")}, // C.size_
225 {"ndim", std::string("ndim_")}, // C.ndim_
226 {"T", std::string("T_")}, // C.T_
227 {"itemsize", std::string("itemsize_")}, // C.itemsize_
228 {"nbytes", std::string("nbytes_")}, // C.nbytes_
229 {"strides", std::string("strides_")}, // C.strides_
230 }},
231 {kObjectTypeRowTensorType,
232 {
233 {"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values
234 {"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices
235 {"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape
236 }},
237 {kObjectTypeSparseTensorType,
238 {
239 {"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
240 {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
241 {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
242 }},
243 };
244 return attr_map;
245 }
246
Resource(const py::object & obj)247 Resource::Resource(const py::object &obj)
248 : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
249 source_input_(obj),
250 is_cleaned_(false) {}
251
~Resource()252 Resource::~Resource() {
253 MS_LOG(DEBUG) << "Resource clear";
254
255 std::unordered_map<std::string, Any>().swap(results_);
256 // If exit normally, these global variables will be cleaned
257 // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
258 // these global variables may not being cleaned, it may
259 // cause segmentfault when free python object inside these global variables
260 // after python interpreter got freed, so these global variables
261 // are cleaned here.
262 // So if exit normally, these global variable will be cleaned twice,
263 // care be taken to prevent double free in the following functions.
264 if (!is_cleaned_) {
265 try {
266 Clean();
267 } catch (const std::exception &e) {
268 MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
269 } catch (...) {
270 MS_LOG(ERROR) << "Exception when cleaning resource.";
271 }
272 }
273 }
274
GetMethodOrAttr(const string & name,const TypeId & type_id,const BuiltInTypeMap & method_map)275 Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
276 auto type_method_map = method_map.find(static_cast<int64_t>(type_id));
277 if (type_method_map == method_map.end()) {
278 return Any();
279 }
280 auto method = type_method_map->second.find(name);
281 if (method == type_method_map->second.end()) {
282 return Any();
283 }
284 return method->second;
285 }
286
IsTypeInBuiltInMap(const TypeId & type)287 bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
288 TypeId type_id = NormalizeTypeId(type);
289 const BuiltInTypeMap &method_map = GetMethodMap();
290 auto iter = method_map.find(static_cast<int64_t>(type_id));
291 if (iter == method_map.end()) {
292 const BuiltInTypeMap &attr_map = GetAttrMap();
293 iter = attr_map.find(static_cast<int64_t>(type_id));
294 if (iter == attr_map.end()) {
295 return false;
296 }
297 }
298 return true;
299 }
300
GetMethodPtr(const TypeId & type,const std::string & name)301 Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
302 TypeId type_id = NormalizeTypeId(type);
303 const BuiltInTypeMap &method_map = GetMethodMap();
304 return GetMethodOrAttr(name, type_id, method_map);
305 }
306
GetAttrPtr(const TypeId & type,const std::string & name)307 Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
308 TypeId type_id = NormalizeTypeId(type);
309 const BuiltInTypeMap &attr_map = GetAttrMap();
310 return GetMethodOrAttr(name, type_id, attr_map);
311 }
312
Clean()313 void Resource::Clean() {
314 // AbstractTensor->elements() will be saved in AbstractBasePtrList
315 args_spec_.clear();
316 source_input_ = py::none();
317 // Context with AbstractBasePtrList may be saved in GraphEvaluator
318 // some Evaluator like ResolveEvaluator may save Python object in cache,
319 // it should be cleaned before Python Interpreter destructed.
320 MS_EXCEPTION_IF_NULL(engine_);
321 engine_->ClearEvaluatorCache();
322 // clean static variable to prevent from crash. As static variable is released after
323 // Python threads is released.
324 parse::data_converter::ClearObjectCache();
325 parse::Parser::CleanParserResource();
326 parse::CleanDataClassToClassMap();
327 trace::ClearTraceStack();
328 is_cleaned_ = true;
329 }
330
331 } // namespace pipeline
332 } // namespace mindspore
333