1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
19
20 #include <string>
21 #include <vector>
22 #include <memory>
23
24 #include "transform/graph_ir/op_adapter_base.h"
25 #include "ir/scalar.h"
26 #include "ops/op_utils.h"
27
28 namespace mindspore {
29 class GeDataTypeImm final : public IntegerImm {
30 public:
31 GeDataTypeImm();
32 explicit GeDataTypeImm(::ge::DataType v);
33 ~GeDataTypeImm() override = default;
MS_DECLARE_PARENT(GeDataTypeImm,IntegerImm)34 MS_DECLARE_PARENT(GeDataTypeImm, IntegerImm)
35 std::size_t hash() const override { return hash_; }
IsZero()36 bool IsZero() override { return v_ == static_cast<::ge::DataType>(0); }
IsOne()37 bool IsOne() override { return v_ == static_cast<::ge::DataType>(1); }
value()38 ::ge::DataType value() const { return v_; }
39 bool operator==(const Value &other) const override;
40 bool operator==(const GeDataTypeImm &other) const;
ToString()41 std::string ToString() const override { return scalar_to_string(v_); }
42 std::string DumpText() const override;
43
44 private:
45 ::ge::DataType v_;
46 };
47 using GeDataTypeImmPtr = std::shared_ptr<GeDataTypeImm>;
IMM_TRAITS(GeDataTypeImmPtr,::ge::DataType)48 IMM_TRAITS(GeDataTypeImmPtr, ::ge::DataType)
49
50 namespace transform {
51 template <typename T>
52 inline ValuePtr GetRealValue(const T &value) {
53 return MakeValue(value);
54 }
55
56 template <>
57 inline ValuePtr GetRealValue<GeDataType>(const GeDataType &value) {
58 return MakeValue<GeDataType>(value);
59 }
60
61 template <>
62 inline ValuePtr GetRealValue<GeTensor>(const GeTensor &) {
63 return nullptr;
64 }
65
66 // Get integral value from ValuePtr and cast to integral type T
67 template <typename T, typename std::enable_if<std::is_integral_v<std::decay_t<T>>>::type * = nullptr>
68 T GetCastIntegralValue(const ValuePtr &value) {
69 MS_EXCEPTION_IF_NULL(value);
70 TypeId type_id = value->type()->type_id();
71
72 switch (type_id) {
73 case kNumberTypeBool:
74 return static_cast<T>(ops::GetValueWithCheck<bool>(value));
75 case kNumberTypeInt8:
76 return static_cast<T>(ops::GetValueWithCheck<int8_t>(value));
77 case kNumberTypeInt16:
78 return static_cast<T>(ops::GetValueWithCheck<int16_t>(value));
79 case kNumberTypeInt32:
80 return static_cast<T>(ops::GetValueWithCheck<int32_t>(value));
81 case kNumberTypeInt64:
82 return static_cast<T>(ops::GetValueWithCheck<int64_t>(value));
83 case kNumberTypeUInt8:
84 return static_cast<T>(ops::GetValueWithCheck<uint8_t>(value));
85 case kNumberTypeUInt16:
86 return static_cast<T>(ops::GetValueWithCheck<uint16_t>(value));
87 case kNumberTypeUInt32:
88 return static_cast<T>(ops::GetValueWithCheck<uint32_t>(value));
89 case kNumberTypeUInt64:
90 return static_cast<T>(ops::GetValueWithCheck<uint64_t>(value));
91 default:
92 MS_LOG(EXCEPTION) << "Get and cast value of type " << value->type()->ToString() << " to integral type fail.";
93 }
94 }
95
96 // Get floating point value from ValuePtr and cast to floating point type T
97 template <typename T, typename std::enable_if<std::is_floating_point_v<std::decay_t<T>>>::type * = nullptr>
98 T GetCastFloatValue(const ValuePtr &value) {
99 MS_EXCEPTION_IF_NULL(value);
100 TypeId type_id = value->type()->type_id();
101
102 switch (type_id) {
103 case kNumberTypeFloat32:
104 return static_cast<T>(ops::GetValueWithCheck<float>(value));
105 case kNumberTypeFloat64:
106 return static_cast<T>(ops::GetValueWithCheck<double>(value));
107 default:
108 MS_LOG(EXCEPTION) << "Get and cast value of type " << value->type()->ToString()
109 << " to floating point type fail.";
110 }
111 }
112
113 template <typename P, typename Q>
114 static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits<P> &, const AnyTraits<Q> &) {
115 return static_cast<Q>(GetValue<P>(value));
116 }
117
118 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits);
119
120 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
121 const AnyTraits<std::vector<int64_t>>);
122
123 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>);
124
125 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>);
126
127 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
128 const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>);
129
130 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>);
131
132 std::vector<GeDataType> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<GEType>>);
133
134 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEDataFormat>);
135
136 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEPadMod>);
137
138 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEReduction>);
139
140 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode>);
141
142 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FASInputLayoutMode>);
143
144 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FFNActivationMode>);
145
146 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ScatterReduceMode>);
147
148 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode>);
149
150 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEEnumToStr>, const std::vector<std::string> &);
151
152 template <typename P, typename Q>
153 std::vector<Q> ConvertAnyUtil(const ValuePtr &value, AnyTraits<P>, const AnyTraits<std::vector<Q>>) {
154 MS_EXCEPTION_IF_NULL(value);
155 std::vector<Q> data;
156 if (!value->isa<ValueTuple>() && !value->isa<ValueList>()) {
157 MS_LOG(WARNING) << "error convert Value to vector for value: " << value->ToString()
158 << ", type: " << value->type_name() << ", value should be a tuple or list";
159 data.emplace_back(ConvertAnyUtil(value, AnyTraits<P>(), AnyTraits<Q>()));
160 return data;
161 }
162 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
163 for (auto &it : vec) {
164 data.emplace_back(ConvertAnyUtil(it, AnyTraits<P>(), AnyTraits<Q>()));
165 }
166 return data;
167 }
168
169 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ValueAny>);
170
171 bool IsCustomPrim(const PrimitivePtr &prim);
172 bool IsCustomCNode(const AnfNodePtr &node);
173 bool IsNoNeedConstantFoldCNode(const PrimitivePtr &prim);
174 std::string GetOpIOFormat(const AnfNodePtr &node);
175 } // namespace transform
176 } // namespace mindspore
177 #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
178