• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "pipeline/pynative/op_function/converter.h"
17 #include <unordered_map>
18 #include "include/common/utils/convert_utils_py.h"
19 #include "pipeline/jit/ps/parse/data_converter.h"
20 #include "pipeline/pynative/pynative_utils.h"
21 
22 namespace mindspore {
23 namespace pynative {
24 
25 namespace {
26 using OP_DTYPE = mindspore::ops::OP_DTYPE;
27 template <typename T, typename U>
PyCast(const py::object & obj)28 std::shared_ptr<U> PyCast(const py::object &obj) {
29   return std::make_shared<U>(py::cast<T>(obj));
30 }
31 
ConvertBool(const py::object & obj)32 BoolImmPtr ConvertBool(const py::object &obj) {
33   if (!py::isinstance<py::bool_>(obj)) {
34     // The mutable _Bool class inherits from int, because base class 'bool' is a marked final.
35     if (py::isinstance<py::int_>(obj) && py::hasattr(obj, "__ms_mutable_bool__")) {
36       auto obj_int64 = py::cast<int64_t>(obj);
37       bool obj_bool = obj_int64 != 0;
38       return std::make_shared<BoolImm>(obj_bool);
39     } else {
40       return nullptr;
41     }
42   }
43   return PyCast<bool, BoolImm>(obj);
44 }
45 
ConvertInt(const py::object & obj)46 Int64ImmPtr ConvertInt(const py::object &obj) {
47   // bool is also an instance of py::int_
48   if (py::isinstance<py::bool_>(obj) || !py::isinstance<py::int_>(obj)) {
49     return nullptr;
50   }
51   return PyCast<int64_t, Int64Imm>(obj);
52 }
53 
ConvertFloat(const py::object & obj)54 FP32ImmPtr ConvertFloat(const py::object &obj) {
55   if (!py::isinstance<py::float_>(obj)) {
56     return nullptr;
57   }
58   return PyCast<double, FP32Imm>(obj);
59 }
60 
ConvertNumber(const py::object & obj)61 ScalarPtr ConvertNumber(const py::object &obj) {
62   if (py::isinstance<py::float_>(obj)) {
63     return std::make_shared<FP32Imm>(py::cast<double>(obj));
64   } else if (py::isinstance<py::bool_>(obj)) {
65     return std::make_shared<BoolImm>(py::cast<bool>(obj));
66   } else if (py::isinstance<py::int_>(obj)) {
67     return std::make_shared<Int64Imm>(py::cast<int64_t>(obj));
68   }
69   return nullptr;
70 }
71 
ConvertStr(const py::object & obj)72 StringImmPtr ConvertStr(const py::object &obj) {
73   if (!py::isinstance<py::str>(obj)) {
74     return nullptr;
75   }
76   return PyCast<string, StringImm>(obj);
77 }
78 
79 template <typename T, typename U, typename N>
ConvertList(const py::object & obj)80 ValueTuplePtr ConvertList(const py::object &obj) {
81   if (!py::isinstance<T>(obj)) {
82     return nullptr;
83   }
84   auto seq = py::cast<T>(obj);
85   size_t size = seq.size();
86   std::vector<ValuePtr> convert(size);
87   for (size_t i = 0; i < size; ++i) {
88     if (!py::isinstance<U>(seq[i])) {
89       return nullptr;
90     }
91     auto out = PyCast<U, N>(seq[i]);
92     if (out == nullptr) {
93       return nullptr;
94     }
95     convert[i] = out;
96   }
97   return std::make_shared<ValueTuple>(std::move(convert));
98 }
99 
100 template <typename T>
ConvertIntList(const py::object & obj)101 ValueTuplePtr ConvertIntList(const py::object &obj) {
102   if (!py::isinstance<T>(obj)) {
103     return nullptr;
104   }
105   auto seq = py::cast<T>(obj);
106   size_t size = seq.size();
107   std::vector<ValuePtr> convert(size);
108   for (size_t i = 0; i < size; ++i) {
109     // bool is also an instance of py::int_
110     if (py::isinstance<py::bool_>(seq[i]) || !py::isinstance<py::int_>(seq[i])) {
111       return nullptr;
112     }
113     auto out = PyCast<py::int_, Int64Imm>(seq[i]);
114     if (out == nullptr) {
115       return nullptr;
116     }
117     convert[i] = out;
118   }
119   return std::make_shared<ValueTuple>(std::move(convert));
120 }
121 
122 template <>
ConvertList(const py::object & obj)123 ValueTuplePtr ConvertList<py::tuple, py::int_, Int64Imm>(const py::object &obj) {
124   return ConvertIntList<py::tuple>(obj);
125 }
126 
127 template <>
ConvertList(const py::object & obj)128 ValueTuplePtr ConvertList<py::list, py::int_, Int64Imm>(const py::object &obj) {
129   return ConvertIntList<py::list>(obj);
130 }
131 }  // namespace
132 
Converter(ops::OpDef * op_def)133 Converter::Converter(ops::OpDef *op_def)
134     : op_def_(op_def), source_type_(std::vector<ops::OP_DTYPE>(op_def->args_.size())) {}
135 
Parse(const py::list & python_args)136 void Converter::Parse(const py::list &python_args) {
137   if (op_def_->args_.size() != python_args.size()) {
138     MS_LOG(EXCEPTION) << "For operator " << op_def_->name_ << ", it requires " << op_def_->args_.size()
139                       << "parameters, bug got " << python_args.size() << "parameters!";
140   }
141 }
142 
ToTensor(const py::list & python_args,size_t i)143 ValuePtr Converter::ToTensor(const py::list &python_args, size_t i) {
144   const auto &op_arg = op_def_->args_[i];
145   const py::object &obj = (python_args)[i];
146   source_type_[i] = OP_DTYPE::DT_BEGIN;
147   auto tensor = parse::ConvertTensor(obj);
148   if (tensor != nullptr) {
149     if (tensor->isa<tensor::BaseTensor>()) {
150       tensor->cast<tensor::BaseTensorPtr>()->set_need_pipeline_sync(true);
151     }
152     return tensor;
153   }
154   if (!op_arg.cast_dtype_.empty()) {
155     auto convert = ConvertByCastDtype(obj, op_arg, i);
156     if (convert != nullptr && convert->isa<tensor::BaseTensor>()) {
157       return convert->cast<tensor::BaseTensorPtr>();
158     }
159   }
160 
161   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
162   return nullptr;
163 }
164 
ToTensorOptional(const py::list & python_args,size_t i)165 std::optional<ValuePtr> Converter::ToTensorOptional(const py::list &python_args, size_t i) {
166   const py::object &obj = (python_args)[i];
167   if (py::isinstance<py::none>(obj)) {
168     return std::nullopt;
169   }
170   return std::make_optional(ToTensor(python_args, i));
171 }
172 
173 template <typename T>
ToTensorList(const py::list & python_args,size_t i)174 ValueTuplePtr Converter::ToTensorList(const py::list &python_args, size_t i) {
175   const auto &op_arg = op_def_->args_[i];
176   const py::object &obj = python_args[i];
177   source_type_[i] = OP_DTYPE::DT_BEGIN;
178   auto val_seq = parse::ConvertSequence<py::tuple, ValueTuple, parse::ConvertTensor>(obj);
179   if (val_seq != nullptr && val_seq->isa<ValueTuple>()) {
180     return val_seq->cast<ValueTuplePtr>();
181   }
182   return ConvertValueTupleByCastDtype(python_args, op_arg, i);
183 }
184 
185 template <typename T>
ToTensorListOptional(const py::list & python_args,size_t i)186 std::optional<ValueTuplePtr> Converter::ToTensorListOptional(const py::list &python_args, size_t i) {
187   const py::object &obj = (python_args)[i];
188   if (py::isinstance<py::none>(obj)) {
189     return std::nullopt;
190   }
191   return std::make_optional(ToTensorList<T>(python_args, i));
192 }
193 
ToInt(const py::list & python_args,size_t i)194 Int64ImmPtr Converter::ToInt(const py::list &python_args, size_t i) {
195   const auto &op_arg = op_def_->args_[i];
196   const py::object &obj = python_args[i];
197   source_type_[i] = OP_DTYPE::DT_BEGIN;
198   auto convert = ConvertInt(obj);
199   if (convert != nullptr) {
200     return convert;
201   }
202   if (!op_arg.cast_dtype_.empty()) {
203     auto convert_value = ConvertByCastDtype(obj, op_arg, i);
204     if (convert_value != nullptr && convert_value->isa<Int64Imm>()) {
205       return convert_value->cast<Int64ImmPtr>();
206     }
207   }
208   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
209   return nullptr;
210 }
211 
ToIntOptional(const py::list & python_args,size_t i)212 std::optional<Int64ImmPtr> Converter::ToIntOptional(const py::list &python_args, size_t i) {
213   const py::object &obj = python_args[i];
214   if (py::isinstance<py::none>(obj)) {
215     return std::nullopt;
216   }
217   return std::make_optional(ToInt(python_args, i));
218 }
219 
220 template <typename T>
ToIntList(const py::list & python_args,size_t i)221 ValueTuplePtr Converter::ToIntList(const py::list &python_args, size_t i) {
222   const auto &op_arg = op_def_->args_[i];
223   const py::object &obj = python_args[i];
224   ValueTuplePtr convert = ConvertList<T, py::int_, Int64Imm>(obj);
225   if (convert != nullptr) {
226     return convert;
227   }
228   return ConvertValueTupleByCastDtype(python_args, op_arg, i);
229 }
230 
231 template <typename T>
ToIntListOptional(const py::list & python_args,size_t i)232 std::optional<ValueTuplePtr> Converter::ToIntListOptional(const py::list &python_args, size_t i) {
233   const py::object &obj = python_args[i];
234   if (py::isinstance<py::none>(obj)) {
235     return std::nullopt;
236   }
237   return std::make_optional(ToIntList<T>(python_args, i));
238 }
239 
ToBool(const py::list & python_args,size_t i)240 BoolImmPtr Converter::ToBool(const py::list &python_args, size_t i) {
241   const auto &op_arg = op_def_->args_[i];
242   const py::object &obj = python_args[i];
243   source_type_[i] = OP_DTYPE::DT_BEGIN;
244   auto convert = ConvertBool(obj);
245   if (convert != nullptr) {
246     return convert;
247   }
248   if (!op_arg.cast_dtype_.empty()) {
249     auto convert_value = ConvertByCastDtype(obj, op_arg, i);
250     if (convert_value != nullptr && convert_value->isa<BoolImm>()) {
251       return convert_value->cast<BoolImmPtr>();
252     }
253   }
254   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
255   return nullptr;
256 }
257 
ToBoolOptional(const py::list & python_args,size_t i)258 std::optional<BoolImmPtr> Converter::ToBoolOptional(const py::list &python_args, size_t i) {
259   const py::object &obj = python_args[i];
260   if (py::isinstance<py::none>(obj)) {
261     return std::nullopt;
262   }
263   return std::make_optional(ToBool(python_args, i));
264 }
265 
266 template <typename T>
ToBoolList(const py::list & python_args,size_t i)267 ValueTuplePtr Converter::ToBoolList(const py::list &python_args, size_t i) {
268   const auto &op_arg = op_def_->args_[i];
269   const py::object &obj = python_args[i];
270   source_type_[i] = OP_DTYPE::DT_BEGIN;
271   ValueTuplePtr convert = ConvertList<T, py::bool_, BoolImm>(obj);
272   if (convert != nullptr) {
273     return convert;
274   }
275   return ConvertValueTupleByCastDtype(python_args, op_arg, i);
276 }
277 
278 template <typename T>
ToBoolListOptional(const py::list & python_args,size_t i)279 std::optional<ValueTuplePtr> Converter::ToBoolListOptional(const py::list &python_args, size_t i) {
280   const py::object &obj = python_args[i];
281   source_type_[i] = OP_DTYPE::DT_BEGIN;
282   if (py::isinstance<py::none>(obj)) {
283     return std::nullopt;
284   }
285   return std::make_optional(ToBoolList<T>(python_args, i));
286 }
287 
ToFloat(const py::list & python_args,size_t i)288 FP32ImmPtr Converter::ToFloat(const py::list &python_args, size_t i) {
289   const auto &op_arg = op_def_->args_[i];
290   const py::object &obj = python_args[i];
291   source_type_[i] = OP_DTYPE::DT_BEGIN;
292   auto convert = ConvertFloat(obj);
293   if (convert != nullptr) {
294     return convert;
295   }
296   if (!op_arg.cast_dtype_.empty()) {
297     auto convert_value = ConvertByCastDtype(obj, op_arg, i);
298     if (convert_value != nullptr && convert_value->isa<FP32Imm>()) {
299       return convert_value->cast<FP32ImmPtr>();
300     }
301   }
302   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
303   return nullptr;
304 }
305 
ToFloatOptional(const py::list & python_args,size_t i)306 std::optional<FP32ImmPtr> Converter::ToFloatOptional(const py::list &python_args, size_t i) {
307   const py::object &obj = python_args[i];
308   if (py::isinstance<py::none>(obj)) {
309     return std::nullopt;
310   }
311   return std::make_optional(ToFloat(python_args, i));
312 }
313 
314 template <typename T>
ToFloatList(const py::list & python_args,size_t i)315 ValueTuplePtr Converter::ToFloatList(const py::list &python_args, size_t i) {
316   const auto &op_arg = op_def_->args_[i];
317   const py::object &obj = python_args[i];
318   source_type_[i] = OP_DTYPE::DT_BEGIN;
319   ValueTuplePtr convert = ConvertList<T, py::float_, FP32Imm>(obj);
320   if (convert != nullptr) {
321     return convert;
322   }
323   return ConvertValueTupleByCastDtype(python_args, op_arg, i);
324 }
325 
326 template <typename T>
ToFloatListOptional(const py::list & python_args,size_t i)327 std::optional<ValueTuplePtr> Converter::ToFloatListOptional(const py::list &python_args, size_t i) {
328   const py::object &obj = python_args[i];
329   if (py::isinstance<py::none>(obj)) {
330     return std::nullopt;
331   }
332   return std::make_optional(ToFloatList<T>(python_args, i));
333 }
334 
ToScalar(const py::list & python_args,size_t i)335 ScalarPtr Converter::ToScalar(const py::list &python_args, size_t i) {
336   const auto &op_arg = op_def_->args_[i];
337   const py::object &obj = python_args[i];
338   source_type_[i] = OP_DTYPE::DT_BEGIN;
339   auto convert = ConvertNumber(obj);
340   if (convert != nullptr) {
341     return convert;
342   }
343   if (!op_arg.cast_dtype_.empty()) {
344     auto convert_value = ConvertByCastDtype(obj, op_arg, i);
345     if (convert_value != nullptr && convert_value->isa<Scalar>()) {
346       return convert_value->cast<ScalarPtr>();
347     }
348   }
349   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
350   return nullptr;
351 }
352 
ToScalarOptional(const py::list & python_args,size_t i)353 std::optional<ScalarPtr> Converter::ToScalarOptional(const py::list &python_args, size_t i) {
354   const py::object &obj = python_args[i];
355   if (py::isinstance<py::none>(obj)) {
356     return std::nullopt;
357   }
358   return std::make_optional(ToScalar(python_args, i));
359 }
360 
ToString(const py::list & python_args,size_t i)361 StringImmPtr Converter::ToString(const py::list &python_args, size_t i) {
362   const auto &op_arg = op_def_->args_[i];
363   const py::object &obj = python_args[i];
364   source_type_[i] = OP_DTYPE::DT_BEGIN;
365   auto convert = ConvertStr(obj);
366   if (convert != nullptr) {
367     return convert;
368   }
369   if (!op_arg.cast_dtype_.empty()) {
370     auto convert_value = ConvertByCastDtype(obj, op_arg, i);
371     if (convert_value != nullptr && convert_value->isa<StringImm>()) {
372       return convert_value->cast<StringImmPtr>();
373     }
374   }
375   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
376   return nullptr;
377 }
378 
ToStringOptional(const py::list & python_args,size_t i)379 std::optional<StringImmPtr> Converter::ToStringOptional(const py::list &python_args, size_t i) {
380   const py::object &obj = python_args[i];
381   if (py::isinstance<py::none>(obj)) {
382     return std::nullopt;
383   }
384   return std::make_optional(ToString(python_args, i));
385 }
386 
ToDtype(const py::list & python_args,size_t i)387 Int64ImmPtr Converter::ToDtype(const py::list &python_args, size_t i) {
388   const py::object &obj = python_args[i];
389   source_type_[i] = OP_DTYPE::DT_BEGIN;
390   auto convert = ConvertInt(obj);
391   if (convert != nullptr) {
392     return convert;
393   }
394   if (py::isinstance<mindspore::Type>(obj)) {
395     TypePtr type = py::cast<mindspore::TypePtr>(obj);
396     return std::make_shared<Int64Imm>(static_cast<int>(type->type_id()));
397   }
398   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
399   return nullptr;
400 }
401 
ToDtypeOptional(const py::list & python_args,size_t i)402 std::optional<Int64ImmPtr> Converter::ToDtypeOptional(const py::list &python_args, size_t i) {
403   const py::object &obj = python_args[i];
404   if (py::isinstance<py::none>(obj)) {
405     return std::nullopt;
406   }
407   return std::make_optional(ToDtype(python_args, i));
408 }
409 
ConvertByCastDtype(const py::object & input,const ops::OpInputArg & op_arg,size_t index)410 ValuePtr Converter::ConvertByCastDtype(const py::object &input, const ops::OpInputArg &op_arg, size_t index) {
411   for (auto &cast_dtype : op_arg.cast_dtype_) {
412     auto convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
413     if (convert_func == nullptr) {
414       MS_LOG(EXCEPTION) << "Can't find convert function for src_dtype[" << cast_dtype << "] and dst_type"
415                         << op_arg.arg_dtype_ << "].";
416     }
417     auto value = convert_func(input);
418     if (value != nullptr) {
419       source_type_[index] = cast_dtype;
420       return value;
421     }
422   }
423   return nullptr;
424 }
425 
ConvertValueTupleByCastDtype(const py::list & python_args,const ops::OpInputArg & op_arg,size_t index)426 ValueTuplePtr Converter::ConvertValueTupleByCastDtype(const py::list &python_args, const ops::OpInputArg &op_arg,
427                                                       size_t index) {
428   const auto &input = python_args[index];
429   if (!op_arg.cast_dtype_.empty()) {
430     auto convert_value = ConvertByCastDtype(input, op_arg, index);
431     if (convert_value != nullptr && convert_value->isa<ValueTuple>()) {
432       return convert_value->cast<ValueTuplePtr>();
433     }
434   }
435   PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, index);
436   return nullptr;
437 }
438 
439 // Declare template to compile corresponding method.
440 template ValueTuplePtr Converter::ToTensorList<py::tuple>(const py::list &python_args, size_t i);
441 template ValueTuplePtr Converter::ToTensorList<py::list>(const py::list &python_args, size_t i);
442 template std::optional<ValueTuplePtr> Converter::ToTensorListOptional<py::tuple>(const py::list &python_args, size_t i);
443 template std::optional<ValueTuplePtr> Converter::ToTensorListOptional<py::list>(const py::list &python_args, size_t i);
444 template std::optional<ValueTuplePtr> Converter::ToIntListOptional<py::tuple>(const py::list &python_args, size_t i);
445 template std::optional<ValueTuplePtr> Converter::ToIntListOptional<py::list>(const py::list &python_args, size_t i);
446 template std::optional<ValueTuplePtr> Converter::ToBoolListOptional<py::tuple>(const py::list &python_args, size_t i);
447 template std::optional<ValueTuplePtr> Converter::ToBoolListOptional<py::list>(const py::list &python_args, size_t i);
448 template std::optional<ValueTuplePtr> Converter::ToFloatListOptional<py::tuple>(const py::list &python_args, size_t i);
449 template std::optional<ValueTuplePtr> Converter::ToFloatListOptional<py::list>(const py::list &python_args, size_t i);
450 
451 }  // namespace pynative
452 }  // namespace mindspore
453