• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
18 #define MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
19 
20 #include <vector>
21 #include <string>
22 #include <type_traits>
23 #include "mindapi/base/base.h"
24 #include "mindapi/ir/common.h"
25 
26 namespace mindspore::api {
27 template <typename T>
28 struct ImmTrait {};
29 
30 #define MIND_API_IMM_TRAIT(typeimm, prototype) \
31   template <>                                  \
32   struct ImmTrait<prototype> {                 \
33     using type = SharedPtr<typeimm>;           \
34   }
35 
36 /// \brief Value represents a value in expression.
37 class MIND_API Value : public Base {
38  public:
39   MIND_API_BASE_MEMBER(Value);
40 
41   /// \brief Get the type of this Value.
42   ///
43   /// \return The type.
44   TypePtr type() const;
45 
46   /// \brief Get the abstract of this Value.
47   ///
48   /// \return Abstract of this Value.
49   AbstractBasePtr ToAbstract() const;
50 };
51 
52 /// \brief ValueSequence represents a sequence of values.
53 class MIND_API ValueSequence : public Value {
54  public:
55   MIND_API_BASE_MEMBER(ValueSequence);
56 
57   /// \brief Get the size of this ValueSequence.
58   ///
59   /// \return The size as the number of elements.
60   std::size_t size() const;
61 
62   /// \brief Get the list of values in this ValueSequence.
63   ///
64   /// \return The list of element values.
65   std::vector<ValuePtr> value() const;
66 };
67 
68 using ValueSequencePtr = SharedPtr<ValueSequence>;
69 
70 /// \brief ValueTuple represents a value tuple.
71 class MIND_API ValueTuple : public ValueSequence {
72  public:
73   MIND_API_BASE_MEMBER(ValueTuple);
74 
75   /// \brief Constructor of ValueTuple.
76   ///
77   /// \param[in] elements The elements of the tuple.
78   explicit ValueTuple(const std::vector<ValuePtr> &elements);
79 };
80 
81 using ValueTuplePtr = SharedPtr<ValueTuple>;
82 
83 /// \brief StringImm defines a Value whose type is string.
84 class MIND_API StringImm : public Value {
85  public:
86   MIND_API_BASE_MEMBER(StringImm);
87 
88   /// \brief Create StringImm with the given string.
89   ///
90   /// \param[in] str The given string value.
91   explicit StringImm(const std::string &str);
92 
93   /// \brief Get the string value of this StringImm.
94   ///
95   /// \return The string value of this StringImm.
96   const std::string &value() const;
97 };
98 
99 using StringImmPtr = SharedPtr<StringImm>;
100 
101 MIND_API_IMM_TRAIT(StringImm, std::string);
102 
103 /// \brief Scalar defines interface for scalar data.
104 class MIND_API Scalar : public Value {
105  public:
106   MIND_API_BASE_MEMBER(Scalar);
107 };
108 
109 /// \brief BoolImm defines interface for bool data.
110 class MIND_API BoolImm : public Scalar {
111  public:
112   MIND_API_BASE_MEMBER(BoolImm);
113 
114   /// \brief Create BoolImm with the given bool value.
115   ///
116   /// \param[in] b The given bool value.
117   explicit BoolImm(bool b);
118 
119   /// \brief Get the bool value of this BoolImm.
120   ///
121   /// \return The bool value of this BoolImm.
122   bool value() const;
123 };
124 
125 using BoolImmPtr = SharedPtr<BoolImm>;
126 
127 MIND_API_IMM_TRAIT(BoolImm, bool);
128 
129 /// \brief IntegerImm defines interface for integer data.
130 class MIND_API IntegerImm : public Scalar {
131  public:
132   MIND_API_BASE_MEMBER(IntegerImm);
133 };
134 
135 /// \brief Int8Imm defines interface for int8 data.
136 class MIND_API Int8Imm : public IntegerImm {
137  public:
138   MIND_API_BASE_MEMBER(Int8Imm);
139 
140   /// \brief Create Int8Imm with the given int8 value.
141   ///
142   /// \param[in] value The given int8 value.
143   explicit Int8Imm(int8_t value);
144 
145   /// \brief Get the int8 value of this Int8Imm.
146   ///
147   /// \return The int8 value of this Int8Imm.
148   int8_t value() const;
149 };
150 
151 using Int8ImmPtr = SharedPtr<Int8Imm>;
152 
153 MIND_API_IMM_TRAIT(Int8Imm, int8_t);
154 
155 /// \brief Int16Imm defines interface for int16 data.
156 class MIND_API Int16Imm : public IntegerImm {
157  public:
158   MIND_API_BASE_MEMBER(Int16Imm);
159 
160   /// \brief Create Int1I6mm with the given int16 value.
161   ///
162   /// \param[in] value The given int16 value.
163   explicit Int16Imm(int16_t value);
164 
165   /// \brief Get the int16 value of this Int16Imm.
166   ///
167   /// \return The int16 value of this Int16Imm.
168   int16_t value() const;
169 };
170 
171 using Int16ImmPtr = SharedPtr<Int16Imm>;
172 
173 MIND_API_IMM_TRAIT(Int16Imm, int16_t);
174 
175 /// \brief Int32Imm defines interface for int32 data.
176 class MIND_API Int32Imm : public IntegerImm {
177  public:
178   MIND_API_BASE_MEMBER(Int32Imm);
179 
180   /// \brief Create Int32Imm with the given int32 value.
181   ///
182   /// \param[in] value The given int32 value.
183   explicit Int32Imm(int32_t value);
184 
185   /// \brief Get the int32 value of this Int32Imm.
186   ///
187   /// \return The int32 value of this Int32Imm.
188   int32_t value() const;
189 };
190 
191 using Int32ImmPtr = SharedPtr<Int32Imm>;
192 
193 MIND_API_IMM_TRAIT(Int32Imm, int32_t);
194 
195 /// \brief Int64Imm defines interface for int64 data.
196 class MIND_API Int64Imm : public IntegerImm {
197  public:
198   MIND_API_BASE_MEMBER(Int64Imm);
199 
200   /// \brief Create Int64Imm with the given int64 value.
201   ///
202   /// \param[in] value The given int64 value.
203   explicit Int64Imm(int64_t value);
204 
205   /// \brief Get the int64 value of this Int64Imm.
206   ///
207   /// \return The int64 value of this Int64Imm.
208   int64_t value() const;
209 };
210 
211 using Int64ImmPtr = SharedPtr<Int64Imm>;
212 
213 MIND_API_IMM_TRAIT(Int64Imm, int64_t);
214 
215 /// \brief UInt8Imm defines interface for uint8 data.
216 class MIND_API UInt8Imm : public IntegerImm {
217  public:
218   MIND_API_BASE_MEMBER(UInt8Imm);
219 
220   /// \brief Create UInt8Imm with the given uint8 value.
221   ///
222   /// \param[in] value The given uint8 value.
223   explicit UInt8Imm(uint8_t value);
224 
225   /// \brief Get the uint8 value of this UInt8Imm.
226   ///
227   /// \return The uint8 value of this UInt8Imm.
228   uint8_t value() const;
229 };
230 
231 using UInt8ImmPtr = SharedPtr<UInt8Imm>;
232 
233 MIND_API_IMM_TRAIT(UInt8Imm, uint8_t);
234 
235 /// \brief FloatImm defines interface for float data.
236 class MIND_API FloatImm : public Scalar {
237  public:
238   MIND_API_BASE_MEMBER(FloatImm);
239 };
240 
241 /// \brief FP32Imm defines interface for float32 data.
242 class MIND_API FP32Imm : public FloatImm {
243  public:
244   MIND_API_BASE_MEMBER(FP32Imm);
245 
246   /// \brief Create FP32Imm with the given float value.
247   ///
248   /// \param[in] value The given float value.
249   explicit FP32Imm(float value);
250 
251   /// \brief Get the float value of this FP32Imm.
252   ///
253   /// \return The float value of this FP32Imm.
254   float value() const;
255 };
256 
257 using FP32ImmPtr = SharedPtr<FP32Imm>;
258 
259 MIND_API_IMM_TRAIT(FP32Imm, float);
260 
261 /// \brief FP64Imm defines interface for float64 data.
262 class MIND_API FP64Imm : public FloatImm {
263  public:
264   MIND_API_BASE_MEMBER(FP64Imm);
265 
266   /// \brief Create FP64Imm with the given float value.
267   ///
268   /// \param[in] value The given float value.
269   explicit FP64Imm(double value);
270 
271   /// \brief Get the float value of this FP64Imm.
272   ///
273   /// \return The float value of this FP64Imm.
274   double value() const;
275 };
276 
277 using FP64ImmPtr = SharedPtr<FP64Imm>;
278 
279 MIND_API_IMM_TRAIT(FP64Imm, double);
280 
281 // === Utility functions for Value === //
282 
283 /// \brief Create a Value object from a primitive type value.
284 ///
285 /// \param[in] v The primitive type value.
286 ///
287 /// \return The created Value object with the given primitive type value.
288 template <typename T, typename U = typename ImmTrait<T>::type::element_type>
MakeValue(T v)289 inline ValuePtr MakeValue(T v) {
290   return MakeShared<U>(v);
291 }
292 
293 /// \brief Create a StringImm Value object from a C string.
294 ///
295 /// \param[in] s The C string.
296 ///
297 /// \return The created StringImm Value object.
MakeValue(const char * s)298 inline ValuePtr MakeValue(const char *s) { return MakeShared<StringImm>(std::string(s)); }
299 
300 /// \brief Create an Int64Imm Value object from a int value.
301 ///
302 /// \param[in] i The int value.
303 ///
304 /// \return The created Int64Imm Value object.
MakeValue(int i)305 inline ValuePtr MakeValue(int i) { return MakeShared<Int64Imm>(static_cast<int64_t>(i)); }
306 
307 /// \brief Create a ValueSequence object from a vector of values.
308 ///
309 /// \param[in] values The vector of values.
310 ///
311 /// \return The created ValueSequence object.
MakeValue(const std::vector<ValuePtr> & values)312 inline ValuePtr MakeValue(const std::vector<ValuePtr> &values) { return MakeShared<ValueTuple>(values); }
313 
314 /// \brief Create a ValueSequence object from a vector of primitive type values.
315 ///
316 /// \param[in] values The vector of primitive values.
317 ///
318 /// \return The created ValueSequence object.
319 template <typename T, typename = typename std::enable_if_t<is_vector<T>::value, T>>
MakeValue(const T & values)320 inline ValuePtr MakeValue(const T &values) {
321   std::vector<ValuePtr> value_vector;
322   value_vector.reserve(values.size());
323   for (auto value : values) {
324     value_vector.emplace_back(MakeValue(value));
325   }
326   return MakeShared<ValueTuple>(value_vector);
327 }
328 
329 /// \brief Get primitive type value from a Value object.
330 ///
331 /// \param[in] value The pointer to the Value object.
332 ///
333 /// \return The primitive type value of the Value object.
334 template <typename T, typename U = typename ImmTrait<T>::type>
GetValue(const ValuePtr & value)335 inline T GetValue(const ValuePtr &value) {
336   if (value == nullptr) {
337     return T();
338   }
339   U imm = value->cast<U>();
340   if (imm == nullptr) {
341     return T();
342   }
343   return imm->value();
344 }
345 
346 /// \brief Get element values from a ValueSequence object.
347 ///
348 /// \param[in] value The pointer to the ValueSequence object.
349 ///
350 /// \return The values as a vector, empty if the input is not a ValueSequence.
351 template <typename T, typename S = typename std::decay_t<T>,
352           typename U = typename std::enable_if_t<is_vector<S>::value, typename S::value_type>>
GetValue(const ValuePtr & value)353 std::vector<U> GetValue(const ValuePtr &value) {
354   if (value == nullptr) {
355     return {};
356   }
357   auto seq = value->cast<ValueSequencePtr>();
358   if (seq == nullptr) {
359     return {};
360   }
361   if constexpr (std::is_same_v<ValuePtr, U>) {
362     return seq->value();
363   } else {
364     auto elements = seq->value();
365     std::vector<U> result;
366     result.reserve(elements.size());
367     for (auto &e : elements) {
368       result.emplace_back(GetValue<U>(e));
369     }
370     return result;
371   }
372 }
373 }  // namespace mindspore::api
374 #endif  // MINDSPORE_CORE_MINDAPI_IR_VALUE_H_
375