1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
18 #define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
19
20 #include <vector>
21 #include <string>
22 #include <type_traits>
23 #include "mindapi/base/base.h"
24 #include "mindapi/ir/common.h"
25
26 namespace mindspore::api {
27 template <typename T>
28 struct ImmTrait {};
29
30 #define MIND_API_IMM_TRAIT(typeimm, prototype) \
31 template <> \
32 struct ImmTrait<prototype> { \
33 using type = SharedPtr<typeimm>; \
34 }
35
36 /// \brief Value represents a value in expression.
37 class MIND_API Value : public Base {
38 public:
39 MIND_API_BASE_MEMBER(Value);
40
41 /// \brief Get the type of this Value.
42 ///
43 /// \return The type.
44 TypePtr type() const;
45
46 /// \brief Get the abstract of this Value.
47 ///
48 /// \return Abstract of this Value.
49 AbstractBasePtr ToAbstract() const;
50 };
51
52 /// \brief ValueSequence represents a sequence of values.
53 class MIND_API ValueSequence : public Value {
54 public:
55 MIND_API_BASE_MEMBER(ValueSequence);
56
57 /// \brief Get the size of this ValueSequence.
58 ///
59 /// \return The size as the number of elements.
60 std::size_t size() const;
61
62 /// \brief Get the list of values in this ValueSequence.
63 ///
64 /// \return The list of element values.
65 std::vector<ValuePtr> value() const;
66 };
67
68 using ValueSequencePtr = SharedPtr<ValueSequence>;
69
70 /// \brief ValueTuple represents a value tuple.
71 class MIND_API ValueTuple : public ValueSequence {
72 public:
73 MIND_API_BASE_MEMBER(ValueTuple);
74
75 /// \brief Constructor of ValueTuple.
76 ///
77 /// \param[in] elements The elements of the tuple.
78 explicit ValueTuple(const std::vector<ValuePtr> &elements);
79 };
80
81 using ValueTuplePtr = SharedPtr<ValueTuple>;
82
83 /// \brief StringImm defines a Value whose type is string.
84 class MIND_API StringImm : public Value {
85 public:
86 MIND_API_BASE_MEMBER(StringImm);
87
88 /// \brief Create StringImm with the given string.
89 ///
90 /// \param[in] str The given string value.
91 explicit StringImm(const std::string &str);
92
93 /// \brief Get the string value of this StringImm.
94 ///
95 /// \return The string value of this StringImm.
96 const std::string &value() const;
97 };
98
99 using StringImmPtr = SharedPtr<StringImm>;
100
101 MIND_API_IMM_TRAIT(StringImm, std::string);
102
103 /// \brief Scalar defines interface for scalar data.
104 class MIND_API Scalar : public Value {
105 public:
106 MIND_API_BASE_MEMBER(Scalar);
107 };
108
109 /// \brief BoolImm defines interface for bool data.
110 class MIND_API BoolImm : public Scalar {
111 public:
112 MIND_API_BASE_MEMBER(BoolImm);
113
114 /// \brief Create BoolImm with the given bool value.
115 ///
116 /// \param[in] b The given bool value.
117 explicit BoolImm(bool b);
118
119 /// \brief Get the bool value of this BoolImm.
120 ///
121 /// \return The bool value of this BoolImm.
122 bool value() const;
123 };
124
125 using BoolImmPtr = SharedPtr<BoolImm>;
126
127 MIND_API_IMM_TRAIT(BoolImm, bool);
128
129 /// \brief IntegerImm defines interface for integer data.
130 class MIND_API IntegerImm : public Scalar {
131 public:
132 MIND_API_BASE_MEMBER(IntegerImm);
133 };
134
135 /// \brief Int8Imm defines interface for int8 data.
136 class MIND_API Int8Imm : public IntegerImm {
137 public:
138 MIND_API_BASE_MEMBER(Int8Imm);
139
140 /// \brief Create Int8Imm with the given int8 value.
141 ///
142 /// \param[in] value The given int8 value.
143 explicit Int8Imm(int8_t value);
144
145 /// \brief Get the int8 value of this Int8Imm.
146 ///
147 /// \return The int8 value of this Int8Imm.
148 int8_t value() const;
149 };
150
151 using Int8ImmPtr = SharedPtr<Int8Imm>;
152
153 MIND_API_IMM_TRAIT(Int8Imm, int8_t);
154
155 /// \brief Int16Imm defines interface for int16 data.
156 class MIND_API Int16Imm : public IntegerImm {
157 public:
158 MIND_API_BASE_MEMBER(Int16Imm);
159
160 /// \brief Create Int1I6mm with the given int16 value.
161 ///
162 /// \param[in] value The given int16 value.
163 explicit Int16Imm(int16_t value);
164
165 /// \brief Get the int16 value of this Int16Imm.
166 ///
167 /// \return The int16 value of this Int16Imm.
168 int16_t value() const;
169 };
170
171 using Int16ImmPtr = SharedPtr<Int16Imm>;
172
173 MIND_API_IMM_TRAIT(Int16Imm, int16_t);
174
175 /// \brief Int32Imm defines interface for int32 data.
176 class MIND_API Int32Imm : public IntegerImm {
177 public:
178 MIND_API_BASE_MEMBER(Int32Imm);
179
180 /// \brief Create Int32Imm with the given int32 value.
181 ///
182 /// \param[in] value The given int32 value.
183 explicit Int32Imm(int32_t value);
184
185 /// \brief Get the int32 value of this Int32Imm.
186 ///
187 /// \return The int32 value of this Int32Imm.
188 int32_t value() const;
189 };
190
191 using Int32ImmPtr = SharedPtr<Int32Imm>;
192
193 MIND_API_IMM_TRAIT(Int32Imm, int32_t);
194
195 /// \brief Int64Imm defines interface for int64 data.
196 class MIND_API Int64Imm : public IntegerImm {
197 public:
198 MIND_API_BASE_MEMBER(Int64Imm);
199
200 /// \brief Create Int64Imm with the given int64 value.
201 ///
202 /// \param[in] value The given int64 value.
203 explicit Int64Imm(int64_t value);
204
205 /// \brief Get the int64 value of this Int64Imm.
206 ///
207 /// \return The int64 value of this Int64Imm.
208 int64_t value() const;
209 };
210
211 using Int64ImmPtr = SharedPtr<Int64Imm>;
212
213 MIND_API_IMM_TRAIT(Int64Imm, int64_t);
214
215 /// \brief UInt8Imm defines interface for uint8 data.
216 class MIND_API UInt8Imm : public IntegerImm {
217 public:
218 MIND_API_BASE_MEMBER(UInt8Imm);
219
220 /// \brief Create UInt8Imm with the given uint8 value.
221 ///
222 /// \param[in] value The given uint8 value.
223 explicit UInt8Imm(uint8_t value);
224
225 /// \brief Get the uint8 value of this UInt8Imm.
226 ///
227 /// \return The uint8 value of this UInt8Imm.
228 uint8_t value() const;
229 };
230
231 using UInt8ImmPtr = SharedPtr<UInt8Imm>;
232
233 MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);
234
235 /// \brief FloatImm defines interface for float data.
236 class MIND_API FloatImm : public Scalar {
237 public:
238 MIND_API_BASE_MEMBER(FloatImm);
239 };
240
241 /// \brief FP32Imm defines interface for float32 data.
242 class MIND_API FP32Imm : public FloatImm {
243 public:
244 MIND_API_BASE_MEMBER(FP32Imm);
245
246 /// \brief Create FP32Imm with the given float value.
247 ///
248 /// \param[in] value The given float value.
249 explicit FP32Imm(float value);
250
251 /// \brief Get the float value of this FP32Imm.
252 ///
253 /// \return The float value of this FP32Imm.
254 float value() const;
255 };
256
257 using FP32ImmPtr = SharedPtr<FP32Imm>;
258
259 MIND_API_IMM_TRAIT(FP32Imm, float);
260
261 /// \brief FP64Imm defines interface for float64 data.
262 class MIND_API FP64Imm : public FloatImm {
263 public:
264 MIND_API_BASE_MEMBER(FP64Imm);
265
266 /// \brief Create FP64Imm with the given float value.
267 ///
268 /// \param[in] value The given float value.
269 explicit FP64Imm(double value);
270
271 /// \brief Get the float value of this FP64Imm.
272 ///
273 /// \return The float value of this FP64Imm.
274 double value() const;
275 };
276
277 using FP64ImmPtr = SharedPtr<FP64Imm>;
278
279 MIND_API_IMM_TRAIT(FP64Imm, double);
280
281 // === Utility functions for Value === //
282
283 /// \brief Create a Value object from a primitive type value.
284 ///
285 /// \param[in] v The primitive type value.
286 ///
287 /// \return The created Value object with the given primitive type value.
288 template <typename T, typename U = typename ImmTrait<T>::type::element_type>
MakeValue(T v)289 inline ValuePtr MakeValue(T v) {
290 return MakeShared<U>(v);
291 }
292
293 /// \brief Create a StringImm Value object from a C string.
294 ///
295 /// \param[in] s The C string.
296 ///
297 /// \return The created StringImm Value object.
MakeValue(const char * s)298 inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }
299
300 /// \brief Create an Int64Imm Value object from a int value.
301 ///
302 /// \param[in] i The int value.
303 ///
304 /// \return The created Int64Imm Value object.
MakeValue(int i)305 inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }
306
307 /// \brief Create a ValueSequence object from a vector of values.
308 ///
309 /// \param[in] values The vector of values.
310 ///
311 /// \return The created ValueSequence object.
MakeValue(const std::vector<ValuePtr> & values)312 inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }
313
314 /// \brief Create a ValueSequence object from a vector of primitive type values.
315 ///
316 /// \param[in] values The vector of primitive values.
317 ///
318 /// \return The created ValueSequence object.
319 template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
MakeValue(const T & values)320 inline ValuePtr MakeValue(const T &values) {
321 std::vector<ValuePtr> value_vector;
322 value_vector.reserve(values.size());
323 for (auto value : values) {
324 value_vector.emplace_back(MakeValue(value));
325 }
326 return MakeShared<ValueTuple>(value_vector);
327 }
328
329 /// \brief Get primitive type value from a Value object.
330 ///
331 /// \param[in] value The pointer to the Value object.
332 ///
333 /// \return The primitive type value of the Value object.
334 template <typename T, typename U = typename ImmTrait<T>::type>
GetValue(const ValuePtr & value)335 inline T GetValue(const ValuePtr &value) {
336 if (value == nullptr) {
337 return T();
338 }
339 U imm = value->cast<U>();
340 if (imm == nullptr) {
341 return T();
342 }
343 return imm->value();
344 }
345
346 /// \brief Get element values from a ValueSequence object.
347 ///
348 /// \param[in] value The pointer to the ValueSequence object.
349 ///
350 /// \return The values as a vector, empty if the input is not a ValueSequence.
351 template <typename T, typename S = typename std::decay_t<T>,
352 typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
GetValue(const ValuePtr & value)353 std::vector<U> GetValue(const ValuePtr &value) {
354 if (value == nullptr) {
355 return {};
356 }
357 auto seq = value->cast<ValueSequencePtr>();
358 if (seq == nullptr) {
359 return {};
360 }
361 if constexpr (std::is_same_v<ValuePtr, U>) {
362 return seq->value();
363 } else {
364 auto elements = seq->value();
365 std::vector<U> result;
366 result.reserve(elements.size());
367 for (auto &e : elements) {
368 result.emplace_back(GetValue<U>(e));
369 }
370 return result;
371 }
372 }
373 } // namespace mindspore::api
374 #endif // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
375