1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "pipeline/pynative/op_function/converter.h"
17 #include <unordered_map>
18 #include "include/common/utils/convert_utils_py.h"
19 #include "pipeline/jit/ps/parse/data_converter.h"
20 #include "pipeline/pynative/pynative_utils.h"
21
22 namespace mindspore {
23 namespace pynative {
24
25 namespace {
26 using OP_DTYPE = mindspore::ops::OP_DTYPE;
27 template <typename T, typename U>
PyCast(const py::object & obj)28 std::shared_ptr<U> PyCast(const py::object &obj) {
29 return std::make_shared<U>(py::cast<T>(obj));
30 }
31
ConvertBool(const py::object & obj)32 BoolImmPtr ConvertBool(const py::object &obj) {
33 if (!py::isinstance<py::bool_>(obj)) {
34 // The mutable _Bool class inherits from int, because base class 'bool' is a marked final.
35 if (py::isinstance<py::int_>(obj) && py::hasattr(obj, "__ms_mutable_bool__")) {
36 auto obj_int64 = py::cast<int64_t>(obj);
37 bool obj_bool = obj_int64 != 0;
38 return std::make_shared<BoolImm>(obj_bool);
39 } else {
40 return nullptr;
41 }
42 }
43 return PyCast<bool, BoolImm>(obj);
44 }
45
ConvertInt(const py::object & obj)46 Int64ImmPtr ConvertInt(const py::object &obj) {
47 // bool is also an instance of py::int_
48 if (py::isinstance<py::bool_>(obj) || !py::isinstance<py::int_>(obj)) {
49 return nullptr;
50 }
51 return PyCast<int64_t, Int64Imm>(obj);
52 }
53
ConvertFloat(const py::object & obj)54 FP32ImmPtr ConvertFloat(const py::object &obj) {
55 if (!py::isinstance<py::float_>(obj)) {
56 return nullptr;
57 }
58 return PyCast<double, FP32Imm>(obj);
59 }
60
ConvertNumber(const py::object & obj)61 ScalarPtr ConvertNumber(const py::object &obj) {
62 if (py::isinstance<py::float_>(obj)) {
63 return std::make_shared<FP32Imm>(py::cast<double>(obj));
64 } else if (py::isinstance<py::bool_>(obj)) {
65 return std::make_shared<BoolImm>(py::cast<bool>(obj));
66 } else if (py::isinstance<py::int_>(obj)) {
67 return std::make_shared<Int64Imm>(py::cast<int64_t>(obj));
68 }
69 return nullptr;
70 }
71
ConvertStr(const py::object & obj)72 StringImmPtr ConvertStr(const py::object &obj) {
73 if (!py::isinstance<py::str>(obj)) {
74 return nullptr;
75 }
76 return PyCast<string, StringImm>(obj);
77 }
78
79 template <typename T, typename U, typename N>
ConvertList(const py::object & obj)80 ValueTuplePtr ConvertList(const py::object &obj) {
81 if (!py::isinstance<T>(obj)) {
82 return nullptr;
83 }
84 auto seq = py::cast<T>(obj);
85 size_t size = seq.size();
86 std::vector<ValuePtr> convert(size);
87 for (size_t i = 0; i < size; ++i) {
88 if (!py::isinstance<U>(seq[i])) {
89 return nullptr;
90 }
91 auto out = PyCast<U, N>(seq[i]);
92 if (out == nullptr) {
93 return nullptr;
94 }
95 convert[i] = out;
96 }
97 return std::make_shared<ValueTuple>(std::move(convert));
98 }
99
100 template <typename T>
ConvertIntList(const py::object & obj)101 ValueTuplePtr ConvertIntList(const py::object &obj) {
102 if (!py::isinstance<T>(obj)) {
103 return nullptr;
104 }
105 auto seq = py::cast<T>(obj);
106 size_t size = seq.size();
107 std::vector<ValuePtr> convert(size);
108 for (size_t i = 0; i < size; ++i) {
109 // bool is also an instance of py::int_
110 if (py::isinstance<py::bool_>(seq[i]) || !py::isinstance<py::int_>(seq[i])) {
111 return nullptr;
112 }
113 auto out = PyCast<py::int_, Int64Imm>(seq[i]);
114 if (out == nullptr) {
115 return nullptr;
116 }
117 convert[i] = out;
118 }
119 return std::make_shared<ValueTuple>(std::move(convert));
120 }
121
122 template <>
ConvertList(const py::object & obj)123 ValueTuplePtr ConvertList<py::tuple, py::int_, Int64Imm>(const py::object &obj) {
124 return ConvertIntList<py::tuple>(obj);
125 }
126
127 template <>
ConvertList(const py::object & obj)128 ValueTuplePtr ConvertList<py::list, py::int_, Int64Imm>(const py::object &obj) {
129 return ConvertIntList<py::list>(obj);
130 }
131 } // namespace
132
Converter(ops::OpDef * op_def)133 Converter::Converter(ops::OpDef *op_def)
134 : op_def_(op_def), source_type_(std::vector<ops::OP_DTYPE>(op_def->args_.size())) {}
135
Parse(const py::list & python_args)136 void Converter::Parse(const py::list &python_args) {
137 if (op_def_->args_.size() != python_args.size()) {
138 MS_LOG(EXCEPTION) << "For operator " << op_def_->name_ << ", it requires " << op_def_->args_.size()
139 << "parameters, bug got " << python_args.size() << "parameters!";
140 }
141 }
142
ToTensor(const py::list & python_args,size_t i)143 ValuePtr Converter::ToTensor(const py::list &python_args, size_t i) {
144 const auto &op_arg = op_def_->args_[i];
145 const py::object &obj = (python_args)[i];
146 source_type_[i] = OP_DTYPE::DT_BEGIN;
147 auto tensor = parse::ConvertTensor(obj);
148 if (tensor != nullptr) {
149 if (tensor->isa<tensor::BaseTensor>()) {
150 tensor->cast<tensor::BaseTensorPtr>()->set_need_pipeline_sync(true);
151 }
152 return tensor;
153 }
154 if (!op_arg.cast_dtype_.empty()) {
155 auto convert = ConvertByCastDtype(obj, op_arg, i);
156 if (convert != nullptr && convert->isa<tensor::BaseTensor>()) {
157 return convert->cast<tensor::BaseTensorPtr>();
158 }
159 }
160
161 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
162 return nullptr;
163 }
164
ToTensorOptional(const py::list & python_args,size_t i)165 std::optional<ValuePtr> Converter::ToTensorOptional(const py::list &python_args, size_t i) {
166 const py::object &obj = (python_args)[i];
167 if (py::isinstance<py::none>(obj)) {
168 return std::nullopt;
169 }
170 return std::make_optional(ToTensor(python_args, i));
171 }
172
173 template <typename T>
ToTensorList(const py::list & python_args,size_t i)174 ValueTuplePtr Converter::ToTensorList(const py::list &python_args, size_t i) {
175 const auto &op_arg = op_def_->args_[i];
176 const py::object &obj = python_args[i];
177 source_type_[i] = OP_DTYPE::DT_BEGIN;
178 auto val_seq = parse::ConvertSequence<py::tuple, ValueTuple, parse::ConvertTensor>(obj);
179 if (val_seq != nullptr && val_seq->isa<ValueTuple>()) {
180 return val_seq->cast<ValueTuplePtr>();
181 }
182 return ConvertValueTupleByCastDtype(python_args, op_arg, i);
183 }
184
185 template <typename T>
ToTensorListOptional(const py::list & python_args,size_t i)186 std::optional<ValueTuplePtr> Converter::ToTensorListOptional(const py::list &python_args, size_t i) {
187 const py::object &obj = (python_args)[i];
188 if (py::isinstance<py::none>(obj)) {
189 return std::nullopt;
190 }
191 return std::make_optional(ToTensorList<T>(python_args, i));
192 }
193
ToInt(const py::list & python_args,size_t i)194 Int64ImmPtr Converter::ToInt(const py::list &python_args, size_t i) {
195 const auto &op_arg = op_def_->args_[i];
196 const py::object &obj = python_args[i];
197 source_type_[i] = OP_DTYPE::DT_BEGIN;
198 auto convert = ConvertInt(obj);
199 if (convert != nullptr) {
200 return convert;
201 }
202 if (!op_arg.cast_dtype_.empty()) {
203 auto convert_value = ConvertByCastDtype(obj, op_arg, i);
204 if (convert_value != nullptr && convert_value->isa<Int64Imm>()) {
205 return convert_value->cast<Int64ImmPtr>();
206 }
207 }
208 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
209 return nullptr;
210 }
211
ToIntOptional(const py::list & python_args,size_t i)212 std::optional<Int64ImmPtr> Converter::ToIntOptional(const py::list &python_args, size_t i) {
213 const py::object &obj = python_args[i];
214 if (py::isinstance<py::none>(obj)) {
215 return std::nullopt;
216 }
217 return std::make_optional(ToInt(python_args, i));
218 }
219
220 template <typename T>
ToIntList(const py::list & python_args,size_t i)221 ValueTuplePtr Converter::ToIntList(const py::list &python_args, size_t i) {
222 const auto &op_arg = op_def_->args_[i];
223 const py::object &obj = python_args[i];
224 ValueTuplePtr convert = ConvertList<T, py::int_, Int64Imm>(obj);
225 if (convert != nullptr) {
226 return convert;
227 }
228 return ConvertValueTupleByCastDtype(python_args, op_arg, i);
229 }
230
231 template <typename T>
ToIntListOptional(const py::list & python_args,size_t i)232 std::optional<ValueTuplePtr> Converter::ToIntListOptional(const py::list &python_args, size_t i) {
233 const py::object &obj = python_args[i];
234 if (py::isinstance<py::none>(obj)) {
235 return std::nullopt;
236 }
237 return std::make_optional(ToIntList<T>(python_args, i));
238 }
239
ToBool(const py::list & python_args,size_t i)240 BoolImmPtr Converter::ToBool(const py::list &python_args, size_t i) {
241 const auto &op_arg = op_def_->args_[i];
242 const py::object &obj = python_args[i];
243 source_type_[i] = OP_DTYPE::DT_BEGIN;
244 auto convert = ConvertBool(obj);
245 if (convert != nullptr) {
246 return convert;
247 }
248 if (!op_arg.cast_dtype_.empty()) {
249 auto convert_value = ConvertByCastDtype(obj, op_arg, i);
250 if (convert_value != nullptr && convert_value->isa<BoolImm>()) {
251 return convert_value->cast<BoolImmPtr>();
252 }
253 }
254 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
255 return nullptr;
256 }
257
ToBoolOptional(const py::list & python_args,size_t i)258 std::optional<BoolImmPtr> Converter::ToBoolOptional(const py::list &python_args, size_t i) {
259 const py::object &obj = python_args[i];
260 if (py::isinstance<py::none>(obj)) {
261 return std::nullopt;
262 }
263 return std::make_optional(ToBool(python_args, i));
264 }
265
266 template <typename T>
ToBoolList(const py::list & python_args,size_t i)267 ValueTuplePtr Converter::ToBoolList(const py::list &python_args, size_t i) {
268 const auto &op_arg = op_def_->args_[i];
269 const py::object &obj = python_args[i];
270 source_type_[i] = OP_DTYPE::DT_BEGIN;
271 ValueTuplePtr convert = ConvertList<T, py::bool_, BoolImm>(obj);
272 if (convert != nullptr) {
273 return convert;
274 }
275 return ConvertValueTupleByCastDtype(python_args, op_arg, i);
276 }
277
278 template <typename T>
ToBoolListOptional(const py::list & python_args,size_t i)279 std::optional<ValueTuplePtr> Converter::ToBoolListOptional(const py::list &python_args, size_t i) {
280 const py::object &obj = python_args[i];
281 source_type_[i] = OP_DTYPE::DT_BEGIN;
282 if (py::isinstance<py::none>(obj)) {
283 return std::nullopt;
284 }
285 return std::make_optional(ToBoolList<T>(python_args, i));
286 }
287
ToFloat(const py::list & python_args,size_t i)288 FP32ImmPtr Converter::ToFloat(const py::list &python_args, size_t i) {
289 const auto &op_arg = op_def_->args_[i];
290 const py::object &obj = python_args[i];
291 source_type_[i] = OP_DTYPE::DT_BEGIN;
292 auto convert = ConvertFloat(obj);
293 if (convert != nullptr) {
294 return convert;
295 }
296 if (!op_arg.cast_dtype_.empty()) {
297 auto convert_value = ConvertByCastDtype(obj, op_arg, i);
298 if (convert_value != nullptr && convert_value->isa<FP32Imm>()) {
299 return convert_value->cast<FP32ImmPtr>();
300 }
301 }
302 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
303 return nullptr;
304 }
305
ToFloatOptional(const py::list & python_args,size_t i)306 std::optional<FP32ImmPtr> Converter::ToFloatOptional(const py::list &python_args, size_t i) {
307 const py::object &obj = python_args[i];
308 if (py::isinstance<py::none>(obj)) {
309 return std::nullopt;
310 }
311 return std::make_optional(ToFloat(python_args, i));
312 }
313
314 template <typename T>
ToFloatList(const py::list & python_args,size_t i)315 ValueTuplePtr Converter::ToFloatList(const py::list &python_args, size_t i) {
316 const auto &op_arg = op_def_->args_[i];
317 const py::object &obj = python_args[i];
318 source_type_[i] = OP_DTYPE::DT_BEGIN;
319 ValueTuplePtr convert = ConvertList<T, py::float_, FP32Imm>(obj);
320 if (convert != nullptr) {
321 return convert;
322 }
323 return ConvertValueTupleByCastDtype(python_args, op_arg, i);
324 }
325
326 template <typename T>
ToFloatListOptional(const py::list & python_args,size_t i)327 std::optional<ValueTuplePtr> Converter::ToFloatListOptional(const py::list &python_args, size_t i) {
328 const py::object &obj = python_args[i];
329 if (py::isinstance<py::none>(obj)) {
330 return std::nullopt;
331 }
332 return std::make_optional(ToFloatList<T>(python_args, i));
333 }
334
ToScalar(const py::list & python_args,size_t i)335 ScalarPtr Converter::ToScalar(const py::list &python_args, size_t i) {
336 const auto &op_arg = op_def_->args_[i];
337 const py::object &obj = python_args[i];
338 source_type_[i] = OP_DTYPE::DT_BEGIN;
339 auto convert = ConvertNumber(obj);
340 if (convert != nullptr) {
341 return convert;
342 }
343 if (!op_arg.cast_dtype_.empty()) {
344 auto convert_value = ConvertByCastDtype(obj, op_arg, i);
345 if (convert_value != nullptr && convert_value->isa<Scalar>()) {
346 return convert_value->cast<ScalarPtr>();
347 }
348 }
349 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
350 return nullptr;
351 }
352
ToScalarOptional(const py::list & python_args,size_t i)353 std::optional<ScalarPtr> Converter::ToScalarOptional(const py::list &python_args, size_t i) {
354 const py::object &obj = python_args[i];
355 if (py::isinstance<py::none>(obj)) {
356 return std::nullopt;
357 }
358 return std::make_optional(ToScalar(python_args, i));
359 }
360
ToString(const py::list & python_args,size_t i)361 StringImmPtr Converter::ToString(const py::list &python_args, size_t i) {
362 const auto &op_arg = op_def_->args_[i];
363 const py::object &obj = python_args[i];
364 source_type_[i] = OP_DTYPE::DT_BEGIN;
365 auto convert = ConvertStr(obj);
366 if (convert != nullptr) {
367 return convert;
368 }
369 if (!op_arg.cast_dtype_.empty()) {
370 auto convert_value = ConvertByCastDtype(obj, op_arg, i);
371 if (convert_value != nullptr && convert_value->isa<StringImm>()) {
372 return convert_value->cast<StringImmPtr>();
373 }
374 }
375 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
376 return nullptr;
377 }
378
ToStringOptional(const py::list & python_args,size_t i)379 std::optional<StringImmPtr> Converter::ToStringOptional(const py::list &python_args, size_t i) {
380 const py::object &obj = python_args[i];
381 if (py::isinstance<py::none>(obj)) {
382 return std::nullopt;
383 }
384 return std::make_optional(ToString(python_args, i));
385 }
386
ToDtype(const py::list & python_args,size_t i)387 Int64ImmPtr Converter::ToDtype(const py::list &python_args, size_t i) {
388 const py::object &obj = python_args[i];
389 source_type_[i] = OP_DTYPE::DT_BEGIN;
390 auto convert = ConvertInt(obj);
391 if (convert != nullptr) {
392 return convert;
393 }
394 if (py::isinstance<mindspore::Type>(obj)) {
395 TypePtr type = py::cast<mindspore::TypePtr>(obj);
396 return std::make_shared<Int64Imm>(static_cast<int>(type->type_id()));
397 }
398 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, i);
399 return nullptr;
400 }
401
ToDtypeOptional(const py::list & python_args,size_t i)402 std::optional<Int64ImmPtr> Converter::ToDtypeOptional(const py::list &python_args, size_t i) {
403 const py::object &obj = python_args[i];
404 if (py::isinstance<py::none>(obj)) {
405 return std::nullopt;
406 }
407 return std::make_optional(ToDtype(python_args, i));
408 }
409
ConvertByCastDtype(const py::object & input,const ops::OpInputArg & op_arg,size_t index)410 ValuePtr Converter::ConvertByCastDtype(const py::object &input, const ops::OpInputArg &op_arg, size_t index) {
411 for (auto &cast_dtype : op_arg.cast_dtype_) {
412 auto convert_func = parse::GetConverterByType(parse::CombineTypesForTypeCast(cast_dtype, op_arg.arg_dtype_));
413 if (convert_func == nullptr) {
414 MS_LOG(EXCEPTION) << "Can't find convert function for src_dtype[" << cast_dtype << "] and dst_type"
415 << op_arg.arg_dtype_ << "].";
416 }
417 auto value = convert_func(input);
418 if (value != nullptr) {
419 source_type_[index] = cast_dtype;
420 return value;
421 }
422 }
423 return nullptr;
424 }
425
ConvertValueTupleByCastDtype(const py::list & python_args,const ops::OpInputArg & op_arg,size_t index)426 ValueTuplePtr Converter::ConvertValueTupleByCastDtype(const py::list &python_args, const ops::OpInputArg &op_arg,
427 size_t index) {
428 const auto &input = python_args[index];
429 if (!op_arg.cast_dtype_.empty()) {
430 auto convert_value = ConvertByCastDtype(input, op_arg, index);
431 if (convert_value != nullptr && convert_value->isa<ValueTuple>()) {
432 return convert_value->cast<ValueTuplePtr>();
433 }
434 }
435 PyNativeAlgo::PyParser::PrintTypeCastError(op_def_, python_args, index);
436 return nullptr;
437 }
438
439 // Declare template to compile corresponding method.
440 template ValueTuplePtr Converter::ToTensorList<py::tuple>(const py::list &python_args, size_t i);
441 template ValueTuplePtr Converter::ToTensorList<py::list>(const py::list &python_args, size_t i);
442 template std::optional<ValueTuplePtr> Converter::ToTensorListOptional<py::tuple>(const py::list &python_args, size_t i);
443 template std::optional<ValueTuplePtr> Converter::ToTensorListOptional<py::list>(const py::list &python_args, size_t i);
444 template std::optional<ValueTuplePtr> Converter::ToIntListOptional<py::tuple>(const py::list &python_args, size_t i);
445 template std::optional<ValueTuplePtr> Converter::ToIntListOptional<py::list>(const py::list &python_args, size_t i);
446 template std::optional<ValueTuplePtr> Converter::ToBoolListOptional<py::tuple>(const py::list &python_args, size_t i);
447 template std::optional<ValueTuplePtr> Converter::ToBoolListOptional<py::list>(const py::list &python_args, size_t i);
448 template std::optional<ValueTuplePtr> Converter::ToFloatListOptional<py::tuple>(const py::list &python_args, size_t i);
449 template std::optional<ValueTuplePtr> Converter::ToFloatListOptional<py::list>(const py::list &python_args, size_t i);
450
451 } // namespace pynative
452 } // namespace mindspore
453