• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 
24 #include "transform/graph_ir/op_adapter_base.h"
25 #include "ir/scalar.h"
26 #include "ops/op_utils.h"
27 
28 namespace mindspore {
29 class GeDataTypeImm final : public IntegerImm {
30  public:
31   GeDataTypeImm();
32   explicit GeDataTypeImm(::ge::DataType v);
33   ~GeDataTypeImm() override = default;
MS_DECLARE_PARENT(GeDataTypeImm,IntegerImm)34   MS_DECLARE_PARENT(GeDataTypeImm, IntegerImm)
35   std::size_t hash() const override { return hash_; }
IsZero()36   bool IsZero() override { return v_ == static_cast<::ge::DataType>(0); }
IsOne()37   bool IsOne() override { return v_ == static_cast<::ge::DataType>(1); }
value()38   ::ge::DataType value() const { return v_; }
39   bool operator==(const Value &other) const override;
40   bool operator==(const GeDataTypeImm &other) const;
ToString()41   std::string ToString() const override { return scalar_to_string(v_); }
42   std::string DumpText() const override;
43 
44  private:
45   ::ge::DataType v_;
46 };
47 using GeDataTypeImmPtr = std::shared_ptr<GeDataTypeImm>;
IMM_TRAITS(GeDataTypeImmPtr,::ge::DataType)48 IMM_TRAITS(GeDataTypeImmPtr, ::ge::DataType)
49 
50 namespace transform {
51 template <typename T>
52 inline ValuePtr GetRealValue(const T &value) {
53   return MakeValue(value);
54 }
55 
56 template <>
57 inline ValuePtr GetRealValue<GeDataType>(const GeDataType &value) {
58   return MakeValue<GeDataType>(value);
59 }
60 
61 template <>
62 inline ValuePtr GetRealValue<GeTensor>(const GeTensor &) {
63   return nullptr;
64 }
65 
66 // Get integral value from ValuePtr and cast to integral type T
67 template <typename T, typename std::enable_if<std::is_integral_v<std::decay_t<T>>>::type * = nullptr>
68 T GetCastIntegralValue(const ValuePtr &value) {
69   MS_EXCEPTION_IF_NULL(value);
70   TypeId type_id = value->type()->type_id();
71 
72   switch (type_id) {
73     case kNumberTypeBool:
74       return static_cast<T>(ops::GetValueWithCheck<bool>(value));
75     case kNumberTypeInt8:
76       return static_cast<T>(ops::GetValueWithCheck<int8_t>(value));
77     case kNumberTypeInt16:
78       return static_cast<T>(ops::GetValueWithCheck<int16_t>(value));
79     case kNumberTypeInt32:
80       return static_cast<T>(ops::GetValueWithCheck<int32_t>(value));
81     case kNumberTypeInt64:
82       return static_cast<T>(ops::GetValueWithCheck<int64_t>(value));
83     case kNumberTypeUInt8:
84       return static_cast<T>(ops::GetValueWithCheck<uint8_t>(value));
85     case kNumberTypeUInt16:
86       return static_cast<T>(ops::GetValueWithCheck<uint16_t>(value));
87     case kNumberTypeUInt32:
88       return static_cast<T>(ops::GetValueWithCheck<uint32_t>(value));
89     case kNumberTypeUInt64:
90       return static_cast<T>(ops::GetValueWithCheck<uint64_t>(value));
91     default:
92       MS_LOG(EXCEPTION) << "Get and cast value of type " << value->type()->ToString() << " to integral type fail.";
93   }
94 }
95 
96 // Get floating point value from ValuePtr and cast to floating point type T
97 template <typename T, typename std::enable_if<std::is_floating_point_v<std::decay_t<T>>>::type * = nullptr>
98 T GetCastFloatValue(const ValuePtr &value) {
99   MS_EXCEPTION_IF_NULL(value);
100   TypeId type_id = value->type()->type_id();
101 
102   switch (type_id) {
103     case kNumberTypeFloat32:
104       return static_cast<T>(ops::GetValueWithCheck<float>(value));
105     case kNumberTypeFloat64:
106       return static_cast<T>(ops::GetValueWithCheck<double>(value));
107     default:
108       MS_LOG(EXCEPTION) << "Get and cast value of type " << value->type()->ToString()
109                         << " to floating point type fail.";
110   }
111 }
112 
113 template <typename P, typename Q>
114 static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits<P> &, const AnyTraits<Q> &) {
115   return static_cast<Q>(GetValue<P>(value));
116 }
117 
118 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &traits);
119 
120 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
121                                     const AnyTraits<std::vector<int64_t>>);
122 
123 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>);
124 
125 std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>);
126 
127 std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
128                                     const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>);
129 
130 GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>);
131 
132 std::vector<GeDataType> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<GEType>>);
133 
134 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEDataFormat>);
135 
136 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEPadMod>);
137 
138 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEReduction>);
139 
140 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AscendQuantRoundMode>);
141 
142 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FASInputLayoutMode>);
143 
144 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<FFNActivationMode>);
145 
146 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ScatterReduceMode>);
147 
148 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GECoordinateTransformMode>);
149 
150 std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEEnumToStr>, const std::vector<std::string> &);
151 
152 template <typename P, typename Q>
153 std::vector<Q> ConvertAnyUtil(const ValuePtr &value, AnyTraits<P>, const AnyTraits<std::vector<Q>>) {
154   MS_EXCEPTION_IF_NULL(value);
155   std::vector<Q> data;
156   if (!value->isa<ValueTuple>() && !value->isa<ValueList>()) {
157     MS_LOG(WARNING) << "error convert Value to vector for value: " << value->ToString()
158                     << ", type: " << value->type_name() << ", value should be a tuple or list";
159     data.emplace_back(ConvertAnyUtil(value, AnyTraits<P>(), AnyTraits<Q>()));
160     return data;
161   }
162   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
163   for (auto &it : vec) {
164     data.emplace_back(ConvertAnyUtil(it, AnyTraits<P>(), AnyTraits<Q>()));
165   }
166   return data;
167 }
168 
169 GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<ValueAny>);
170 
171 bool IsCustomPrim(const PrimitivePtr &prim);
172 bool IsCustomCNode(const AnfNodePtr &node);
173 bool IsNoNeedConstantFoldCNode(const PrimitivePtr &prim);
174 std::string GetOpIOFormat(const AnfNodePtr &node);
175 }  // namespace transform
176 }  // namespace mindspore
177 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
178