• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_
19 
20 #include <optional>
21 #include "ir/tensor.h"
22 #include "ir/value.h"
23 #include "include/backend/visible.h"
24 #include "runtime/pynative/op_runner.h"
25 
26 namespace mindspore::runtime {
27 class BACKEND_EXPORT ValueConverter {
28  public:
29   template <typename T>
Convert(const ValuePtrList & inputs,size_t i)30   static T Convert(const ValuePtrList &inputs, size_t i) {
31     const auto &input = inputs[i];
32     MS_EXCEPTION_IF_NULL(input);
33     auto t = input->template cast<T>();
34     if (t == nullptr) {
35       MS_LOG(EXCEPTION) << "Get input type " << input->ToString() << ", but want to get " << typeid(T).name();
36     }
37     return t;
38   }
39   static Int64ImmPtr ToInt(const ValuePtrList &inputs, size_t i);
40   static FP32ImmPtr ToFloat(const ValuePtrList &inputs, size_t i);
41   static BoolImmPtr ToBool(const ValuePtrList &inputs, size_t i);
42   static ScalarPtr ToScalar(const ValuePtrList &inputs, size_t i);
43   static tensor::BaseTensorPtr ToTensor(const ValuePtrList &inputs, size_t i);
44   static StringImmPtr ToString(const ValuePtrList &inputs, size_t i);
45   static TypePtr ToDtype(const ValuePtrList &inputs, size_t i);
46   static ValueTuplePtr ToValueTuple(const ValuePtrList &inputs, size_t i);
47 
48   template <typename T>
ConvertOptional(const ValuePtrList & inputs,size_t i)49   static std::optional<T> ConvertOptional(const ValuePtrList &inputs, size_t i) {
50     const auto &input = inputs[i];
51     if (input->template isa<None>()) {
52       return std::nullopt;
53     }
54     auto t = input->template cast<T>();
55     MS_EXCEPTION_IF_NULL(t);
56     return std::make_optional<T>(t);
57   }
58   static std::optional<Int64ImmPtr> ToIntOptional(const ValuePtrList &inputs, size_t i);
59   static std::optional<FP32ImmPtr> ToFloatOptional(const ValuePtrList &inputs, size_t i);
60   static std::optional<BoolImmPtr> ToBoolOptional(const ValuePtrList &inputs, size_t i);
61   static std::optional<ScalarPtr> ToScalarOptional(const ValuePtrList &inputs, size_t i);
62   static std::optional<tensor::BaseTensorPtr> ToTensorOptional(const ValuePtrList &inputs, size_t i);
63   static std::optional<StringImmPtr> ToStringOptional(const ValuePtrList &inputs, size_t i);
64   static std::optional<TypePtr> ToDtypeOptional(const ValuePtrList &inputs, size_t i);
65   static std::optional<ValueTuplePtr> ToValueTupleOptional(const ValuePtrList &inputs, size_t i);
66 
67   static tensor::BaseTensorPtr ContiguousTensorValue(OpRunnerInfo *op_runner_info, const tensor::BaseTensorPtr &tensor);
68   static ValueTuplePtr ContiguousTensorValue(OpRunnerInfo *op_runner_info, const ValueTuplePtr &tuple);
69   template <typename T>
ContiguousTensorValue(OpRunnerInfo * op_runner_info,const std::optional<T> & val)70   static std::optional<T> ContiguousTensorValue(OpRunnerInfo *op_runner_info, const std::optional<T> &val) {
71     if (!val.has_value()) {
72       return val;
73     }
74     return std::make_optional<T>(ContiguousTensorValue(op_runner_info, val.value()));
75   }
76 };
77 }  // namespace mindspore::runtime
78 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_OP_FUNCTION_VALUE_CONVERTER_H_
79