1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_ 19 20 #include <optional> 21 #include "ir/tensor.h" 22 #include "ir/value.h" 23 #include "include/backend/visible.h" 24 #include "runtime/pynative/op_runner.h" 25 26 namespace mindspore::runtime { 27 class BACKEND_EXPORT ValueConverter { 28 public: 29 template <typename T> Convert(const ValuePtrList & inputs,size_t i)30 static T Convert(const ValuePtrList &inputs, size_t i) { 31 const auto &input = inputs[i]; 32 MS_EXCEPTION_IF_NULL(input); 33 auto t = input->template cast<T>(); 34 if (t == nullptr) { 35 MS_LOG(EXCEPTION) << "Get input type " << input->ToString() << ", but want to get " << typeid(T).name(); 36 } 37 return t; 38 } 39 static Int64ImmPtr ToInt(const ValuePtrList &inputs, size_t i); 40 static FP32ImmPtr ToFloat(const ValuePtrList &inputs, size_t i); 41 static BoolImmPtr ToBool(const ValuePtrList &inputs, size_t i); 42 static ScalarPtr ToScalar(const ValuePtrList &inputs, size_t i); 43 static tensor::BaseTensorPtr ToTensor(const ValuePtrList &inputs, size_t i); 44 static StringImmPtr ToString(const ValuePtrList &inputs, size_t i); 45 static TypePtr ToDtype(const ValuePtrList &inputs, size_t i); 46 static ValueTuplePtr ToValueTuple(const ValuePtrList &inputs, size_t i); 47 48 template <typename T> ConvertOptional(const ValuePtrList & inputs,size_t i)49 static std::optional<T> ConvertOptional(const ValuePtrList &inputs, size_t i) { 50 const auto &input = inputs[i]; 51 if (input->template isa<None>()) { 52 return std::nullopt; 53 } 54 auto t = input->template cast<T>(); 55 MS_EXCEPTION_IF_NULL(t); 56 return std::make_optional<T>(t); 57 } 58 static std::optional<Int64ImmPtr> ToIntOptional(const ValuePtrList &inputs, size_t i); 59 static std::optional<FP32ImmPtr> ToFloatOptional(const ValuePtrList &inputs, size_t i); 60 static std::optional<BoolImmPtr> ToBoolOptional(const ValuePtrList &inputs, size_t i); 61 static std::optional<ScalarPtr> ToScalarOptional(const ValuePtrList &inputs, size_t i); 62 static std::optional<tensor::BaseTensorPtr> ToTensorOptional(const ValuePtrList &inputs, size_t i); 63 static std::optional<StringImmPtr> ToStringOptional(const ValuePtrList &inputs, size_t i); 64 static std::optional<TypePtr> ToDtypeOptional(const ValuePtrList &inputs, size_t i); 65 static std::optional<ValueTuplePtr> ToValueTupleOptional(const ValuePtrList &inputs, size_t i); 66 67 static tensor::BaseTensorPtr ContiguousTensorValue(OpRunnerInfo *op_runner_info, const tensor::BaseTensorPtr &tensor); 68 static ValueTuplePtr ContiguousTensorValue(OpRunnerInfo *op_runner_info, const ValueTuplePtr &tuple); 69 template <typename T> ContiguousTensorValue(OpRunnerInfo * op_runner_info,const std::optional<T> & val)70 static std::optional<T> ContiguousTensorValue(OpRunnerInfo *op_runner_info, const std::optional<T> &val) { 71 if (!val.has_value()) { 72 return val; 73 } 74 return std::make_optional<T>(ContiguousTensorValue(op_runner_info, val.value())); 75 } 76 }; 77 } // namespace mindspore::runtime 78 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_ 79