1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/resource.h"
20 #include "mindspore/core/ops/sparse_tensor_ops.h"
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "mindspore/core/ops/comparison_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/arithmetic_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "include/common/utils/compile_cache_context.h"
27 #include "ir/dtype.h"
28 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
29 #include "pipeline/jit/ps/debug/trace.h"
30 #include "pipeline/jit/ps/parse/data_converter.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/ad/dfunctor.h"
33 #include "frontend/parallel/step_parallel_utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "utils/ms_utils.h"
36
37 namespace mindspore {
38 namespace pipeline {
GetMethodMap()39 BuiltInTypeMap &GetMethodMap() {
40 static BuiltInTypeMap method_map = {
41 {kObjectTypeString,
42 {{"__bool__", std::string("str_bool")}, // C.str_bool
43 {"format", std::string("_format")},
44 {"upper", prim::kPrimStringUpper},
45 {"lower", prim::kPrimStringLower}}},
46 {kMetaTypeNone,
47 {
48 {"__bool__", std::string("none_bool")} // C.none_bool
49 }},
50 {kObjectTypeFunction,
51 {
52 {"__bool__", std::string("func_bool")} // C.str_bool
53 }},
54 {kNumberTypeBool,
55 {
56 {"__and__", prim::kPrimBoolAnd}, // P.bool_and
57 {"__or__", prim::kPrimBoolOr}, // P.bool_or
58 {"__eq__", prim::kPrimBoolEq}, // P.bool_eq
59 {"__ne__", std::string("bool_ne")}, // C.bool_ne
60 {"__bool__", prim::kPrimidentity} // P.identity
61 }},
62 {kNumberTypeInt,
63 {
64 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add
65 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub
66 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul
67 {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv
68 {"__truediv__", std::string("int_truediv")}, // C.int_truediv
69 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod
70 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow
71 {"__floor__", prim::kPrimidentity}, // P.identity
72 {"__trunc__", prim::kPrimidentity}, // P.identity
73 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd
74 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub
75 {"__eq__", prim::kPrimScalarEq}, // P.ScalarEq
76 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne
77 {"__lt__", prim::kPrimScalarLt}, // P.ScalarLt
78 {"__gt__", prim::kPrimScalarGt}, // P.ScalarGt
79 {"__le__", prim::kPrimScalarLe}, // P.ScalarLe
80 {"__ge__", prim::kPrimScalarGe}, // P.ScalarGe
81 {"__bool__", std::string("int_bool")}, // C.int_bool
82 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
83 }},
84 {kNumberTypeUInt,
85 {
86 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
87 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
88 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
89 {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div,
90 {"__truediv__", std::string("int_truediv")}, // C.int_truediv
91 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
92 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
93 {"__floor__", prim::kPrimidentity}, // P.identity,
94 {"__trunc__", prim::kPrimidentity}, // P.identity,
95 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
96 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
97 {"__eq__", prim::kPrimScalarEq}, // P.ScalarEq,
98 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
99 {"__lt__", prim::kPrimScalarLt}, // P.ScalarLt,
100 {"__gt__", prim::kPrimScalarGt}, // P.ScalarGt,
101 {"__le__", prim::kPrimScalarLe}, // P.ScalarLe,
102 {"__ge__", prim::kPrimScalarGe}, // P.ScalarGe,
103 {"__bool__", std::string("int_bool")}, // C.int_bool
104 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array
105 }},
106 {kNumberTypeFloat,
107 {
108 {"__add__", prim::kPrimScalarAdd}, // P.scalar_add,
109 {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub,
110 {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul,
111 {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv
112 {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div,
113 {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod,
114 {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow,
115 {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor,
116 {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc,
117 {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd,
118 {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub,
119 {"__eq__", prim::kPrimScalarEq}, // P.ScalarEq,
120 {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne,
121 {"__lt__", prim::kPrimScalarLt}, // P.ScalarLt,
122 {"__gt__", prim::kPrimScalarGt}, // P.ScalarGt,
123 {"__le__", prim::kPrimScalarLe}, // P.ScalarLe,
124 {"__ge__", prim::kPrimScalarGe}, // P.ScalarGe,
125 {"__bool__", std::string("float_bool")}, // C.float_bool
126 {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array,
127 }},
128 {kObjectTypeTuple,
129 {
130 {"__len__", prim::kPrimSequenceLen}, // P.sequence_len,
131 {"__getitem__", std::string("_getitem")}, // C.getitem,
132 {"__setitem__", std::string("_setitem")}, // C.setitem,
133 {"count", prim::kPrimSequenceCount}, // P.sequence_count
134 {"index", std::string("sequence_index")}, // C.sequence_index
135 }},
136 {kObjectTypeList,
137 {
138 {"__len__", prim::kPrimSequenceLen}, // P.sequence_len,
139 {"__getitem__", std::string("_getitem")}, // C.getitem,
140 {"__setitem__", std::string("_setitem")}, // C.setitem,
141 {"append", std::string("list_append")}, // C.list_append
142 {"insert", std::string("list_insert")}, // C.list_insert
143 {"pop", std::string("list_pop")}, // C.list_pop
144 {"clear", std::string("list_clear")}, // C.list_clear
145 {"reverse", std::string("list_reverse")}, // C.list_reverse
146 {"extend", std::string("list_extend")}, // C.list_extend
147 {"count", prim::kPrimSequenceCount}, // P.sequence_count
148 {"index", std::string("sequence_index")}, // C.sequence_index
149 }},
150 {kObjectTypeDictionary,
151 {
152 {"__len__", prim::kPrimDictLen}, // P.dict_len
153 {"__getitem__", std::string("_getitem")}, // C.getitem,
154 {"__setitem__", std::string("_setitem")}, // C.setitem,
155 {"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
156 {"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
157 {"items", prim::kPrimDictItems}, // P.dict_items
158 {"get", std::string("dict_get")}, // C.dict_get
159 {"has_key", std::string("dict_haskey")}, // C.dict_haskey
160 {"clear", std::string("dict_clear")}, // C.dict_clear
161 {"update", std::string("dict_update")}, // C.dict_update
162 {"fromkeys", std::string("dict_fromkeys")} // C.dict_fromkeys
163 }},
164 {kObjectTypeTensorType,
165 {
166 {"addcdiv", std::string("addcdiv")}, // C.addcdiv
167 {"addcmul", std::string("addcmul")}, // C.addcmul
168 {"all", std::string("all_")}, // C.reduce_all
169 {"atan2", std::string("atan2")}, // P.Atan2
170 {"angle", std::string("angle")}, // C.reduce_any
171 {"any", std::string("any_")}, // C.reduce_any
172 {"bincount", std::string("bincount")}, // bincount
173 {"chunk", std::string("chunk")}, // chunk
174 {"contiguous", prim::kPrimidentity}, // contiguous
175 {"slogdet", std::string("slogdet")}, // slogdet
176 {"trace", std::string("trace")}, // trace
177 {"tril", std::string("tril")}, // tril
178 {"__add__", std::string("add")}, // C.add
179 {"__sub__", std::string("sub")}, // C.sub
180 {"__mul__", std::string("mul")}, // C.mul
181 {"__matmul__", std::string("matmul")}, // F.matmul
182 {"xdivy", std::string("xdivy")}, // P.Xdivy
183 {"abs", std::string("abs_")}, // C.abs_
184 {"absolute", std::string("abs_")}, // C.abs_
185 {"mean", std::string("mean")}, // C.mean
186 {"prod", std::string("prod")}, // C.reduce_prod
187 {"__truediv__", std::string("truediv")}, // C.truediv
188 {"__floordiv__", std::string("floordiv")}, // C.floordiv
189 {"__mod__", std::string("mod")}, // C.mod
190 {"__pow__", std::string("pow_")}, // C.pow
191 {"__floor__", std::string("floor")}, // P.floor
192 {"__trunc__", std::string("array_trunc")}, // C.array_trunc
193 {"__pos__", std::string("array_uadd")}, // C.array_uadd
194 {"__neg__", std::string("array_usub")}, // C.array_usub
195 {"__eq__", std::string("eq")}, // C.eq
196 {"__ne__", std::string("ne")}, // C.ne
197 {"__lt__", std::string("lt")}, // C.lt
198 {"__gt__", std::string("gt")}, // C.gt
199 {"__le__", std::string("le")}, // C.le
200 {"__ge__", std::string("ge")}, // C.ge
201 {"gt", std::string("gt")}, // P.Greater
202 {"ge", std::string("ge")}, // P.GreaterEqual
203 {"expand_as", std::string("expand_tensor_as")}, // C.expand_as
204 {"broadcast_to", std::string("broadcast_to")}, // P.BroadcastTo
205 {"view", std::string("view")}, // C.view
206 {"view_as", std::string("view_as")}, // view_as()
207 {"__len__", prim::kPrimArrayLen}, // P.array_len,
208 {"__getitem__", std::string("_getitem")}, // C.getitem,
209 {"__setitem__", std::string("_setitem")}, // C.setitem,
210 {"__ms_to_array__", prim::kPrimidentity}, // P.identity,
211 {"gather_elements", std::string("gather_elements")}, // P.GatherD
212 {"item", std::string("item")}, // P.item,
213 {"itemset", std::string("itemset")}, // P.itemset,
214 {"transpose", std::string("transpose")}, // P.transpose
215 {"flatten", std::string("flatten")}, // P.reshape(,-1)
216 {"reshape", std::string("reshape")}, // P.reshape()
217 {"reshape_as", std::string("reshape_as")}, // P.reshape()
218 {"reverse", std::string("reverse")}, // P.ReverseV2()
219 {"reverse_sequence", std::string("reverse_sequence")}, // P.ReverseSequence()
220 {"bitwise_and", std::string("bitwise_and")}, // P.BitwiseAnd()
221 {"bitwise_or", std::string("bitwise_or")}, // P.BitwiseOr()
222 {"bitwise_xor", std::string("bitwise_xor")}, // P.BitwiseXor()
223 {"bitwise_left_shift", std::string("bitwise_left_shift")}, // bitwise_left_shift
224 {"bitwise_right_shift", std::string("bitwise_right_shift")}, // bitwise_right_shift
225 {"tan", std::string("tan")}, // P.Tan()
226 {"ger", std::string("ger")}, // P.Ger()
227 {"ravel", std::string("ravel")}, // P.reshape(,(-1,))
228 {"swapaxes", std::string("swapaxes")}, // P.transpose()
229 {"swapdims", std::string("swapdims")}, // P.transpose()
230 {"narrow", std::string("narrow")}, // narrow()
231 {"masked_fill", std::string("masked_fill")}, // masked_fill()
232 {"masked_select", std::string("masked_select")}, // masked_select()
233 {"nonzero", std::string("nonzero")}, // nonzero()
234 {"expand_dims", std::string("expand_dims")}, // P.expand_dims()
235 {"squeeze", std::string("squeeze")}, // P.squeeze()
236 {"unbind", std::string("unbind")}, // P.Unstack()
237 {"unsqueeze", std::string("unsqueeze")}, // P.expand_dims()
238 {"astype", std::string("astype")}, // P.cast()
239 {"short", std::string("short")}, // P.cast()
240 {"median", std::string("median")}, // P.median()
241 {"cumsum", std::string("cumsum")}, // P.cumsum()
242 {"cummin", std::string("cummin")}, // cummin()
243 {"cummax", std::string("cummax")}, // cummax()
244 {"index_fill", std::string("index_fill")}, // index_fill()
245 {"index_select", std::string("index_select")}, // index_select()
246 {"repeat_interleave", std::string("repeat_interleave")}, // repeat_interleave()
247 {"copy", std::string("copy")}, // copy()
248 {"copysign", std::string("copysign")}, // copysign()
249 {"inplace_update", std::string("inplace_update")}, // P.InplaceUpdateV2
250 {"lerp", std::string("lerp")}, // lerp()
251 {"lcm", std::string("lcm")}, // F.lcm()
252 {"ldexp", std::string("ldexp")}, // F.ldexp()
253 {"log1p", std::string("log1p")}, // P.Log1p()
254 {"logcumsumexp", std::string("logcumsumexp")}, // logcumsumexp()
255 {"logit", std::string("logit")}, // Logit()
256 {"negative", std::string("negative")}, // neg()
257 {"logdet", std::string("logdet")}, // logdet()
258 {"log_matrix_determinant", std::string("log_matrix_determinant")}, // log_matrix_determinant()
259 {"matrix_determinant", std::string("matrix_determinant")}, // matrix_determinant()
260 {"matrix_power", std::string("matrix_power")}, // P.MatrixPower()
261 {"det", std::string("det")}, // matrix_determinant()
262 {"ndimension", std::string("ndim_")}, // ndimension()
263 {"max", std::string("max")}, // P.reduce_max()
264 {"min", std::string("min")}, // P.reduce_min()
265 {"pow", std::string("pow")}, // P.Pow()
266 {"log", std::string("log")}, // P.Log()
267 {"nelement", std::string("numel")}, // numel()
268 {"numel", std::string("numel")}, // numel()
269 {"permute", std::string("permute")}, // permute()
270 {"positive", std::string("positive")}, // positive()
271 {"remainder", std::string("remainder")}, // remainder()
272 {"log10", std::string("log10")}, // F.log10()
273 {"log2", std::string("log2")}, // F.log2()
274 {"logaddexp", std::string("logaddexp")}, // logaddexp()
275 {"logaddexp2", std::string("logaddexp2")}, // logaddexp2()
276 {"logsumexp", std::string("logsumexp")}, // logsumexp()
277 {"isneginf", std::string("isneginf")}, // isneginf()
278 {"isposinf", std::string("isposinf")}, // isposinf()
279 {"isreal", std::string("isreal")}, // isreal()
280 {"minimum", std::string("minimum")}, // P.Minimum()
281 {"cosh", std::string("cosh")}, // P.Cosh()
282 {"tanh", std::string("tanh")}, // P.Tanh()
283 {"rad2deg", std::string("rad2deg")}, // F.rad2deg()
284 {"deg2rad", std::string("deg2rad")}, // F.deg2rad()
285 {"dot", std::string("dot")}, // composite.dot()
286 {"round", std::string("round_")}, // P.Round()
287 {"roll", std::string("roll")}, // P.Roll()
288 {"rot90", std::string("rot90")}, // rot90()
289 {"fill", std::string("fill")}, // P.fill()
290 {"fills", std::string("fills")}, // P.fills
291 {"fill_diagonal", std::string("fill_diagonal")}, // P.FillDiagonal()
292 {"uniform", std::string("uniform")}, // P.UniformExt()
293 {"ptp", std::string("ptp")}, // P.reduce_max() - P.reduce_min()
294 {"clamp", std::string("clamp")}, // clamp()
295 {"clip", std::string("clamp")}, // clamp()
296 {"__bool__", std::string("tensor_bool")}, // C.tensor_bool
297 {"argmax", std::string("argmax")}, // P.Argmax()
298 {"argmin", std::string("argmin")}, // P.Argmax()
299 {"resize", std::string("resize")}, // P.Reshape()
300 {"crop_and_resize", std::string("crop_and_resize")}, // P.crop_and_resize
301 {"select", std::string("select")}, // P.Select()
302 {"choose", std::string("choose")}, // P.Select()
303 {"diagonal", std::string("diagonal")}, // P.Eye()
304 {"diagonal_scatter", std::string("diagonal_scatter")}, // diagonal_scatter()
305 {"i0", std::string("i0")}, // F.i0()
306 {"isclose", std::string("isclose")}, // P.IsClose()
307 {"is_floating_point", std::string("is_floating_point")}, // is_floating_point()
308 {"is_signed", std::string("is_signed")}, // is_signed()
309 {"is_complex", std::string("is_complex")}, // F.is_complex()
310 {"inv", std::string("inv")}, // inv()
311 {"inverse", std::string("inverse")}, // inverse()
312 {"invert", std::string("invert")}, // invert()
313 {"searchsorted", std::string("searchsorted")}, // P.Select()
314 {"take", std::string("take")}, // P.GatherNd()
315 {"gather", std::string("gather")}, // P.Gather()
316 {"scatter", std::string("scatter")}, // P.TensorScatterElements()
317 {"scatter_add", std::string("tensor_scatter_add")}, // P.TensorScatterAdd()
318 {"scatter_mul", std::string("tensor_scatter_mul")}, // tensor_scatter_mul()
319 {"scatter_sub", std::string("tensor_scatter_sub")}, // P.TensorScatterSub()
320 {"scatter_min", std::string("tensor_scatter_min")}, // P.TensorScatterMin()
321 {"scatter_max", std::string("tensor_scatter_max")}, // P.TensorScatterMax()
322 {"scatter_div", std::string("tensor_scatter_div")}, // P.TensorScatterDiv()
323 {"slice_scatter", std::string("slice_scatter")}, // slice_scatter()
324 {"select_scatter", std::string("select_scatter")}, // select_scatter()
325 {"norm", std::string("norm")}, // norm()
326 {"unsorted_segment_min", std::string("unsorted_segment_min")}, // P.UnsortedSegmentMin()
327 {"unsorted_segment_max", std::string("unsorted_segment_max")}, // P.UnsortedSegmentMax()
328 {"unsorted_segment_prod", std::string("unsorted_segment_prod")}, // P.UnsortedSegmentProd()
329 {"renorm", std::string("renorm")}, // renorm()
330 {"real", std::string("real")}, // real()
331 {"reciprocal", std::string("reciprocal")}, // reciprocal()
332 {"rsqrt", std::string("rsqrt")}, // rsqrt()
333 {"trace", std::string("trace")}, // P.Eye()
334 {"var", std::string("var")}, // P.ReduceSum
335 {"std", std::string("std")}, // P.ReduceSum
336 {"sum", std::string("sum")}, // P.ReduceSum
337 {"sqrt", std::string("sqrt")}, // P.Sqrt()
338 {"square", std::string("square")}, // P.Square()
339 {"sub", std::string("sub")}, // P.Sub()
340 {"true_divide", std::string("true_divide")}, // true_divide()
341 {"triu", std::string("triu")}, // triu()
342 {"subtract", std::string("subtract")}, // true_divide()
343 {"sum_to_size", std::string("sum_to_size")}, // sum_to_size()
344 {"exp", std::string("exp")}, // P.Exp()
345 {"repeat", std::string("repeat")}, // C.repeat_elements
346 {"bernoulli", std::string("bernoulli")}, // P.Bernoulli()
347 {"ceil", std::string("ceil")}, // P.Ceil
348 {"floor", std::string("floor")}, // P.floor
349 {"floor_divide", std::string("floor_divide")}, // floor_divide
350 {"flip", std::string("flip")}, // flip
351 {"fliplr", std::string("fliplr")}, // fliplr
352 {"flipud", std::string("flipud")}, // flipud
353 {"float_power", std::string("float_power")}, // F.float_power
354 {"fmax", std::string("fmax")}, // fmax()
355 {"fmin", std::string("fmin")}, // fmin()
356 {"fmod", std::string("fmod")}, // F.fmod
357 {"hardshrink", std::string("hardshrink")}, // P.hshrink
358 {"heaviside", std::string("heaviside")}, // F.heaviside
359 {"hypot", std::string("hypot")}, // F.hypot
360 {"soft_shrink", std::string("soft_shrink")}, // P.SoftShrink
361 {"gather_nd", std::string("gather_nd")}, // P.GatherNd()
362 {"unique_consecutive", std::string("unique_consecutive")}, // UniqueConsecutive()
363 {"unique_with_pad", std::string("unique_with_pad")}, // P.UniqueWithPad()
364 {"diag", std::string("diag")}, // P.Diag()
365 {"diagflat", std::string("diagflat")}, // diagflat()
366 {"digamma", std::string("digamma")}, // digamma()
367 {"lgamma", std::string("lgamma")}, // lgamma()
368 {"adaptive_max_pool2d", std::string("adaptive_max_pool2d")}, // P.AdaptiveMaxPool2D
369 {"to_coo", std::string("to_coo")}, // dense_to_sparse_coo()
370 {"to_csr", std::string("to_csr")}, // dense_to_sparse_csr()
371 {"tolist", std::string("tolist")}, // tolist()
372 {"col2im", std::string("col2im")}, // P.Col2Im
373 {"count_nonzero", std::string("count_nonzero")}, // count_nonzero
374 {"split", std::string("split")}, // split
375 {"tensor_split", std::string("tensor_split")}, // tensor_split
376 {"vsplit", std::string("vsplit")}, // vsplit
377 {"hsplit", std::string("hsplit")}, // hsplit
378 {"dsplit", std::string("dsplit")}, // dplit
379 {"random_categorical", std::string("random_categorical")}, // P.RandomCategorical
380 {"xlogy", std::string("xlogy")}, // P.Xlogy()
381 {"erf", std::string("erf")}, // P.Erf()
382 {"erfc", std::string("erfc")}, // P.Erfc()
383 {"argmax_with_value", std::string("argmax_with_value")}, // P.ArgMaxWithValue
384 {"argmin_with_value", std::string("argmin_with_value")}, // P.ArgMinWithValue
385 {"tile", std::string("tile")}, // P.Tile
386 {"topk", std::string("topk")}, // P.TopK()
387 {"top_k", std::string("top_k")}, // P.TopK()
388 {"isfinite", std::string("isfinite")}, // P.isfinite()
389 {"cos", std::string("cos")}, // cos()
390 {"cov", std::string("cov")}, // cov()
391 {"acos", std::string("acos")}, // acos()
392 {"arccos", std::string("acos")}, // acos()
393 {"acosh", std::string("acosh")}, // acosh()
394 {"sigmoid", std::string("sigmoid")}, // P.Sigmoid()
395 {"addr", std::string("addr")}, // addr()
396 {"add", std::string("add")}, // P.Add()
397 {"addbmm", std::string("addbmm")}, // addbmm()
398 {"addmm", std::string("addmm")}, // addmm()
399 {"addmv", std::string("addmv")}, // addmv()
400 {"adjoint", std::string("adjoint")}, // adjoint()
401 {"t", std::string("t")}, // t()
402 {"arccosh", std::string("acosh")}, // arccosh()
403 {"sin", std::string("sin")}, // sin()
404 {"sinc", std::string("sinc")}, // sinc()
405 {"arcsin", std::string("asin")}, // arcsin()
406 {"arctan", std::string("atan")}, // arctan()
407 {"arctan2", std::string("atan2")}, // arctan2()
408 {"asin", std::string("asin")}, // asin()
409 {"asinh", std::string("asinh")}, // asinh()
410 {"arcsinh", std::string("asinh")}, // arcsinh()
411 {"atan", std::string("atan")}, // atan()
412 {"atanh", std::string("atanh")}, // atanh()
413 {"arctanh", std::string("atanh")}, // arctanh()
414 {"baddbmm", std::string("baddbmm")}, // baddbmm
415 {"bmm", std::string("bmm")}, // bmm()
416 {"value", std::string("value_")}, // P.Load(param, U)
417 {"to", std::string("to")}, // to()
418 {"bool", std::string("to_bool")}, // bool()
419 {"float", std::string("to_float")}, // float()
420 {"half", std::string("to_half")}, // half()
421 {"int", std::string("to_int")}, // int()
422 {"long", std::string("to_long")}, // long()
423 {"cholesky", std::string("cholesky")}, // cholesky()
424 {"cholesky_inverse", std::string("cholesky_inverse")}, // cholesky_inverse()
425 {"cholesky_solve", std::string("cholesky_solve")}, // cholesky_solve()
426 {"conj", std::string("conj")}, // conj()
427 {"cross", std::string("cross")}, // cross()
428 {"erfinv", std::string("erfinv")}, // erfinv()
429 {"less_equal", std::string("less_equal")}, // less_equal()
430 {"fold", std::string("fold")}, // fold()
431 {"unfold", std::string("unfold")}, // unfold()
432 {"expand", std::string("expand")}, // expand()
433 {"cumprod", std::string("cumprod")}, // cumprod()
434 {"div", std::string("div")}, // div()
435 {"divide", std::string("div")}, // divide()
436 {"eq", std::string("eq")}, // eq()
437 {"equal", std::string("equal")}, // equal()
438 {"expm1", std::string("expm1")}, // expm1()
439 {"eig", std::string("eig")}, // eig()
440 {"eigvals", std::string("eigvals")}, // eigvals()
441 {"geqrf", std::string("geqrf")}, // geqrf()
442 {"histc", std::string("histc")}, // histc()
443 {"type", std::string("ms_type")}, // astype()
444 {"type_as", std::string("type_as")}, // astype()
445 {"dim", prim::kPrimRank}, // P.Rank()
446 {"index_add", std::string("index_add")}, // index_add()
447 {"greater", std::string("greater")}, // greater()
448 {"greater_equal", std::string("greater_equal")}, // greater_equal()
449 {"igamma", std::string("igamma")}, // igamma()
450 {"igammac", std::string("igammac")}, // igammac()
451 {"isinf", std::string("isinf")}, // isinf()
452 {"isnan", std::string("isnan")}, // isnan()
453 {"le", std::string("le")}, // le()
454 {"less", std::string("less")}, // less()
455 {"lt", std::string("less")}, // lt()
456 {"logical_and", std::string("logical_and")}, // logical_and()
457 {"logical_not", std::string("logical_not")}, // logical_not()
458 {"logical_or", std::string("logical_or")}, // logical_or()
459 {"logical_xor", std::string("logical_xor")}, // logical_xor()
460 {"lstsq", std::string("lstsq")}, // lstsq()
461 {"mvlgamma", std::string("mvlgamma")}, // mvlgamma()
462 {"matmul", std::string("matmul")}, // matmul()
463 {"inner", std::string("inner")}, // inner()
464 {"maximum", std::string("maximum")}, // maximum()
465 {"msort", std::string("msort")}, // msort()
466 {"mm", std::string("mm")}, // mm()
467 {"mul", std::string("mul")}, // mul()
468 {"multiply", std::string("multiply")}, // multiply()
469 {"nan_to_num", std::string("nan_to_num")}, // nan_to_num()
470 {"nansum", std::string("nansum")}, // nansum()
471 {"nanmean", std::string("nanmean")}, // nanmean()
472 {"nanmedian", std::string("nanmedian")}, // nanmedian()
473 {"neg", std::string("neg")}, // neg()
474 {"ne", std::string("ne")}, // ne()
475 {"not_equal", std::string("not_equal")}, // not_equal()
476 {"new_zeros", std::string("new_zeros")}, // new_zeros()
477 {"new_ones", std::string("new_ones")}, // new_ones()
478 {"sgn", std::string("sgn")}, // sgn()
479 {"sign", std::string("sign")}, // sign()
480 {"signbit", std::string("signbit")}, // signbit()
481 {"sinh", std::string("sinh")}, // sinh()
482 {"sort", std::string("sort")}, // sort()
483 {"cauchy", std::string("cauchy")}, // P.cauchy()
484 {"log_normal", std::string("log_normal")}, // P.LogNormalReverse()
485 {"argsort", std::string("argsort")}, // argsort()
486 {"trunc", std::string("trunc")}, // trunc()
487 {"where", std::string("where")}, // where()
488 {"imag", std::string("imag")}, // imag()
489 {"diff", std::string("diff")}, // diff()
490 {"frac", std::string("frac")}, // frac()
491 {"argwhere", std::string("argwhere")}, // argwhere()
492 {"moveaxis", std::string("moveaxis")}, // moveaxis()
493 {"multinomial", std::string("multinomial")}, // multinomial()
494 {"movedim", std::string("movedim")}, // movedim()
495 {"nextafter", std::string("nextafter")}, // nextafter()
496 {"qr", std::string("qr")}, // qr()
497 {"ormqr", std::string("ormqr")}, // ormqr()
498 {"amax", std::string("amax")}, // amax()
499 {"amin", std::string("amin")}, // amin()
500 {"lu_solve", std::string("lu_solve")}, // lu_solve()
501 {"masked_scatter", std::string("masked_scatter")}, // masked_scatter()
502 {"index_put", std::string("index_put")}, // index_input()
503 {"aminmax", std::string("aminmax")}, // aminmax()
504 {"quantile", std::string("quantile")}, // quantile()
505 {"nanquantile", std::string("nanquantile")}, // nanquantile()
506 {"orgqr", std::string("orgqr")}, // orgqr()
507 {"outer", std::string("outer")}, // outer()
508 {"softmax", std::string("softmax")}, // softmax()
509 }},
510 {kObjectTypeRowTensorType,
511 {
512 {"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add
513 }},
514 {kObjectTypeCSRTensorType,
515 {
516 {"astype", std::string("csr_astype")}, // C.csr_astype
517 {"abs", std::string("csr_abs")}, // C.csr_abs
518 {"sum", std::string("csr_sum")}, // C.csr_sum
519 {"mv", std::string("csr_mv")}, // C.csr_mv
520 {"to_tuple", std::string("csr_to_tuple")}, // C.csr_to_tuple
521 {"to_coo", std::string("csr_to_coo")}, // C.csr_to_coo
522 {"to_dense", std::string("csr_to_dense")}, // C.csr_to_dense
523 {"mm", std::string("csr_mm")}, // C.csr_mm
524 {"add", std::string("csr_add")}, // C.csr_add
525 {"softmax", std::string("csr_softmax")}, // C.csr_softmax
526 }},
527 {kObjectTypeCOOTensorType,
528 {
529 {"astype", std::string("coo_astype")}, // C.coo_astype
530 {"abs", std::string("coo_abs")}, // C.coo_abs
531 {"to_tuple", std::string("coo_to_tuple")}, // C.coo_to_tuple
532 {"to_csr", std::string("coo_to_csr")}, // C.coo_to_csr
533 {"to_dense", std::string("coo_to_dense")}, // C.coo_to_dense
534 {"coalesce", std::string("coo_coalesce")}, // C.coo_coalesce
535 {"add", std::string("coo_add")}, // C.coo_add
536 }},
537 {kObjectTypeMapTensorType,
538 {
539 {"get", std::string("map_tensor_get")}, // C.map_tensor_get
540 {"put", std::string("map_tensor_put")}, // C.map_tensor_put
541 {"erase", std::string("map_tensor_erase")}, // C.map_tensor_erase
542 {"get_keys", std::string("map_tensor_get_keys")}, // C.map_tensor_get_keys
543 {"get_values", std::string("map_tensor_get_values")}, // C.map_tensor_get_values
544 {"get_data", std::string("map_tensor_get_data")}, // C.map_tensor_get_data
545 }},
546 {kObjectTypeJTagged, {}},
547 {kObjectTypeSymbolicKeyType, {}},
548 {kObjectTypeEnvType, {}}};
549 return method_map;
550 }
551
GetAttrMap()552 BuiltInTypeMap &GetAttrMap() {
553 static BuiltInTypeMap attr_map = {
554 {kObjectTypeTensorType,
555 {
556 {"shape", prim::kPrimShape}, // C.shape_
557 {"dtype", prim::kPrimDType}, // C.dtype_
558 {"size", std::string("size_")}, // C.size_
559 {"ndim", std::string("ndim_")}, // C.ndim_
560 {"H", std::string("H")}, // C.H
561 {"T", std::string("T_")}, // C.T_
562 {"itemsize", std::string("itemsize_")}, // C.itemsize_
563 {"nbytes", std::string("nbytes_")}, // C.nbytes_
564 {"strides", std::string("strides_")}, // C.strides_
565 {"mH", std::string("adjoint")}, // C.adjoint
566 {"mT", std::string("mT")}, // C.mT_
567 }},
568 {kObjectTypeRowTensorType,
569 {
570 {"values", prim::kPrimRowTensorGetValues}, // F.row_tensor_get_values
571 {"indices", prim::kPrimRowTensorGetIndices}, // F.row_tensor_get_indices
572 {"dense_shape", prim::kPrimRowTensorGetDenseShape}, // F.row_tensor_get_dense_shape
573 }},
574 {kObjectTypeCOOTensorType,
575 {
576 {"values", prim::kPrimCOOTensorGetValues}, // F.coo_tensor_get_values
577 {"indices", prim::kPrimCOOTensorGetIndices}, // F.coo_tensor_get_indices
578 {"shape", prim::kPrimCOOTensorGetDenseShape}, // F.coo_tensor_get_dense_shape
579 {"dtype", std::string("dtype_")}, // C.dtype_
580 {"size", std::string("sparse_size_")}, // C.sparse_size_
581 {"ndim", std::string("sparse_ndim_")}, // C.sparse_ndim_
582 {"itemsize", std::string("itemsize_")}, // C.itemsize_
583 }},
584 {kObjectTypeCSRTensorType,
585 {
586 {"indptr", prim::kPrimCSRTensorGetIndptr}, // F.csr_tensor_get_indptr
587 {"values", prim::kPrimCSRTensorGetValues}, // F.csr_tensor_get_values
588 {"indices", prim::kPrimCSRTensorGetIndices}, // F.csr_tensor_get_indices
589 {"shape", prim::kPrimCSRTensorGetDenseShape}, // F.csr_tensor_get_shape
590 {"dtype", std::string("dtype_")}, // C.dtype_
591 {"size", std::string("sparse_size_")}, // C.sparse_size_
592 {"ndim", std::string("sparse_ndim_")}, // C.sparse_ndim_
593 {"itemsize", std::string("itemsize_")}, // C.itemsize_
594 }},
595 {kObjectTypeMapTensorType,
596 {
597 {"default_value", prim::kPrimMapTensorGetDefaultValue}, // F.map_tensor_get_default_value
598 {"permit_filter_value", prim::kPrimMapTensorGetPermitFilterValue}, // F.map_tensor_get_permit_filter_value
599 {"evict_filter_value", prim::kPrimMapTensorGetEvictFilterValue}, // F.map_tensor_get_evict_filter_value
600 }},
601 };
602 return attr_map;
603 }
604
605 std::mutex Resource::backend_init_mutex_;
606
Resource(const py::object & obj)607 Resource::Resource(const py::object &obj)
608 : engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
609 source_input_(obj),
610 is_cleaned_(false) {}
611
~Resource()612 Resource::~Resource() {
613 MS_LOG(DEBUG) << "Resource clear";
614
615 try {
616 mindspore::HashMap<std::string, Any>().swap(results_);
617 } catch (const std::exception &e) {
618 MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
619 }
620
621 // If exit normally, these global variables will be cleaned
622 // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
623 // these global variables may not being cleaned, it may
624 // cause segmentfault when free python object inside these global variables
625 // after python interpreter got freed, so these global variables
626 // are cleaned here.
627 // So if exit normally, these global variable will be cleaned twice,
628 // care be taken to prevent double free in the following functions.
629 if (!is_cleaned_) {
630 try {
631 Clean();
632 } catch (const std::exception &e) {
633 MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what();
634 } catch (...) {
635 MS_LOG(ERROR) << "Exception when cleaning resource.";
636 }
637 }
638 }
639
GetMethodOrAttr(const string & name,const TypeId & type_id,const BuiltInTypeMap & method_map)640 Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
641 auto type_method_map = method_map.find(static_cast<int64_t>(type_id));
642 if (type_method_map == method_map.end()) {
643 return Any();
644 }
645 auto method = type_method_map->second.find(name);
646 if (method == type_method_map->second.end()) {
647 return Any();
648 }
649 return method->second;
650 }
651
IsTypeInBuiltInMap(const TypeId & type)652 bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
653 TypeId type_id = NormalizeTypeId(type);
654 const BuiltInTypeMap &method_map = GetMethodMap();
655 auto iter = method_map.find(static_cast<int64_t>(type_id));
656 if (iter == method_map.end()) {
657 const BuiltInTypeMap &attr_map = GetAttrMap();
658 iter = attr_map.find(static_cast<int64_t>(type_id));
659 if (iter == attr_map.end()) {
660 return false;
661 }
662 }
663 return true;
664 }
665
GetMethodPtr(const TypeId & type,const std::string & name)666 Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
667 TypeId type_id = NormalizeTypeId(type);
668 const BuiltInTypeMap &method_map = GetMethodMap();
669 return GetMethodOrAttr(name, type_id, method_map);
670 }
671
GetAttrPtr(const TypeId & type,const std::string & name)672 Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
673 TypeId type_id = NormalizeTypeId(type);
674 const BuiltInTypeMap &attr_map = GetAttrMap();
675 return GetMethodOrAttr(name, type_id, attr_map);
676 }
677
GetCompileCacheResource(const py::list & compile_cache_dep_files,const py::dict & weights,const std::string & queue_name,size_t compile_cache_id,bool * compile_cache_consistent,bool has_python_script)678 void Resource::GetCompileCacheResource(const py::list &compile_cache_dep_files, const py::dict &weights,
679 const std::string &queue_name, size_t compile_cache_id,
680 bool *compile_cache_consistent, bool has_python_script) {
681 compile_cache_manager_ = std::make_shared<CompileCacheManager>(compile_cache_id);
682 compile_cache_manager_->InitParallelGroupCkptSaveFile();
683 const bool force_use_compile_cache = (common::GetEnv("MS_DEV_FORCE_USE_COMPILE_CACHE") == "1");
684 auto &context = CompileCacheContext::GetInstance();
685 // When enabling compile cache, it is possible to enable it even without Python script.
686 if (force_use_compile_cache || !has_python_script) {
687 context.set_init_compile_cache(true);
688 MS_LOG(WARNING)
689 << "The env MS_DEV_FORCE_USE_COMPILE_CACHE has been set. It will force to use the compile cache without "
690 "checking whether the network has been changed. Please note the correctness.";
691 } else {
692 MS_EXCEPTION_IF_NULL(compile_cache_consistent);
693 if (!*compile_cache_consistent) {
694 MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
695 return;
696 }
697 context.set_init_compile_cache(true);
698 compile_cache_manager_->InitCompileCacheHash(compile_cache_dep_files);
699 *compile_cache_consistent = compile_cache_manager_->CanLoadCache();
700 if (!*compile_cache_consistent) {
701 MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
702 return;
703 }
704 }
705 func_graph_ = compile_cache_manager_->GetCachedFuncGraph(manager_, weights, queue_name);
706 layout_map_ = compile_cache_manager_->layout_map();
707 }
708
CacheFuncGraph() const709 void Resource::CacheFuncGraph() const {
710 FuncGraphPtr layout_fg = nullptr;
711 if (parallel::IsAutoParallelCareGraph(func_graph_)) {
712 layout_fg = GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
713 }
714 compile_cache_manager_->CacheFuncGraph(func_graph_, layout_fg);
715 }
716
Clean()717 void Resource::Clean() {
718 // Ensure that async backend creating task is finished before clean resource.
719 if (backend_ == nullptr && backend_future_.valid()) {
720 backend_ = backend_future_.get();
721 }
722 // AbstractTensor->elements() will be saved in AbstractBasePtrList
723 args_abs_.clear();
724 arguments_.clear();
725 source_input_ = py::none();
726 // Context with AbstractBasePtrList may be saved in GraphEvaluator
727 // some Evaluator like ResolveEvaluator may save Python object in cache,
728 // it should be cleaned before Python Interpreter destructed.
729 MS_EXCEPTION_IF_NULL(engine_);
730 engine_->ClearEvaluatorCache();
731 engine_->Clear();
732 // Clean cache used for parse. As static variable is released after
733 // Python threads is released.
734 parse::data_converter::ClearObjectCache();
735 parse::Parser::CleanParserResource();
736 // Clear all graphs' holding for python object(such as Cell),
737 // otherwise it will result to circular reference between the func_graph and cell.
738 for (auto graph : manager()->func_graphs()) {
739 graph->set_python_obj(nullptr);
740 }
741 trace::ClearTraceStack();
742 is_cleaned_ = true;
743 }
744
GetBackend() const745 compile::BackendPtr Resource::GetBackend() const {
746 if (backend_ == nullptr && backend_future_.valid()) {
747 backend_ = backend_future_.get();
748 }
749 return backend_;
750 }
751
SetBackendAsync(std::function<compile::BackendPtr ()> func)752 void Resource::SetBackendAsync(std::function<compile::BackendPtr()> func) {
753 static const bool is_enable_async = (common::GetEnv("MS_DEV_ASYNC_BACKEND_INIT") == "1");
754 auto context_ptr = MsContext::GetInstance();
755 MS_EXCEPTION_IF_NULL(context_ptr);
756 static const bool is_enable_ge = context_ptr->backend_policy() == "ge";
757 if (!is_enable_async || is_enable_ge) {
758 // Disable async backend init if required.
759 std::lock_guard<std::mutex> guard(GetBackendInitMutex());
760 backend_ = func();
761 return;
762 }
763 if (backend_ == nullptr && backend_future_.valid()) {
764 (void)backend_future_.get();
765 }
766 backend_ = nullptr;
767 backend_future_ = std::async(std::launch::async, [func]() {
768 std::lock_guard<std::mutex> guard(Resource::GetBackendInitMutex());
769 return func();
770 });
771 }
772 } // namespace pipeline
773 } // namespace mindspore
774