• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/resource.h"
20 #include "mindspore/core/ops/sparse_tensor_ops.h"
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/comparison_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/arithmetic_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "include/common/utils/compile_cache_context.h"
27 #include "ir/dtype.h"
28 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
29 #include "pipeline/jit/ps/debug/trace.h"
30 #include "pipeline/jit/ps/parse/data_converter.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/ad/dfunctor.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "utils/ms_utils.h"
36 
37 namespace mindspore {
38 namespace pipeline {
GetMethodMap()39 BuiltInTypeMap &GetMethodMap() {
40   static BuiltInTypeMap method_map = {
41     {kObjectTypeString,
42      {{"__bool__", std::string("str_bool")},  // C.str_bool
43       {"format", std::string("_format")},
44       {"upper", prim::kPrimStringUpper},
45       {"lower", prim::kPrimStringLower}}},
46     {kMetaTypeNone,
47      {
48        {"__bool__", std::string("none_bool")}  // C.none_bool
49      }},
50     {kObjectTypeFunction,
51      {
52        {"__bool__", std::string("func_bool")}  // C.str_bool
53      }},
54     {kNumberTypeBool,
55      {
56        {"__and__", prim::kPrimBoolAnd},     // P.bool_and
57        {"__or__", prim::kPrimBoolOr},       // P.bool_or
58        {"__eq__", prim::kPrimBoolEq},       // P.bool_eq
59        {"__ne__", std::string("bool_ne")},  // C.bool_ne
60        {"__bool__", prim::kPrimidentity}    // P.identity
61      }},
62     {kNumberTypeInt,
63      {
64        {"__add__", prim::kPrimScalarAdd},              // P.scalar_add
65        {"__sub__", prim::kPrimScalarSub},              // P.scalar_sub
66        {"__mul__", prim::kPrimScalarMul},              // P.scalar_mul
67        {"__floordiv__", std::string("int_floordiv")},  // C.int_floordiv
68        {"__truediv__", std::string("int_truediv")},    // C.int_truediv
69        {"__mod__", prim::kPrimScalarMod},              // P.scalar_mod
70        {"__pow__", prim::kPrimScalarPow},              // P.scalar_pow
71        {"__floor__", prim::kPrimidentity},             // P.identity
72        {"__trunc__", prim::kPrimidentity},             // P.identity
73        {"__pos__", prim::kPrimScalarUadd},             // P.scalar_uadd
74        {"__neg__", prim::kPrimScalarUsub},             // P.scalar_usub
75        {"__eq__", prim::kPrimScalarEq},                // P.ScalarEq
76        {"__ne__", prim::kPrimScalarNe},                // P.scalar_ne
77        {"__lt__", prim::kPrimScalarLt},                // P.ScalarLt
78        {"__gt__", prim::kPrimScalarGt},                // P.ScalarGt
79        {"__le__", prim::kPrimScalarLe},                // P.ScalarLe
80        {"__ge__", prim::kPrimScalarGe},                // P.ScalarGe
81        {"__bool__", std::string("int_bool")},          // C.int_bool
82        {"__ms_to_array__", prim::kPrimScalarToArray},  // P.scalar_to_array
83      }},
84     {kNumberTypeUInt,
85      {
86        {"__add__", prim::kPrimScalarAdd},              // P.scalar_add,
87        {"__sub__", prim::kPrimScalarSub},              // P.scalar_sub,
88        {"__mul__", prim::kPrimScalarMul},              // P.scalar_mul,
89        {"__floordiv__", prim::kPrimScalarDiv},         // P.scalar_div,
90        {"__truediv__", std::string("int_truediv")},    // C.int_truediv
91        {"__mod__", prim::kPrimScalarMod},              // P.scalar_mod,
92        {"__pow__", prim::kPrimScalarPow},              // P.scalar_pow,
93        {"__floor__", prim::kPrimidentity},             // P.identity,
94        {"__trunc__", prim::kPrimidentity},             // P.identity,
95        {"__pos__", prim::kPrimScalarUadd},             // P.scalar_uadd,
96        {"__neg__", prim::kPrimScalarUsub},             // P.scalar_usub,
97        {"__eq__", prim::kPrimScalarEq},                // P.ScalarEq,
98        {"__ne__", prim::kPrimScalarNe},                // P.scalar_ne,
99        {"__lt__", prim::kPrimScalarLt},                // P.ScalarLt,
100        {"__gt__", prim::kPrimScalarGt},                // P.ScalarGt,
101        {"__le__", prim::kPrimScalarLe},                // P.ScalarLe,
102        {"__ge__", prim::kPrimScalarGe},                // P.ScalarGe,
103        {"__bool__", std::string("int_bool")},          // C.int_bool
104        {"__ms_to_array__", prim::kPrimScalarToArray},  // P.scalar_to_array
105      }},
106     {kNumberTypeFloat,
107      {
108        {"__add__", prim::kPrimScalarAdd},                // P.scalar_add,
109        {"__sub__", prim::kPrimScalarSub},                // P.scalar_sub,
110        {"__mul__", prim::kPrimScalarMul},                // P.scalar_mul,
111        {"__floordiv__", std::string("float_floordiv")},  // C.float_floordiv
112        {"__truediv__", prim::kPrimScalarDiv},            // P.scalar_div,
113        {"__mod__", prim::kPrimScalarMod},                // P.scalar_mod,
114        {"__pow__", prim::kPrimScalarPow},                // P.scalar_pow,
115        {"__floor__", prim::kPrimScalarFloor},            // P.scalar_floor,
116        {"__trunc__", prim::kPrimScalarTrunc},            // P.scalar_trunc,
117        {"__pos__", prim::kPrimScalarUadd},               // P.scalar_uadd,
118        {"__neg__", prim::kPrimScalarUsub},               // P.scalar_usub,
119        {"__eq__", prim::kPrimScalarEq},                  // P.ScalarEq,
120        {"__ne__", prim::kPrimScalarNe},                  // P.scalar_ne,
121        {"__lt__", prim::kPrimScalarLt},                  // P.ScalarLt,
122        {"__gt__", prim::kPrimScalarGt},                  // P.ScalarGt,
123        {"__le__", prim::kPrimScalarLe},                  // P.ScalarLe,
124        {"__ge__", prim::kPrimScalarGe},                  // P.ScalarGe,
125        {"__bool__", std::string("float_bool")},          // C.float_bool
126        {"__ms_to_array__", prim::kPrimScalarToArray},    // P.scalar_to_array,
127      }},
128     {kObjectTypeTuple,
129      {
130        {"__len__", prim::kPrimSequenceLen},       // P.sequence_len,
131        {"__getitem__", std::string("_getitem")},  // C.getitem,
132        {"__setitem__", std::string("_setitem")},  // C.setitem,
133        {"count", prim::kPrimSequenceCount},       // P.sequence_count
134        {"index", std::string("sequence_index")},  // C.sequence_index
135      }},
136     {kObjectTypeList,
137      {
138        {"__len__", prim::kPrimSequenceLen},       // P.sequence_len,
139        {"__getitem__", std::string("_getitem")},  // C.getitem,
140        {"__setitem__", std::string("_setitem")},  // C.setitem,
141        {"append", std::string("list_append")},    // C.list_append
142        {"insert", std::string("list_insert")},    // C.list_insert
143        {"pop", std::string("list_pop")},          // C.list_pop
144        {"clear", std::string("list_clear")},      // C.list_clear
145        {"reverse", std::string("list_reverse")},  // C.list_reverse
146        {"extend", std::string("list_extend")},    // C.list_extend
147        {"count", prim::kPrimSequenceCount},       // P.sequence_count
148        {"index", std::string("sequence_index")},  // C.sequence_index
149      }},
150     {kObjectTypeDictionary,
151      {
152        {"__len__", prim::kPrimDictLen},            // P.dict_len
153        {"__getitem__", std::string("_getitem")},   // C.getitem,
154        {"__setitem__", std::string("_setitem")},   // C.setitem,
155        {"keys", prim::kPrimDictGetKeys},           // P.dict_getkeys,
156        {"values", prim::kPrimDictGetValues},       // P.dict_getvalues,
157        {"items", prim::kPrimDictItems},            // P.dict_items
158        {"get", std::string("dict_get")},           // C.dict_get
159        {"has_key", std::string("dict_haskey")},    // C.dict_haskey
160        {"clear", std::string("dict_clear")},       // C.dict_clear
161        {"update", std::string("dict_update")},     // C.dict_update
162        {"fromkeys", std::string("dict_fromkeys")}  // C.dict_fromkeys
163      }},
164     {kObjectTypeTensorType,
165      {
166        {"addcdiv", std::string("addcdiv")},                                // C.addcdiv
167        {"addcmul", std::string("addcmul")},                                // C.addcmul
168        {"all", std::string("all_")},                                       // C.reduce_all
169        {"atan2", std::string("atan2")},                                    // P.Atan2
170        {"angle", std::string("angle")},                                    // C.reduce_any
171        {"any", std::string("any_")},                                       // C.reduce_any
172        {"bincount", std::string("bincount")},                              // bincount
173        {"chunk", std::string("chunk")},                                    // chunk
174        {"contiguous", prim::kPrimidentity},                                // contiguous
175        {"slogdet", std::string("slogdet")},                                // slogdet
176        {"trace", std::string("trace")},                                    // trace
177        {"tril", std::string("tril")},                                      // tril
178        {"__add__", std::string("add")},                                    // C.add
179        {"__sub__", std::string("sub")},                                    // C.sub
180        {"__mul__", std::string("mul")},                                    // C.mul
181        {"__matmul__", std::string("matmul")},                              // F.matmul
182        {"xdivy", std::string("xdivy")},                                    // P.Xdivy
183        {"abs", std::string("abs_")},                                       // C.abs_
184        {"absolute", std::string("abs_")},                                  // C.abs_
185        {"mean", std::string("mean")},                                      // C.mean
186        {"prod", std::string("prod")},                                      // C.reduce_prod
187        {"__truediv__", std::string("truediv")},                            // C.truediv
188        {"__floordiv__", std::string("floordiv")},                          // C.floordiv
189        {"__mod__", std::string("mod")},                                    // C.mod
190        {"__pow__", std::string("pow_")},                                   // C.pow
191        {"__floor__", std::string("floor")},                                // P.floor
192        {"__trunc__", std::string("array_trunc")},                          // C.array_trunc
193        {"__pos__", std::string("array_uadd")},                             // C.array_uadd
194        {"__neg__", std::string("array_usub")},                             // C.array_usub
195        {"__eq__", std::string("eq")},                                      // C.eq
196        {"__ne__", std::string("ne")},                                      // C.ne
197        {"__lt__", std::string("lt")},                                      // C.lt
198        {"__gt__", std::string("gt")},                                      // C.gt
199        {"__le__", std::string("le")},                                      // C.le
200        {"__ge__", std::string("ge")},                                      // C.ge
201        {"gt", std::string("gt")},                                          // P.Greater
202        {"ge", std::string("ge")},                                          // P.GreaterEqual
203        {"expand_as", std::string("expand_tensor_as")},                     // C.expand_as
204        {"broadcast_to", std::string("broadcast_to")},                      // P.BroadcastTo
205        {"view", std::string("view")},                                      // C.view
206        {"view_as", std::string("view_as")},                                // view_as()
207        {"__len__", prim::kPrimArrayLen},                                   // P.array_len,
208        {"__getitem__", std::string("_getitem")},                           // C.getitem,
209        {"__setitem__", std::string("_setitem")},                           // C.setitem,
210        {"__ms_to_array__", prim::kPrimidentity},                           // P.identity,
211        {"gather_elements", std::string("gather_elements")},                // P.GatherD
212        {"item", std::string("item")},                                      // P.item,
213        {"itemset", std::string("itemset")},                                // P.itemset,
214        {"transpose", std::string("transpose")},                            // P.transpose
215        {"flatten", std::string("flatten")},                                // P.reshape(,-1)
216        {"reshape", std::string("reshape")},                                // P.reshape()
217        {"reshape_as", std::string("reshape_as")},                          // P.reshape()
218        {"reverse", std::string("reverse")},                                // P.ReverseV2()
219        {"reverse_sequence", std::string("reverse_sequence")},              // P.ReverseSequence()
220        {"bitwise_and", std::string("bitwise_and")},                        // P.BitwiseAnd()
221        {"bitwise_or", std::string("bitwise_or")},                          // P.BitwiseOr()
222        {"bitwise_xor", std::string("bitwise_xor")},                        // P.BitwiseXor()
223        {"bitwise_left_shift", std::string("bitwise_left_shift")},          // bitwise_left_shift
224        {"bitwise_right_shift", std::string("bitwise_right_shift")},        // bitwise_right_shift
225        {"tan", std::string("tan")},                                        // P.Tan()
226        {"ger", std::string("ger")},                                        // P.Ger()
227        {"ravel", std::string("ravel")},                                    // P.reshape(,(-1,))
228        {"swapaxes", std::string("swapaxes")},                              // P.transpose()
229        {"swapdims", std::string("swapdims")},                              // P.transpose()
230        {"narrow", std::string("narrow")},                                  // narrow()
231        {"masked_fill", std::string("masked_fill")},                        // masked_fill()
232        {"masked_select", std::string("masked_select")},                    // masked_select()
233        {"nonzero", std::string("nonzero")},                                // nonzero()
234        {"expand_dims", std::string("expand_dims")},                        // P.expand_dims()
235        {"squeeze", std::string("squeeze")},                                // P.squeeze()
236        {"unbind", std::string("unbind")},                                  // P.Unstack()
237        {"unsqueeze", std::string("unsqueeze")},                            // P.expand_dims()
238        {"astype", std::string("astype")},                                  // P.cast()
239        {"short", std::string("short")},                                    // P.cast()
240        {"median", std::string("median")},                                  // P.median()
241        {"cumsum", std::string("cumsum")},                                  // P.cumsum()
242        {"cummin", std::string("cummin")},                                  // cummin()
243        {"cummax", std::string("cummax")},                                  // cummax()
244        {"index_fill", std::string("index_fill")},                          // index_fill()
245        {"index_select", std::string("index_select")},                      // index_select()
246        {"repeat_interleave", std::string("repeat_interleave")},            // repeat_interleave()
247        {"copy", std::string("copy")},                                      // copy()
248        {"copysign", std::string("copysign")},                              // copysign()
249        {"inplace_update", std::string("inplace_update")},                  // P.InplaceUpdateV2
250        {"lerp", std::string("lerp")},                                      // lerp()
251        {"lcm", std::string("lcm")},                                        // F.lcm()
252        {"ldexp", std::string("ldexp")},                                    // F.ldexp()
253        {"log1p", std::string("log1p")},                                    // P.Log1p()
254        {"logcumsumexp", std::string("logcumsumexp")},                      // logcumsumexp()
255        {"logit", std::string("logit")},                                    // Logit()
256        {"negative", std::string("negative")},                              // neg()
257        {"logdet", std::string("logdet")},                                  // logdet()
258        {"log_matrix_determinant", std::string("log_matrix_determinant")},  // log_matrix_determinant()
259        {"matrix_determinant", std::string("matrix_determinant")},          // matrix_determinant()
260        {"matrix_power", std::string("matrix_power")},                      // P.MatrixPower()
261        {"det", std::string("det")},                                        // matrix_determinant()
262        {"ndimension", std::string("ndim_")},                               // ndimension()
263        {"max", std::string("max")},                                        // P.reduce_max()
264        {"min", std::string("min")},                                        // P.reduce_min()
265        {"pow", std::string("pow")},                                        // P.Pow()
266        {"log", std::string("log")},                                        // P.Log()
267        {"nelement", std::string("numel")},                                 // numel()
268        {"numel", std::string("numel")},                                    // numel()
269        {"permute", std::string("permute")},                                // permute()
270        {"positive", std::string("positive")},                              // positive()
271        {"remainder", std::string("remainder")},                            // remainder()
272        {"log10", std::string("log10")},                                    // F.log10()
273        {"log2", std::string("log2")},                                      // F.log2()
274        {"logaddexp", std::string("logaddexp")},                            // logaddexp()
275        {"logaddexp2", std::string("logaddexp2")},                          // logaddexp2()
276        {"logsumexp", std::string("logsumexp")},                            // logsumexp()
277        {"isneginf", std::string("isneginf")},                              // isneginf()
278        {"isposinf", std::string("isposinf")},                              // isposinf()
279        {"isreal", std::string("isreal")},                                  // isreal()
280        {"minimum", std::string("minimum")},                                // P.Minimum()
281        {"cosh", std::string("cosh")},                                      // P.Cosh()
282        {"tanh", std::string("tanh")},                                      // P.Tanh()
283        {"rad2deg", std::string("rad2deg")},                                // F.rad2deg()
284        {"deg2rad", std::string("deg2rad")},                                // F.deg2rad()
285        {"dot", std::string("dot")},                                        // composite.dot()
286        {"round", std::string("round_")},                                   // P.Round()
287        {"roll", std::string("roll")},                                      // P.Roll()
288        {"rot90", std::string("rot90")},                                    // rot90()
289        {"fill", std::string("fill")},                                      // P.fill()
290        {"fills", std::string("fills")},                                    // P.fills
291        {"fill_diagonal", std::string("fill_diagonal")},                    // P.FillDiagonal()
292        {"uniform", std::string("uniform")},                                // P.UniformExt()
293        {"ptp", std::string("ptp")},                                        // P.reduce_max() - P.reduce_min()
294        {"clamp", std::string("clamp")},                                    // clamp()
295        {"clip", std::string("clamp")},                                     // clamp()
296        {"__bool__", std::string("tensor_bool")},                           // C.tensor_bool
297        {"argmax", std::string("argmax")},                                  // P.Argmax()
298        {"argmin", std::string("argmin")},                                  // P.Argmax()
299        {"resize", std::string("resize")},                                  // P.Reshape()
300        {"crop_and_resize", std::string("crop_and_resize")},                // P.crop_and_resize
301        {"select", std::string("select")},                                  // P.Select()
302        {"choose", std::string("choose")},                                  // P.Select()
303        {"diagonal", std::string("diagonal")},                              // P.Eye()
304        {"diagonal_scatter", std::string("diagonal_scatter")},              // diagonal_scatter()
305        {"i0", std::string("i0")},                                          // F.i0()
306        {"isclose", std::string("isclose")},                                // P.IsClose()
307        {"is_floating_point", std::string("is_floating_point")},            // is_floating_point()
308        {"is_signed", std::string("is_signed")},                            // is_signed()
309        {"is_complex", std::string("is_complex")},                          // F.is_complex()
310        {"inv", std::string("inv")},                                        // inv()
311        {"inverse", std::string("inverse")},                                // inverse()
312        {"invert", std::string("invert")},                                  // invert()
313        {"searchsorted", std::string("searchsorted")},                      // P.Select()
314        {"take", std::string("take")},                                      // P.GatherNd()
315        {"gather", std::string("gather")},                                  // P.Gather()
316        {"scatter", std::string("scatter")},                                // P.TensorScatterElements()
317        {"scatter_add", std::string("tensor_scatter_add")},                 // P.TensorScatterAdd()
318        {"scatter_mul", std::string("tensor_scatter_mul")},                 // tensor_scatter_mul()
319        {"scatter_sub", std::string("tensor_scatter_sub")},                 // P.TensorScatterSub()
320        {"scatter_min", std::string("tensor_scatter_min")},                 // P.TensorScatterMin()
321        {"scatter_max", std::string("tensor_scatter_max")},                 // P.TensorScatterMax()
322        {"scatter_div", std::string("tensor_scatter_div")},                 // P.TensorScatterDiv()
323        {"slice_scatter", std::string("slice_scatter")},                    // slice_scatter()
324        {"select_scatter", std::string("select_scatter")},                  // select_scatter()
325        {"norm", std::string("norm")},                                      // norm()
326        {"unsorted_segment_min", std::string("unsorted_segment_min")},      // P.UnsortedSegmentMin()
327        {"unsorted_segment_max", std::string("unsorted_segment_max")},      // P.UnsortedSegmentMax()
328        {"unsorted_segment_prod", std::string("unsorted_segment_prod")},    // P.UnsortedSegmentProd()
329        {"renorm", std::string("renorm")},                                  // renorm()
330        {"real", std::string("real")},                                      // real()
331        {"reciprocal", std::string("reciprocal")},                          // reciprocal()
332        {"rsqrt", std::string("rsqrt")},                                    // rsqrt()
333        {"trace", std::string("trace")},                                    // P.Eye()
334        {"var", std::string("var")},                                        // P.ReduceSum
335        {"std", std::string("std")},                                        // P.ReduceSum
336        {"sum", std::string("sum")},                                        // P.ReduceSum
337        {"sqrt", std::string("sqrt")},                                      // P.Sqrt()
338        {"square", std::string("square")},                                  // P.Square()
339        {"sub", std::string("sub")},                                        // P.Sub()
340        {"true_divide", std::string("true_divide")},                        // true_divide()
341        {"triu", std::string("triu")},                                      // triu()
342        {"subtract", std::string("subtract")},                              // true_divide()
343        {"sum_to_size", std::string("sum_to_size")},                        // sum_to_size()
344        {"exp", std::string("exp")},                                        // P.Exp()
345        {"repeat", std::string("repeat")},                                  // C.repeat_elements
346        {"bernoulli", std::string("bernoulli")},                            // P.Bernoulli()
347        {"ceil", std::string("ceil")},                                      // P.Ceil
348        {"floor", std::string("floor")},                                    // P.floor
349        {"floor_divide", std::string("floor_divide")},                      // floor_divide
350        {"flip", std::string("flip")},                                      // flip
351        {"fliplr", std::string("fliplr")},                                  // fliplr
352        {"flipud", std::string("flipud")},                                  // flipud
353        {"float_power", std::string("float_power")},                        // F.float_power
354        {"fmax", std::string("fmax")},                                      // fmax()
355        {"fmin", std::string("fmin")},                                      // fmin()
356        {"fmod", std::string("fmod")},                                      // F.fmod
357        {"hardshrink", std::string("hardshrink")},                          // P.hshrink
358        {"heaviside", std::string("heaviside")},                            // F.heaviside
359        {"hypot", std::string("hypot")},                                    // F.hypot
360        {"soft_shrink", std::string("soft_shrink")},                        // P.SoftShrink
361        {"gather_nd", std::string("gather_nd")},                            // P.GatherNd()
362        {"unique_consecutive", std::string("unique_consecutive")},          // UniqueConsecutive()
363        {"unique_with_pad", std::string("unique_with_pad")},                // P.UniqueWithPad()
364        {"diag", std::string("diag")},                                      // P.Diag()
365        {"diagflat", std::string("diagflat")},                              // diagflat()
366        {"digamma", std::string("digamma")},                                // digamma()
367        {"lgamma", std::string("lgamma")},                                  // lgamma()
368        {"adaptive_max_pool2d", std::string("adaptive_max_pool2d")},        // P.AdaptiveMaxPool2D
369        {"to_coo", std::string("to_coo")},                                  // dense_to_sparse_coo()
370        {"to_csr", std::string("to_csr")},                                  // dense_to_sparse_csr()
371        {"tolist", std::string("tolist")},                                  // tolist()
372        {"col2im", std::string("col2im")},                                  // P.Col2Im
373        {"count_nonzero", std::string("count_nonzero")},                    // count_nonzero
374        {"split", std::string("split")},                                    // split
375        {"tensor_split", std::string("tensor_split")},                      // tensor_split
376        {"vsplit", std::string("vsplit")},                                  // vsplit
377        {"hsplit", std::string("hsplit")},                                  // hsplit
378        {"dsplit", std::string("dsplit")},                                  // dplit
379        {"random_categorical", std::string("random_categorical")},          // P.RandomCategorical
380        {"xlogy", std::string("xlogy")},                                    // P.Xlogy()
381        {"erf", std::string("erf")},                                        // P.Erf()
382        {"erfc", std::string("erfc")},                                      // P.Erfc()
383        {"argmax_with_value", std::string("argmax_with_value")},            // P.ArgMaxWithValue
384        {"argmin_with_value", std::string("argmin_with_value")},            // P.ArgMinWithValue
385        {"tile", std::string("tile")},                                      // P.Tile
386        {"topk", std::string("topk")},                                      // P.TopK()
387        {"top_k", std::string("top_k")},                                    // P.TopK()
388        {"isfinite", std::string("isfinite")},                              // P.isfinite()
389        {"cos", std::string("cos")},                                        // cos()
390        {"cov", std::string("cov")},                                        // cov()
391        {"acos", std::string("acos")},                                      // acos()
392        {"arccos", std::string("acos")},                                    // acos()
393        {"acosh", std::string("acosh")},                                    // acosh()
394        {"sigmoid", std::string("sigmoid")},                                // P.Sigmoid()
395        {"addr", std::string("addr")},                                      // addr()
396        {"add", std::string("add")},                                        // P.Add()
397        {"addbmm", std::string("addbmm")},                                  // addbmm()
398        {"addmm", std::string("addmm")},                                    // addmm()
399        {"addmv", std::string("addmv")},                                    // addmv()
400        {"adjoint", std::string("adjoint")},                                // adjoint()
401        {"t", std::string("t")},                                            // t()
402        {"arccosh", std::string("acosh")},                                  // arccosh()
403        {"sin", std::string("sin")},                                        // sin()
404        {"sinc", std::string("sinc")},                                      // sinc()
405        {"arcsin", std::string("asin")},                                    // arcsin()
406        {"arctan", std::string("atan")},                                    // arctan()
407        {"arctan2", std::string("atan2")},                                  // arctan2()
408        {"asin", std::string("asin")},                                      // asin()
409        {"asinh", std::string("asinh")},                                    // asinh()
410        {"arcsinh", std::string("asinh")},                                  // arcsinh()
411        {"atan", std::string("atan")},                                      // atan()
412        {"atanh", std::string("atanh")},                                    // atanh()
413        {"arctanh", std::string("atanh")},                                  // arctanh()
414        {"baddbmm", std::string("baddbmm")},                                // baddbmm
415        {"bmm", std::string("bmm")},                                        // bmm()
416        {"value", std::string("value_")},                                   // P.Load(param, U)
417        {"to", std::string("to")},                                          // to()
418        {"bool", std::string("to_bool")},                                   // bool()
419        {"float", std::string("to_float")},                                 // float()
420        {"half", std::string("to_half")},                                   // half()
421        {"int", std::string("to_int")},                                     // int()
422        {"long", std::string("to_long")},                                   // long()
423        {"cholesky", std::string("cholesky")},                              // cholesky()
424        {"cholesky_inverse", std::string("cholesky_inverse")},              // cholesky_inverse()
425        {"cholesky_solve", std::string("cholesky_solve")},                  // cholesky_solve()
426        {"conj", std::string("conj")},                                      // conj()
427        {"cross", std::string("cross")},                                    // cross()
428        {"erfinv", std::string("erfinv")},                                  // erfinv()
429        {"less_equal", std::string("less_equal")},                          // less_equal()
430        {"fold", std::string("fold")},                                      // fold()
431        {"unfold", std::string("unfold")},                                  // unfold()
432        {"expand", std::string("expand")},                                  // expand()
433        {"cumprod", std::string("cumprod")},                                // cumprod()
434        {"div", std::string("div")},                                        // div()
435        {"divide", std::string("div")},                                     // divide()
436        {"eq", std::string("eq")},                                          // eq()
437        {"equal", std::string("equal")},                                    // equal()
438        {"expm1", std::string("expm1")},                                    // expm1()
439        {"eig", std::string("eig")},                                        // eig()
440        {"eigvals", std::string("eigvals")},                                // eigvals()
441        {"geqrf", std::string("geqrf")},                                    // geqrf()
442        {"histc", std::string("histc")},                                    // histc()
443        {"type", std::string("ms_type")},                                   // astype()
444        {"type_as", std::string("type_as")},                                // astype()
445        {"dim", prim::kPrimRank},                                           // P.Rank()
446        {"index_add", std::string("index_add")},                            // index_add()
447        {"greater", std::string("greater")},                                // greater()
448        {"greater_equal", std::string("greater_equal")},                    // greater_equal()
449        {"igamma", std::string("igamma")},                                  // igamma()
450        {"igammac", std::string("igammac")},                                // igammac()
451        {"isinf", std::string("isinf")},                                    // isinf()
452        {"isnan", std::string("isnan")},                                    // isnan()
453        {"le", std::string("le")},                                          // le()
454        {"less", std::string("less")},                                      // less()
455        {"lt", std::string("less")},                                        // lt()
456        {"logical_and", std::string("logical_and")},                        // logical_and()
457        {"logical_not", std::string("logical_not")},                        // logical_not()
458        {"logical_or", std::string("logical_or")},                          // logical_or()
459        {"logical_xor", std::string("logical_xor")},                        // logical_xor()
460        {"lstsq", std::string("lstsq")},                                    // lstsq()
461        {"mvlgamma", std::string("mvlgamma")},                              // mvlgamma()
462        {"matmul", std::string("matmul")},                                  // matmul()
463        {"inner", std::string("inner")},                                    // inner()
464        {"maximum", std::string("maximum")},                                // maximum()
465        {"msort", std::string("msort")},                                    // msort()
466        {"mm", std::string("mm")},                                          // mm()
467        {"mul", std::string("mul")},                                        // mul()
468        {"multiply", std::string("multiply")},                              // multiply()
469        {"nan_to_num", std::string("nan_to_num")},                          // nan_to_num()
470        {"nansum", std::string("nansum")},                                  // nansum()
471        {"nanmean", std::string("nanmean")},                                // nanmean()
472        {"nanmedian", std::string("nanmedian")},                            // nanmedian()
473        {"neg", std::string("neg")},                                        // neg()
474        {"ne", std::string("ne")},                                          // ne()
475        {"not_equal", std::string("not_equal")},                            // not_equal()
476        {"new_zeros", std::string("new_zeros")},                            // new_zeros()
477        {"new_ones", std::string("new_ones")},                              // new_ones()
478        {"sgn", std::string("sgn")},                                        // sgn()
479        {"sign", std::string("sign")},                                      // sign()
480        {"signbit", std::string("signbit")},                                // signbit()
481        {"sinh", std::string("sinh")},                                      // sinh()
482        {"sort", std::string("sort")},                                      // sort()
483        {"cauchy", std::string("cauchy")},                                  // P.cauchy()
484        {"log_normal", std::string("log_normal")},                          // P.LogNormalReverse()
485        {"argsort", std::string("argsort")},                                // argsort()
486        {"trunc", std::string("trunc")},                                    // trunc()
487        {"where", std::string("where")},                                    // where()
488        {"imag", std::string("imag")},                                      // imag()
489        {"diff", std::string("diff")},                                      // diff()
490        {"frac", std::string("frac")},                                      // frac()
491        {"argwhere", std::string("argwhere")},                              // argwhere()
492        {"moveaxis", std::string("moveaxis")},                              // moveaxis()
493        {"multinomial", std::string("multinomial")},                        // multinomial()
494        {"movedim", std::string("movedim")},                                // movedim()
495        {"nextafter", std::string("nextafter")},                            // nextafter()
496        {"qr", std::string("qr")},                                          // qr()
497        {"ormqr", std::string("ormqr")},                                    // ormqr()
498        {"amax", std::string("amax")},                                      // amax()
499        {"amin", std::string("amin")},                                      // amin()
500        {"lu_solve", std::string("lu_solve")},                              // lu_solve()
501        {"masked_scatter", std::string("masked_scatter")},                  // masked_scatter()
502        {"index_put", std::string("index_put")},                            // index_input()
503        {"aminmax", std::string("aminmax")},                                // aminmax()
504        {"quantile", std::string("quantile")},                              // quantile()
505        {"nanquantile", std::string("nanquantile")},                        // nanquantile()
506        {"orgqr", std::string("orgqr")},                                    // orgqr()
507        {"outer", std::string("outer")},                                    // outer()
508        {"softmax", std::string("softmax")},                                // softmax()
509      }},
510     {kObjectTypeRowTensorType,
511      {
512        {"__add__", prim::kPrimRowTensorAdd},  // P.row_tensor_add
513      }},
514     {kObjectTypeCSRTensorType,
515      {
516        {"astype", std::string("csr_astype")},      // C.csr_astype
517        {"abs", std::string("csr_abs")},            // C.csr_abs
518        {"sum", std::string("csr_sum")},            // C.csr_sum
519        {"mv", std::string("csr_mv")},              // C.csr_mv
520        {"to_tuple", std::string("csr_to_tuple")},  // C.csr_to_tuple
521        {"to_coo", std::string("csr_to_coo")},      // C.csr_to_coo
522        {"to_dense", std::string("csr_to_dense")},  // C.csr_to_dense
523        {"mm", std::string("csr_mm")},              // C.csr_mm
524        {"add", std::string("csr_add")},            // C.csr_add
525        {"softmax", std::string("csr_softmax")},    // C.csr_softmax
526      }},
527     {kObjectTypeCOOTensorType,
528      {
529        {"astype", std::string("coo_astype")},      // C.coo_astype
530        {"abs", std::string("coo_abs")},            // C.coo_abs
531        {"to_tuple", std::string("coo_to_tuple")},  // C.coo_to_tuple
532        {"to_csr", std::string("coo_to_csr")},      // C.coo_to_csr
533        {"to_dense", std::string("coo_to_dense")},  // C.coo_to_dense
534        {"coalesce", std::string("coo_coalesce")},  // C.coo_coalesce
535        {"add", std::string("coo_add")},            // C.coo_add
536      }},
537     {kObjectTypeMapTensorType,
538      {
539        {"get", std::string("map_tensor_get")},                // C.map_tensor_get
540        {"put", std::string("map_tensor_put")},                // C.map_tensor_put
541        {"erase", std::string("map_tensor_erase")},            // C.map_tensor_erase
542        {"get_keys", std::string("map_tensor_get_keys")},      // C.map_tensor_get_keys
543        {"get_values", std::string("map_tensor_get_values")},  // C.map_tensor_get_values
544        {"get_data", std::string("map_tensor_get_data")},      // C.map_tensor_get_data
545      }},
546     {kObjectTypeJTagged, {}},
547     {kObjectTypeSymbolicKeyType, {}},
548     {kObjectTypeEnvType, {}}};
549   return method_map;
550 }
551 
GetAttrMap()552 BuiltInTypeMap &GetAttrMap() {
553   static BuiltInTypeMap attr_map = {
554     {kObjectTypeTensorType,
555      {
556        {"shape", prim::kPrimShape},             // C.shape_
557        {"dtype", prim::kPrimDType},             // C.dtype_
558        {"size", std::string("size_")},          // C.size_
559        {"ndim", std::string("ndim_")},          // C.ndim_
560        {"H", std::string("H")},                 // C.H
561        {"T", std::string("T_")},                // C.T_
562        {"itemsize", std::string("itemsize_")},  // C.itemsize_
563        {"nbytes", std::string("nbytes_")},      // C.nbytes_
564        {"strides", std::string("strides_")},    // C.strides_
565        {"mH", std::string("adjoint")},          // C.adjoint
566        {"mT", std::string("mT")},               // C.mT_
567      }},
568     {kObjectTypeRowTensorType,
569      {
570        {"values", prim::kPrimRowTensorGetValues},           // F.row_tensor_get_values
571        {"indices", prim::kPrimRowTensorGetIndices},         // F.row_tensor_get_indices
572        {"dense_shape", prim::kPrimRowTensorGetDenseShape},  // F.row_tensor_get_dense_shape
573      }},
574     {kObjectTypeCOOTensorType,
575      {
576        {"values", prim::kPrimCOOTensorGetValues},     // F.coo_tensor_get_values
577        {"indices", prim::kPrimCOOTensorGetIndices},   // F.coo_tensor_get_indices
578        {"shape", prim::kPrimCOOTensorGetDenseShape},  // F.coo_tensor_get_dense_shape
579        {"dtype", std::string("dtype_")},              // C.dtype_
580        {"size", std::string("sparse_size_")},         // C.sparse_size_
581        {"ndim", std::string("sparse_ndim_")},         // C.sparse_ndim_
582        {"itemsize", std::string("itemsize_")},        // C.itemsize_
583      }},
584     {kObjectTypeCSRTensorType,
585      {
586        {"indptr", prim::kPrimCSRTensorGetIndptr},     // F.csr_tensor_get_indptr
587        {"values", prim::kPrimCSRTensorGetValues},     // F.csr_tensor_get_values
588        {"indices", prim::kPrimCSRTensorGetIndices},   // F.csr_tensor_get_indices
589        {"shape", prim::kPrimCSRTensorGetDenseShape},  // F.csr_tensor_get_shape
590        {"dtype", std::string("dtype_")},              // C.dtype_
591        {"size", std::string("sparse_size_")},         // C.sparse_size_
592        {"ndim", std::string("sparse_ndim_")},         // C.sparse_ndim_
593        {"itemsize", std::string("itemsize_")},        // C.itemsize_
594      }},
595     {kObjectTypeMapTensorType,
596      {
597        {"default_value", prim::kPrimMapTensorGetDefaultValue},             // F.map_tensor_get_default_value
598        {"permit_filter_value", prim::kPrimMapTensorGetPermitFilterValue},  // F.map_tensor_get_permit_filter_value
599        {"evict_filter_value", prim::kPrimMapTensorGetEvictFilterValue},    // F.map_tensor_get_evict_filter_value
600      }},
601   };
602   return attr_map;
603 }
604 
605 std::mutex Resource::backend_init_mutex_;
606 
Resource(const py::object & obj)607 Resource::Resource(const py::object &obj)
608     : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
609       source_input_(obj),
610       is_cleaned_(false) {}
611 
~Resource()612 Resource::~Resource() {
613   MS_LOG(DEBUG) << "Resource clear";
614 
615   try {
616     mindspore::HashMap<std::string, Any>().swap(results_);
617   } catch (const std::exception &e) {
618     MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
619   }
620 
621   // If exit normally, these global variables will be cleaned
622   // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
623   // these global variables may not being cleaned, it may
624   // cause segmentfault when free python object inside these global variables
625   // after python interpreter got freed, so these global variables
626   // are cleaned here.
627   // So if exit normally, these global variable will be cleaned twice,
628   // care be taken to prevent double free in the following functions.
629   if (!is_cleaned_) {
630     try {
631       Clean();
632     } catch (const std::exception &e) {
633       MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
634     } catch (...) {
635       MS_LOG(ERROR) << "Exception when cleaning resource.";
636     }
637   }
638 }
639 
GetMethodOrAttr(const string & name,const TypeId & type_id,const BuiltInTypeMap & method_map)640 Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
641   auto type_method_map = method_map.find(static_cast<int64_t>(type_id));
642   if (type_method_map == method_map.end()) {
643     return Any();
644   }
645   auto method = type_method_map->second.find(name);
646   if (method == type_method_map->second.end()) {
647     return Any();
648   }
649   return method->second;
650 }
651 
IsTypeInBuiltInMap(const TypeId & type)652 bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
653   TypeId type_id = NormalizeTypeId(type);
654   const BuiltInTypeMap &method_map = GetMethodMap();
655   auto iter = method_map.find(static_cast<int64_t>(type_id));
656   if (iter == method_map.end()) {
657     const BuiltInTypeMap &attr_map = GetAttrMap();
658     iter = attr_map.find(static_cast<int64_t>(type_id));
659     if (iter == attr_map.end()) {
660       return false;
661     }
662   }
663   return true;
664 }
665 
GetMethodPtr(const TypeId & type,const std::string & name)666 Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
667   TypeId type_id = NormalizeTypeId(type);
668   const BuiltInTypeMap &method_map = GetMethodMap();
669   return GetMethodOrAttr(name, type_id, method_map);
670 }
671 
GetAttrPtr(const TypeId & type,const std::string & name)672 Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
673   TypeId type_id = NormalizeTypeId(type);
674   const BuiltInTypeMap &attr_map = GetAttrMap();
675   return GetMethodOrAttr(name, type_id, attr_map);
676 }
677 
GetCompileCacheResource(const py::list & compile_cache_dep_files,const py::dict & weights,const std::string & queue_name,size_t compile_cache_id,bool * compile_cache_consistent,bool has_python_script)678 void Resource::GetCompileCacheResource(const py::list &compile_cache_dep_files, const py::dict &weights,
679                                        const std::string &queue_name, size_t compile_cache_id,
680                                        bool *compile_cache_consistent, bool has_python_script) {
681   compile_cache_manager_ = std::make_shared<CompileCacheManager>(compile_cache_id);
682   compile_cache_manager_->InitParallelGroupCkptSaveFile();
683   const bool force_use_compile_cache = (common::GetEnv("MS_DEV_FORCE_USE_COMPILE_CACHE") == "1");
684   auto &context = CompileCacheContext::GetInstance();
685   // When enabling compile cache, it is possible to enable it even without Python script.
686   if (force_use_compile_cache || !has_python_script) {
687     context.set_init_compile_cache(true);
688     MS_LOG(WARNING)
689       << "The env MS_DEV_FORCE_USE_COMPILE_CACHE has been set. It will force to use the compile cache without "
690          "checking whether the network has been changed. Please note the correctness.";
691   } else {
692     MS_EXCEPTION_IF_NULL(compile_cache_consistent);
693     if (!*compile_cache_consistent) {
694       MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
695       return;
696     }
697     context.set_init_compile_cache(true);
698     compile_cache_manager_->InitCompileCacheHash(compile_cache_dep_files);
699     *compile_cache_consistent = compile_cache_manager_->CanLoadCache();
700     if (!*compile_cache_consistent) {
701       MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
702       return;
703     }
704   }
705   func_graph_ = compile_cache_manager_->GetCachedFuncGraph(manager_, weights, queue_name);
706   layout_map_ = compile_cache_manager_->layout_map();
707 }
708 
CacheFuncGraph() const709 void Resource::CacheFuncGraph() const {
710   FuncGraphPtr layout_fg = nullptr;
711   if (parallel::IsAutoParallelCareGraph(func_graph_)) {
712     layout_fg = GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
713   }
714   compile_cache_manager_->CacheFuncGraph(func_graph_, layout_fg);
715 }
716 
Clean()717 void Resource::Clean() {
718   // Ensure that async backend creating task is finished before clean resource.
719   if (backend_ == nullptr && backend_future_.valid()) {
720     backend_ = backend_future_.get();
721   }
722   // AbstractTensor->elements() will be saved in AbstractBasePtrList
723   args_abs_.clear();
724   arguments_.clear();
725   source_input_ = py::none();
726   // Context with AbstractBasePtrList may be saved in GraphEvaluator
727   // some Evaluator like ResolveEvaluator may save Python object in cache,
728   // it should be cleaned before Python Interpreter destructed.
729   MS_EXCEPTION_IF_NULL(engine_);
730   engine_->ClearEvaluatorCache();
731   engine_->Clear();
732   // Clean cache used for parse. As static variable is released after
733   // Python threads is released.
734   parse::data_converter::ClearObjectCache();
735   parse::Parser::CleanParserResource();
736   // Clear all graphs' holding for python object(such as Cell),
737   // otherwise it will result to circular reference between the func_graph and cell.
738   for (auto graph : manager()->func_graphs()) {
739     graph->set_python_obj(nullptr);
740   }
741   trace::ClearTraceStack();
742   is_cleaned_ = true;
743 }
744 
GetBackend() const745 compile::BackendPtr Resource::GetBackend() const {
746   if (backend_ == nullptr && backend_future_.valid()) {
747     backend_ = backend_future_.get();
748   }
749   return backend_;
750 }
751 
SetBackendAsync(std::function<compile::BackendPtr ()> func)752 void Resource::SetBackendAsync(std::function<compile::BackendPtr()> func) {
753   static const bool is_enable_async = (common::GetEnv("MS_DEV_ASYNC_BACKEND_INIT") == "1");
754   auto context_ptr = MsContext::GetInstance();
755   MS_EXCEPTION_IF_NULL(context_ptr);
756   static const bool is_enable_ge = context_ptr->backend_policy() == "ge";
757   if (!is_enable_async || is_enable_ge) {
758     // Disable async backend init if required.
759     std::lock_guard<std::mutex> guard(GetBackendInitMutex());
760     backend_ = func();
761     return;
762   }
763   if (backend_ == nullptr && backend_future_.valid()) {
764     (void)backend_future_.get();
765   }
766   backend_ = nullptr;
767   backend_future_ = std::async(std::launch::async, [func]() {
768     std::lock_guard<std::mutex> guard(Resource::GetBackendInitMutex());
769     return func();
770   });
771 }
772 }  // namespace pipeline
773 }  // namespace mindspore
774