• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/op_function/value_converter.h"
18 
19 #include <vector>
20 #include <memory>
21 #include "kernel/pyboost/auto_generate/contiguous.h"
22 
23 namespace mindspore::runtime {
24 namespace {
GetContiguousTensor(OpRunnerInfo * op_runner_info,const tensor::BaseTensorPtr & tensor)25 tensor::BaseTensorPtr GetContiguousTensor(OpRunnerInfo *op_runner_info, const tensor::BaseTensorPtr &tensor) {
26   MS_EXCEPTION_IF_NULL(tensor);
27   auto device_address = tensor->device_address();
28   if (device_address == nullptr || device_address->GetTensorStorageInfo() == nullptr) {
29     return tensor;
30   }
31 
32   auto op = CREATE_PYBOOST_OP(Contiguous, op_runner_info->device_target);
33   return op->Call(tensor);
34 }
35 }  // namespace
36 
ToInt(const ValuePtrList & inputs,size_t i)37 Int64ImmPtr ValueConverter::ToInt(const ValuePtrList &inputs, size_t i) { return Convert<Int64ImmPtr>(inputs, i); }
38 
ToFloat(const ValuePtrList & inputs,size_t i)39 FP32ImmPtr ValueConverter::ToFloat(const ValuePtrList &inputs, size_t i) { return Convert<FP32ImmPtr>(inputs, i); }
40 
ToBool(const ValuePtrList & inputs,size_t i)41 BoolImmPtr ValueConverter::ToBool(const ValuePtrList &inputs, size_t i) { return Convert<BoolImmPtr>(inputs, i); }
42 
ToScalar(const ValuePtrList & inputs,size_t i)43 ScalarPtr ValueConverter::ToScalar(const ValuePtrList &inputs, size_t i) { return Convert<ScalarPtr>(inputs, i); }
44 
ToTensor(const ValuePtrList & inputs,size_t i)45 tensor::BaseTensorPtr ValueConverter::ToTensor(const ValuePtrList &inputs, size_t i) {
46   return Convert<tensor::BaseTensorPtr>(inputs, i);
47 }
48 
ToString(const ValuePtrList & inputs,size_t i)49 StringImmPtr ValueConverter::ToString(const ValuePtrList &inputs, size_t i) { return Convert<StringImmPtr>(inputs, i); }
50 
ToDtype(const ValuePtrList & inputs,size_t i)51 TypePtr ValueConverter::ToDtype(const ValuePtrList &inputs, size_t i) { return Convert<TypePtr>(inputs, i); }
52 
ToValueTuple(const ValuePtrList & inputs,size_t i)53 ValueTuplePtr ValueConverter::ToValueTuple(const ValuePtrList &inputs, size_t i) {
54   return Convert<ValueTuplePtr>(inputs, i);
55 }
56 
ToIntOptional(const ValuePtrList & inputs,size_t i)57 std::optional<Int64ImmPtr> ValueConverter::ToIntOptional(const ValuePtrList &inputs, size_t i) {
58   return ConvertOptional<Int64ImmPtr>(inputs, i);
59 }
60 
ToFloatOptional(const ValuePtrList & inputs,size_t i)61 std::optional<FP32ImmPtr> ValueConverter::ToFloatOptional(const ValuePtrList &inputs, size_t i) {
62   return ConvertOptional<FP32ImmPtr>(inputs, i);
63 }
64 
ToBoolOptional(const ValuePtrList & inputs,size_t i)65 std::optional<BoolImmPtr> ValueConverter::ToBoolOptional(const ValuePtrList &inputs, size_t i) {
66   return ConvertOptional<BoolImmPtr>(inputs, i);
67 }
68 
ToScalarOptional(const ValuePtrList & inputs,size_t i)69 std::optional<ScalarPtr> ValueConverter::ToScalarOptional(const ValuePtrList &inputs, size_t i) {
70   return ConvertOptional<ScalarPtr>(inputs, i);
71 }
72 
ToTensorOptional(const ValuePtrList & inputs,size_t i)73 std::optional<tensor::BaseTensorPtr> ValueConverter::ToTensorOptional(const ValuePtrList &inputs, size_t i) {
74   return ConvertOptional<tensor::BaseTensorPtr>(inputs, i);
75 }
76 
ToStringOptional(const ValuePtrList & inputs,size_t i)77 std::optional<StringImmPtr> ValueConverter::ToStringOptional(const ValuePtrList &inputs, size_t i) {
78   return ConvertOptional<StringImmPtr>(inputs, i);
79 }
80 
ToDtypeOptional(const ValuePtrList & inputs,size_t i)81 std::optional<TypePtr> ValueConverter::ToDtypeOptional(const ValuePtrList &inputs, size_t i) {
82   return ConvertOptional<TypePtr>(inputs, i);
83 }
84 
ToValueTupleOptional(const ValuePtrList & inputs,size_t i)85 std::optional<ValueTuplePtr> ValueConverter::ToValueTupleOptional(const ValuePtrList &inputs, size_t i) {
86   return ConvertOptional<ValueTuplePtr>(inputs, i);
87 }
88 
ContiguousTensorValue(OpRunnerInfo * op_runner_info,const tensor::BaseTensorPtr & tensor)89 tensor::BaseTensorPtr ValueConverter::ContiguousTensorValue(OpRunnerInfo *op_runner_info,
90                                                             const tensor::BaseTensorPtr &tensor) {
91   MS_EXCEPTION_IF_NULL(op_runner_info);
92   if (op_runner_info->device_target == kAscendDevice) {
93     return tensor;
94   }
95 
96   return GetContiguousTensor(op_runner_info, tensor);
97 }
98 
ContiguousTensorValue(OpRunnerInfo * op_runner_info,const ValueTuplePtr & tuple)99 ValueTuplePtr ValueConverter::ContiguousTensorValue(OpRunnerInfo *op_runner_info, const ValueTuplePtr &tuple) {
100   MS_EXCEPTION_IF_NULL(op_runner_info);
101   MS_EXCEPTION_IF_NULL(tuple);
102   if (op_runner_info->device_target == kAscendDevice) {
103     return tuple;
104   }
105 
106   const auto &value_list = tuple->value();
107   if (value_list.empty()) {
108     return tuple;
109   }
110 
111   std::vector<ValuePtr> new_value_list(value_list);
112   bool need_rebuild_tuple = false;
113   for (size_t i = 0; i < value_list.size(); i++) {
114     auto val = value_list[i];
115     MS_EXCEPTION_IF_NULL(val);
116     if (!val->isa<tensor::BaseTensor>()) {
117       // No need to contiguous, when tuple is not tensor tuple.
118       break;
119     }
120 
121     const auto &tensor = val->cast<tensor::BaseTensorPtr>();
122     auto contiguous_tensor = GetContiguousTensor(op_runner_info, tensor);
123     if (contiguous_tensor != tensor) {
124       need_rebuild_tuple = true;
125       new_value_list[i] = contiguous_tensor;
126     }
127   }
128 
129   if (need_rebuild_tuple) {
130     return std::make_shared<ValueTuple>(new_value_list);
131   }
132   return tuple;
133 }
134 }  // namespace mindspore::runtime
135