• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/resource.h"
20 #include "pipeline/jit/static_analysis/static_analysis.h"
21 #include "debug/trace.h"
22 #include "ir/dtype.h"
23 #include "pipeline/jit/parse/data_converter.h"
24 #include "frontend/operator/ops.h"
25 #include "frontend/optimizer/ad/dfunctor.h"
26 
27 namespace mindspore {
28 // namespace to support opmap definition
29 namespace pipeline {
30 
GetMethodMap()31 BuiltInTypeMap &GetMethodMap() {
32   static BuiltInTypeMap method_map = {{kObjectTypeString,
33                                        {
34                                          {"__bool__", std::string("str_bool")}  // C.str_bool
35                                        }},
36                                       {kMetaTypeNone,
37                                        {
38                                          {"__bool__", std::string("none_bool")}  // C.none_bool
39                                        }},
40                                       {kObjectTypeFunction,
41                                        {
42                                          {"__bool__", std::string("func_bool")}  // C.str_bool
43                                        }},
44                                       {kNumberTypeBool,
45                                        {
46                                          {"__and__", prim::kPrimBoolAnd},     // P.bool_and
47                                          {"__or__", prim::kPrimBoolOr},       // P.bool_or
48                                          {"__eq__", prim::kPrimBoolEq},       // P.bool_eq
49                                          {"__ne__", std::string("bool_ne")},  // C.bool_ne
50                                          {"__bool__", prim::kPrimIdentity}    // P.identity
51                                        }},
52                                       {kNumberTypeInt,
53                                        {
54                                          {"__add__", prim::kPrimScalarAdd},              // P.scalar_add
55                                          {"__sub__", prim::kPrimScalarSub},              // P.scalar_sub
56                                          {"__mul__", prim::kPrimScalarMul},              // P.scalar_mul
57                                          {"__floordiv__", std::string("int_floordiv")},  // C.int_floordiv
58                                          {"__truediv__", std::string("int_truediv")},    // C.int_truediv
59                                          {"__mod__", prim::kPrimScalarMod},              // P.scalar_mod
60                                          {"__pow__", prim::kPrimScalarPow},              // P.scalar_pow
61                                          {"__floor__", prim::kPrimIdentity},             // P.identity
62                                          {"__trunc__", prim::kPrimIdentity},             // P.identity
63                                          {"__pos__", prim::kPrimScalarUadd},             // P.scalar_uadd
64                                          {"__neg__", prim::kPrimScalarUsub},             // P.scalar_usub
65                                          {"__eq__", prim::kPrimScalarEq},                // P.scalar_eq
66                                          {"__ne__", prim::kPrimScalarNe},                // P.scalar_ne
67                                          {"__lt__", prim::kPrimScalarLt},                // P.scalar_lt
68                                          {"__gt__", prim::kPrimScalarGt},                // P.scalar_gt
69                                          {"__le__", prim::kPrimScalarLe},                // P.scalar_le
70                                          {"__ge__", prim::kPrimScalarGe},                // P.scalar_ge
71                                          {"__bool__", std::string("int_bool")},          // C.int_bool
72                                          {"__ms_to_array__", prim::kPrimScalarToArray},  // P.scalar_to_array
73                                        }},
74                                       {kNumberTypeUInt,
75                                        {
76                                          {"__add__", prim::kPrimScalarAdd},              // P.scalar_add,
77                                          {"__sub__", prim::kPrimScalarSub},              // P.scalar_sub,
78                                          {"__mul__", prim::kPrimScalarMul},              // P.scalar_mul,
79                                          {"__floordiv__", prim::kPrimScalarDiv},         // P.scalar_div,
80                                          {"__truediv__", std::string("int_truediv")},    // C.int_truediv
81                                          {"__mod__", prim::kPrimScalarMod},              // P.scalar_mod,
82                                          {"__pow__", prim::kPrimScalarPow},              // P.scalar_pow,
83                                          {"__floor__", prim::kPrimIdentity},             // P.identity,
84                                          {"__trunc__", prim::kPrimIdentity},             // P.identity,
85                                          {"__pos__", prim::kPrimScalarUadd},             // P.scalar_uadd,
86                                          {"__neg__", prim::kPrimScalarUsub},             // P.scalar_usub,
87                                          {"__eq__", prim::kPrimScalarEq},                // P.scalar_eq,
88                                          {"__ne__", prim::kPrimScalarNe},                // P.scalar_ne,
89                                          {"__lt__", prim::kPrimScalarLt},                // P.scalar_lt,
90                                          {"__gt__", prim::kPrimScalarGt},                // P.scalar_gt,
91                                          {"__le__", prim::kPrimScalarLe},                // P.scalar_le,
92                                          {"__ge__", prim::kPrimScalarGe},                // P.scalar_ge,
93                                          {"__bool__", std::string("int_bool")},          // C.int_bool
94                                          {"__ms_to_array__", prim::kPrimScalarToArray},  // P.scalar_to_array,
95                                        }},
96                                       {kNumberTypeFloat,
97                                        {
98                                          {"__add__", prim::kPrimScalarAdd},                // P.scalar_add,
99                                          {"__sub__", prim::kPrimScalarSub},                // P.scalar_sub,
100                                          {"__mul__", prim::kPrimScalarMul},                // P.scalar_mul,
101                                          {"__floordiv__", std::string("float_floordiv")},  // C.float_floordiv
102                                          {"__truediv__", prim::kPrimScalarDiv},            // P.scalar_div,
103                                          {"__mod__", prim::kPrimScalarMod},                // P.scalar_mod,
104                                          {"__pow__", prim::kPrimScalarPow},                // P.scalar_pow,
105                                          {"__floor__", prim::kPrimScalarFloor},            // P.scalar_floor,
106                                          {"__trunc__", prim::kPrimScalarTrunc},            // P.scalar_trunc,
107                                          {"__pos__", prim::kPrimScalarUadd},               // P.scalar_uadd,
108                                          {"__neg__", prim::kPrimScalarUsub},               // P.scalar_usub,
109                                          {"__eq__", prim::kPrimScalarEq},                  // P.scalar_eq,
110                                          {"__ne__", prim::kPrimScalarNe},                  // P.scalar_ne,
111                                          {"__lt__", prim::kPrimScalarLt},                  // P.scalar_lt,
112                                          {"__gt__", prim::kPrimScalarGt},                  // P.scalar_gt,
113                                          {"__le__", prim::kPrimScalarLe},                  // P.scalar_le,
114                                          {"__ge__", prim::kPrimScalarGe},                  // P.scalar_ge,
115                                          {"__bool__", std::string("float_bool")},          // C.float_bool
116                                          {"__ms_to_array__", prim::kPrimScalarToArray},    // P.scalar_to_array,
117                                        }},
118                                       {kObjectTypeTuple,
119                                        {
120                                          {"__len__", prim::kPrimTupleLen},                  // P.tuple_len,
121                                          {"__getitem__", prim::kPrimTupleGetItem},          // P.tuple_getitem,
122                                          {"__setitem__", prim::kPrimTupleSetItem},          // P.tuple_setitem,
123                                          {"__ms_iter__", prim::kPrimIdentity},              // P.identity,
124                                          {"__ms_next__", std::string("tuple_next")},        // C.tuple_next,
125                                          {"__ms_hasnext__", std::string("tuple_hasnext")},  // C.tuple_hasnext
126                                          {"__bool__", std::string("tuple_bool")}            // C.tuple_bool
127                                        }},
128                                       {kObjectTypeList,
129                                        {
130                                          {"__len__", prim::kPrimListLen},            // P.list_len,
131                                          {"__getitem__", prim::kPrimListGetItem},    // P.list_getitem,
132                                          {"__setitem__", prim::kPrimListSetItem},    // P.list_setitem,
133                                          {"__ms_iter__", prim::kPrimIdentity},       // P.identity
134                                          {"__ms_next__", std::string("list_next")},  // C.list_next
135                                          {"append", std::string("list_append")},     // C.list_next
136                                          {"__bool__", std::string("list_bool")},     // C.list_bool
137                                          {"__ms_hasnext__", std::string("list_hasnext")},
138                                        }},
139                                       {kObjectTypeDictionary,
140                                        {
141                                          {"__len__", prim::kPrimDictLen},          // P.dict_len
142                                          {"__getitem__", prim::kPrimDictGetItem},  // P.dict_getitem
143                                          {"__setitem__", prim::kPrimDictSetItem},  // P.dict_setitem,
144                                          {"keys", prim::kPrimDictGetKeys},         // P.dict_getkeys,
145                                          {"values", prim::kPrimDictGetValues},     // P.dict_getvalues,
146                                          {"__bool__", std::string("dict_bool")}    // C.dict_bool
147                                        }},
148                                       {kObjectTypeTensorType,
149                                        {
150                                          {"all", std::string("all_")},                    // C.reduce_all
151                                          {"any", std::string("any_")},                    // C.reduce_any
152                                          {"__add__", std::string("add")},                 // C.add
153                                          {"__sub__", std::string("sub")},                 // C.sub
154                                          {"__mul__", std::string("mul")},                 // C.mul
155                                          {"abs", std::string("abs_")},                    // C.abs_
156                                          {"mean", std::string("mean")},                   // C.mean
157                                          {"__truediv__", std::string("truediv")},         // C.truediv
158                                          {"__floordiv__", std::string("floordiv")},       // C.floordiv
159                                          {"__mod__", std::string("mod")},                 // C.mod
160                                          {"__pow__", std::string("pow_")},                // C.pow
161                                          {"__floor__", std::string("array_floor")},       // C.array_floor
162                                          {"__trunc__", std::string("array_trunc")},       // C.array_trunc
163                                          {"__pos__", std::string("array_uadd")},          // C.array_uadd
164                                          {"__neg__", std::string("array_usub")},          // C.array_usub
165                                          {"__eq__", std::string("eq")},                   // C.eq
166                                          {"__ne__", std::string("ne")},                   // C.ne
167                                          {"__lt__", std::string("lt")},                   // C.lt
168                                          {"__gt__", std::string("gt")},                   // C.gt
169                                          {"__le__", std::string("le")},                   // C.le
170                                          {"__ge__", std::string("ge")},                   // C.ge
171                                          {"expand_as", std::string("expand_tensor_as")},  // C.expand_as
172                                          {"view", std::string("view")},                   // C.view
173                                          {"__len__", prim::kPrimArrayLen},                // P.array_len,
174                                          {"__getitem__", prim::kPrimArrayGetItem},        // P.array_getitem,
175                                          {"__setitem__", prim::kPrimArraySetItem},        // P.array_setitem,
176                                          {"__ms_iter__", std::string("array_iter")},      // C.array_iter
177                                          {"__ms_to_array__", prim::kPrimIdentity},        // P.identity,
178                                          {"item", std::string("item")},                   // P.item,
179                                          {"itemset", std::string("itemset")},             // P.itemset,
180                                          {"transpose", std::string("transpose")},         // P.transpose
181                                          {"flatten", std::string("flatten")},             // P.reshape(,-1)
182                                          {"reshape", std::string("reshape")},             // P.reshape()
183                                          {"ravel", std::string("ravel")},                 // P.reshape(,(-1,))
184                                          {"swapaxes", std::string("swapaxes")},           // P.transpose()
185                                          {"squeeze", std::string("squeeze")},             // P.squeeze()
186                                          {"astype", std::string("astype")},               // P.cast()
187                                          {"cumsum", std::string("cumsum")},               // P.cumsum()
188                                          {"copy", std::string("copy")},                   // copy()
189                                          {"max", std::string("max")},                     // P.reduce_max()
190                                          {"min", std::string("min")},                     // P.reduce_min()
191                                          {"fill", std::string("fill")},                   // P.fill()
192                                          {"ptp", std::string("ptp")},               // P.reduce_max() - P.reduce_min()
193                                          {"clip", std::string("clip")},             // P.maximum(P.minimum)
194                                          {"__bool__", std::string("tensor_bool")},  // C.tensor_bool
195                                          {"argmax", std::string("argmax")},         // P.Argmax()
196                                          {"argmin", std::string("argmin")},         // P.Argmax()
197                                          {"resize", std::string("resize")},         // P.Reshape()
198                                          {"choose", std::string("choose")},         // P.Select()
199                                          {"diagonal", std::string("diagonal")},     // P.Eye()
200                                          {"searchsorted", std::string("searchsorted")},  // P.Select()
201                                          {"take", std::string("take")},                  // P.GatherNd()
202                                          {"trace", std::string("trace")},                // P.Eye()
203                                          {"var", std::string("var")},                    // P.ReduceSum
204                                          {"std", std::string("std")},                    // P.ReduceSum
205                                          {"sum", std::string("sum")},                    // P.ReduceSum
206                                          {"repeat", std::string("repeat")},              // C.repeat_elements
207                                        }},
208                                       {kObjectTypeRowTensorType,
209                                        {
210                                          {"__add__", prim::kPrimRowTensorAdd},  // P.row_tensor_add
211                                        }},
212                                       {kObjectTypeJTagged, {}},
213                                       {kObjectTypeSymbolicKeyType, {}},
214                                       {kObjectTypeEnvType, {}}};
215   return method_map;
216 }
217 
GetAttrMap()218 BuiltInTypeMap &GetAttrMap() {
219   static BuiltInTypeMap attr_map = {
220     {kObjectTypeTensorType,
221      {
222        {"shape", std::string("shape_")},        // C.shape_
223        {"dtype", std::string("dtype_")},        // C.dtype_
224        {"size", std::string("size_")},          // C.size_
225        {"ndim", std::string("ndim_")},          // C.ndim_
226        {"T", std::string("T_")},                // C.T_
227        {"itemsize", std::string("itemsize_")},  // C.itemsize_
228        {"nbytes", std::string("nbytes_")},      // C.nbytes_
229        {"strides", std::string("strides_")},    // C.strides_
230      }},
231     {kObjectTypeRowTensorType,
232      {
233        {"values", prim::kPrimRowTensorGetValues},           // F.row_tensor_get_values
234        {"indices", prim::kPrimRowTensorGetIndices},         // F.row_tensor_get_indices
235        {"dense_shape", prim::kPrimRowTensorGetDenseShape},  // F.row_tensor_get_dense_shape
236      }},
237     {kObjectTypeSparseTensorType,
238      {
239        {"values", prim::kPrimSparseTensorGetValues},           // F.sparse_tensor_get_values
240        {"indices", prim::kPrimSparseTensorGetIndices},         // F.sparse_tensor_get_indices
241        {"dense_shape", prim::kPrimSparseTensorGetDenseShape},  // F.sparse_tensor_get_dense_shape
242      }},
243   };
244   return attr_map;
245 }
246 
Resource(const py::object & obj)247 Resource::Resource(const py::object &obj)
248     : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
249       source_input_(obj),
250       is_cleaned_(false) {}
251 
~Resource()252 Resource::~Resource() {
253   MS_LOG(DEBUG) << "Resource clear";
254 
255   std::unordered_map<std::string, Any>().swap(results_);
256   // If exit normally, these global variables will be cleaned
257   // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
258   // these global variables may not being cleaned, it may
259   // cause segmentfault when free python object inside these global variables
260   // after python interpreter got freed, so these global variables
261   // are cleaned here.
262   // So if exit normally, these global variable will be cleaned twice,
263   // care be taken to prevent double free in the following functions.
264   if (!is_cleaned_) {
265     try {
266       Clean();
267     } catch (const std::exception &e) {
268       MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
269     } catch (...) {
270       MS_LOG(ERROR) << "Exception when cleaning resource.";
271     }
272   }
273 }
274 
GetMethodOrAttr(const string & name,const TypeId & type_id,const BuiltInTypeMap & method_map)275 Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
276   auto type_method_map = method_map.find(static_cast<int64_t>(type_id));
277   if (type_method_map == method_map.end()) {
278     return Any();
279   }
280   auto method = type_method_map->second.find(name);
281   if (method == type_method_map->second.end()) {
282     return Any();
283   }
284   return method->second;
285 }
286 
IsTypeInBuiltInMap(const TypeId & type)287 bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
288   TypeId type_id = NormalizeTypeId(type);
289   const BuiltInTypeMap &method_map = GetMethodMap();
290   auto iter = method_map.find(static_cast<int64_t>(type_id));
291   if (iter == method_map.end()) {
292     const BuiltInTypeMap &attr_map = GetAttrMap();
293     iter = attr_map.find(static_cast<int64_t>(type_id));
294     if (iter == attr_map.end()) {
295       return false;
296     }
297   }
298   return true;
299 }
300 
GetMethodPtr(const TypeId & type,const std::string & name)301 Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
302   TypeId type_id = NormalizeTypeId(type);
303   const BuiltInTypeMap &method_map = GetMethodMap();
304   return GetMethodOrAttr(name, type_id, method_map);
305 }
306 
GetAttrPtr(const TypeId & type,const std::string & name)307 Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
308   TypeId type_id = NormalizeTypeId(type);
309   const BuiltInTypeMap &attr_map = GetAttrMap();
310   return GetMethodOrAttr(name, type_id, attr_map);
311 }
312 
Clean()313 void Resource::Clean() {
314   // AbstractTensor->elements() will be saved in AbstractBasePtrList
315   args_spec_.clear();
316   source_input_ = py::none();
317   // Context with AbstractBasePtrList may be saved in GraphEvaluator
318   // some Evaluator like ResolveEvaluator may save Python object in cache,
319   // it should be cleaned before Python Interpreter destructed.
320   MS_EXCEPTION_IF_NULL(engine_);
321   engine_->ClearEvaluatorCache();
322   // clean static variable to prevent from crash. As static variable is released after
323   // Python threads is released.
324   parse::data_converter::ClearObjectCache();
325   parse::Parser::CleanParserResource();
326   parse::CleanDataClassToClassMap();
327   trace::ClearTraceStack();
328   is_cleaned_ = true;
329 }
330 
331 }  // namespace pipeline
332 }  // namespace mindspore
333