1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pynative/op_function/value_converter.h"
18
19 #include <vector>
20 #include <memory>
21 #include "kernel/pyboost/auto_generate/contiguous.h"
22
23 namespace mindspore::runtime {
24 namespace {
GetContiguousTensor(OpRunnerInfo * op_runner_info,const tensor::BaseTensorPtr & tensor)25 tensor::BaseTensorPtr GetContiguousTensor(OpRunnerInfo *op_runner_info, const tensor::BaseTensorPtr &tensor) {
26 MS_EXCEPTION_IF_NULL(tensor);
27 auto device_address = tensor->device_address();
28 if (device_address == nullptr || device_address->GetTensorStorageInfo() == nullptr) {
29 return tensor;
30 }
31
32 auto op = CREATE_PYBOOST_OP(Contiguous, op_runner_info->device_target);
33 return op->Call(tensor);
34 }
35 } // namespace
36
ToInt(const ValuePtrList & inputs,size_t i)37 Int64ImmPtr ValueConverter::ToInt(const ValuePtrList &inputs, size_t i) { return Convert<Int64ImmPtr>(inputs, i); }
38
ToFloat(const ValuePtrList & inputs,size_t i)39 FP32ImmPtr ValueConverter::ToFloat(const ValuePtrList &inputs, size_t i) { return Convert<FP32ImmPtr>(inputs, i); }
40
ToBool(const ValuePtrList & inputs,size_t i)41 BoolImmPtr ValueConverter::ToBool(const ValuePtrList &inputs, size_t i) { return Convert<BoolImmPtr>(inputs, i); }
42
ToScalar(const ValuePtrList & inputs,size_t i)43 ScalarPtr ValueConverter::ToScalar(const ValuePtrList &inputs, size_t i) { return Convert<ScalarPtr>(inputs, i); }
44
ToTensor(const ValuePtrList & inputs,size_t i)45 tensor::BaseTensorPtr ValueConverter::ToTensor(const ValuePtrList &inputs, size_t i) {
46 return Convert<tensor::BaseTensorPtr>(inputs, i);
47 }
48
ToString(const ValuePtrList & inputs,size_t i)49 StringImmPtr ValueConverter::ToString(const ValuePtrList &inputs, size_t i) { return Convert<StringImmPtr>(inputs, i); }
50
ToDtype(const ValuePtrList & inputs,size_t i)51 TypePtr ValueConverter::ToDtype(const ValuePtrList &inputs, size_t i) { return Convert<TypePtr>(inputs, i); }
52
ToValueTuple(const ValuePtrList & inputs,size_t i)53 ValueTuplePtr ValueConverter::ToValueTuple(const ValuePtrList &inputs, size_t i) {
54 return Convert<ValueTuplePtr>(inputs, i);
55 }
56
ToIntOptional(const ValuePtrList & inputs,size_t i)57 std::optional<Int64ImmPtr> ValueConverter::ToIntOptional(const ValuePtrList &inputs, size_t i) {
58 return ConvertOptional<Int64ImmPtr>(inputs, i);
59 }
60
ToFloatOptional(const ValuePtrList & inputs,size_t i)61 std::optional<FP32ImmPtr> ValueConverter::ToFloatOptional(const ValuePtrList &inputs, size_t i) {
62 return ConvertOptional<FP32ImmPtr>(inputs, i);
63 }
64
ToBoolOptional(const ValuePtrList & inputs,size_t i)65 std::optional<BoolImmPtr> ValueConverter::ToBoolOptional(const ValuePtrList &inputs, size_t i) {
66 return ConvertOptional<BoolImmPtr>(inputs, i);
67 }
68
ToScalarOptional(const ValuePtrList & inputs,size_t i)69 std::optional<ScalarPtr> ValueConverter::ToScalarOptional(const ValuePtrList &inputs, size_t i) {
70 return ConvertOptional<ScalarPtr>(inputs, i);
71 }
72
ToTensorOptional(const ValuePtrList & inputs,size_t i)73 std::optional<tensor::BaseTensorPtr> ValueConverter::ToTensorOptional(const ValuePtrList &inputs, size_t i) {
74 return ConvertOptional<tensor::BaseTensorPtr>(inputs, i);
75 }
76
ToStringOptional(const ValuePtrList & inputs,size_t i)77 std::optional<StringImmPtr> ValueConverter::ToStringOptional(const ValuePtrList &inputs, size_t i) {
78 return ConvertOptional<StringImmPtr>(inputs, i);
79 }
80
ToDtypeOptional(const ValuePtrList & inputs,size_t i)81 std::optional<TypePtr> ValueConverter::ToDtypeOptional(const ValuePtrList &inputs, size_t i) {
82 return ConvertOptional<TypePtr>(inputs, i);
83 }
84
ToValueTupleOptional(const ValuePtrList & inputs,size_t i)85 std::optional<ValueTuplePtr> ValueConverter::ToValueTupleOptional(const ValuePtrList &inputs, size_t i) {
86 return ConvertOptional<ValueTuplePtr>(inputs, i);
87 }
88
ContiguousTensorValue(OpRunnerInfo * op_runner_info,const tensor::BaseTensorPtr & tensor)89 tensor::BaseTensorPtr ValueConverter::ContiguousTensorValue(OpRunnerInfo *op_runner_info,
90 const tensor::BaseTensorPtr &tensor) {
91 MS_EXCEPTION_IF_NULL(op_runner_info);
92 if (op_runner_info->device_target == kAscendDevice) {
93 return tensor;
94 }
95
96 return GetContiguousTensor(op_runner_info, tensor);
97 }
98
ContiguousTensorValue(OpRunnerInfo * op_runner_info,const ValueTuplePtr & tuple)99 ValueTuplePtr ValueConverter::ContiguousTensorValue(OpRunnerInfo *op_runner_info, const ValueTuplePtr &tuple) {
100 MS_EXCEPTION_IF_NULL(op_runner_info);
101 MS_EXCEPTION_IF_NULL(tuple);
102 if (op_runner_info->device_target == kAscendDevice) {
103 return tuple;
104 }
105
106 const auto &value_list = tuple->value();
107 if (value_list.empty()) {
108 return tuple;
109 }
110
111 std::vector<ValuePtr> new_value_list(value_list);
112 bool need_rebuild_tuple = false;
113 for (size_t i = 0; i < value_list.size(); i++) {
114 auto val = value_list[i];
115 MS_EXCEPTION_IF_NULL(val);
116 if (!val->isa<tensor::BaseTensor>()) {
117 // No need to contiguous, when tuple is not tensor tuple.
118 break;
119 }
120
121 const auto &tensor = val->cast<tensor::BaseTensorPtr>();
122 auto contiguous_tensor = GetContiguousTensor(op_runner_info, tensor);
123 if (contiguous_tensor != tensor) {
124 need_rebuild_tuple = true;
125 new_value_list[i] = contiguous_tensor;
126 }
127 }
128
129 if (need_rebuild_tuple) {
130 return std::make_shared<ValueTuple>(new_value_list);
131 }
132 return tuple;
133 }
134 } // namespace mindspore::runtime
135