• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "minddata/dataset/core/data_type.h"
25 #include "minddata/dataset/engine/data_schema.h"
26 #include "minddata/dataset/include/dataset/datasets.h"
27 #include "minddata/dataset/kernels/ir/tensor_operation.h"
28 
29 namespace mindspore {
30 namespace dataset {
31 // Transform operations for performing data transformation.
32 namespace transforms {
33 // Char arrays storing name of corresponding classes (in alphabetical order)
34 constexpr char kComposeOperation[] = "Compose";
35 constexpr char kConcatenateOperation[] = "Concatenate";
36 constexpr char kDuplicateOperation[] = "Duplicate";
37 constexpr char kFillOperation[] = "Fill";
38 constexpr char kMaskOperation[] = "Mask";
39 constexpr char kOneHotOperation[] = "OneHot";
40 constexpr char kPadEndOperation[] = "PadEnd";
41 constexpr char kParseExampleOperation[] = "ParseExample";
42 constexpr char kPluginOperation[] = "Plugin";
43 constexpr char kPreBuiltOperation[] = "PreBuilt";
44 constexpr char kRandomApplyOperation[] = "RandomApply";
45 constexpr char kRandomChoiceOperation[] = "RandomChoice";
46 constexpr char kSliceOperation[] = "Slice";
47 constexpr char kTypeCastOperation[] = "TypeCast";
48 constexpr char kUniqueOperation[] = "Unique";
49 /* ####################################### Derived TensorOperation classes ################################# */
50 
51 class ComposeOperation : public TensorOperation {
52  public:
53   explicit ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
54 
55   ~ComposeOperation() override = default;
56 
57   std::shared_ptr<TensorOp> Build() override;
58 
59   Status ValidateParams() override;
60 
Name()61   std::string Name() const override { return kComposeOperation; }
62 
63   Status to_json(nlohmann::json *out_json) override;
64 
65   static Status from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation);
66 
67   // Get the compose type: kInvalid / kAscend910B / kCpu.
Type()68   virtual MapTargetDevice Type() {
69     bool have_dvpp = false;
70     bool have_cpu = false;
71     for (auto &item : transforms_) {
72       if (item->Type() == MapTargetDevice::kAscend910B) {
73         have_dvpp = true;
74       } else if (item->Type() == MapTargetDevice::kCpu) {
75         have_cpu = true;
76       } else {
77         MS_LOG(ERROR) << "The transform: " << item->Name() << " is not Ascend or Cpu.";
78         return MapTargetDevice::kInvalid;
79       }
80     }
81 
82     if (have_dvpp && have_cpu) {
83       MS_LOG(ERROR) << "Currently, it is not supported to mix DVPP transforms with CPU transforms in Compose.";
84       return MapTargetDevice::kInvalid;
85     } else if (have_dvpp) {
86       return MapTargetDevice::kAscend910B;
87     } else {
88       return MapTargetDevice::kCpu;
89     }
90   }
91 
92  private:
93   std::vector<std::shared_ptr<TensorOperation>> transforms_;
94 };
95 
96 class ConcatenateOperation : public TensorOperation {
97  public:
98   ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend, const std::shared_ptr<Tensor> &append);
99 
100   ~ConcatenateOperation() override = default;
101 
102   std::shared_ptr<TensorOp> Build() override;
103 
104   Status ValidateParams() override;
105 
Name()106   std::string Name() const override { return kConcatenateOperation; }
107 
108   Status to_json(nlohmann::json *out_json) override;
109 
110   static Status from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation);
111 
112  private:
113   int8_t axis_;
114   std::shared_ptr<Tensor> prepend_;
115   std::shared_ptr<Tensor> append_;
116 };
117 
118 class DuplicateOperation : public TensorOperation {
119  public:
120   DuplicateOperation() = default;
121 
122   ~DuplicateOperation() override = default;
123 
124   std::shared_ptr<TensorOp> Build() override;
125 
126   Status ValidateParams() override;
127 
Name()128   std::string Name() const override { return kDuplicateOperation; }
129 
130   static Status from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation);
131 };
132 
133 class FillOperation : public TensorOperation {
134  public:
135   explicit FillOperation(const std::shared_ptr<Tensor> &fill_value);
136 
137   ~FillOperation() override = default;
138 
139   std::shared_ptr<TensorOp> Build() override;
140 
141   Status ValidateParams() override;
142 
Name()143   std::string Name() const override { return kFillOperation; }
144 
145   Status to_json(nlohmann::json *out_json) override;
146 
147   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
148 
149  private:
150   std::shared_ptr<Tensor> fill_value_;
151 };
152 
153 class MaskOperation : public TensorOperation {
154  public:
155   MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype);
156 
157   ~MaskOperation() override = default;
158 
159   std::shared_ptr<TensorOp> Build() override;
160 
161   Status ValidateParams() override;
162 
Name()163   std::string Name() const override { return kMaskOperation; }
164 
165   Status to_json(nlohmann::json *out_json) override;
166 
167   static Status from_json(const nlohmann::json &op_params, std::shared_ptr<TensorOperation> *operation);
168 
169  private:
170   RelationalOp op_;
171   std::shared_ptr<Tensor> constant_;
172   DataType dtype_;
173 };
174 
175 class OneHotOperation : public TensorOperation {
176  public:
177   explicit OneHotOperation(int32_t num_classes, double smoothing_rate = 0.0);
178 
179   ~OneHotOperation() override = default;
180 
181   std::shared_ptr<TensorOp> Build() override;
182 
183   Status ValidateParams() override;
184 
Name()185   std::string Name() const override { return kOneHotOperation; }
186 
187   Status to_json(nlohmann::json *out_json) override;
188 
189   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
190 
191  private:
192   int32_t num_classes_;
193   double smoothing_rate_;
194 };
195 
196 class PadEndOperation : public TensorOperation {
197  public:
198   PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value);
199 
200   ~PadEndOperation() override = default;
201 
202   std::shared_ptr<TensorOp> Build() override;
203 
204   Status ValidateParams() override;
205 
Name()206   std::string Name() const override { return kPadEndOperation; }
207 
208   Status to_json(nlohmann::json *out_json) override;
209 
210   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
211 
212  private:
213   TensorShape pad_shape_;
214   std::shared_ptr<Tensor> pad_value_;
215 };
216 
217 class ParseExampleOperation : public TensorOperation {
218  public:
219   ParseExampleOperation(DataSchema schema, std::vector<std::string> column_list, bool parallel_parse);
220 
221   ~ParseExampleOperation() override = default;
222 
223   std::shared_ptr<TensorOp> Build() override;
224 
Name()225   std::string Name() const override { return kParseExampleOperation; }
226 
227  private:
228   DataSchema schema_;
229   std::vector<std::string> column_list_;
230   bool parallel_parse_;
231 };
232 
233 class PreBuiltOperation : public TensorOperation {
234  public:
235   explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
236 
237   ~PreBuiltOperation() override = default;
238 
239   std::shared_ptr<TensorOp> Build() override;
240 
241   Status ValidateParams() override;
242 
243   std::string Name() const override;
244 
245   Status to_json(nlohmann::json *out_json) override;
246 
247  private:
248   std::shared_ptr<TensorOp> op_;
249 };
250 
251 class RandomApplyOperation : public TensorOperation {
252  public:
253   RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
254 
255   ~RandomApplyOperation() override = default;
256 
257   std::shared_ptr<TensorOp> Build() override;
258 
259   Status ValidateParams() override;
260 
Name()261   std::string Name() const override { return kRandomApplyOperation; }
262 
263   Status to_json(nlohmann::json *out_json) override;
264 
265   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
266 
267  private:
268   std::vector<std::shared_ptr<TensorOperation>> transforms_;
269   double prob_;
270 };
271 
272 class RandomChoiceOperation : public TensorOperation {
273  public:
274   explicit RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
275 
276   ~RandomChoiceOperation() override = default;
277 
278   std::shared_ptr<TensorOp> Build() override;
279 
280   Status ValidateParams() override;
281 
Name()282   std::string Name() const override { return kRandomChoiceOperation; }
283 
284   Status to_json(nlohmann::json *out_json) override;
285 
286   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
287 
288  private:
289   std::vector<std::shared_ptr<TensorOperation>> transforms_;
290 };
291 
292 class SliceOperation : public TensorOperation {
293  public:
294   explicit SliceOperation(const std::vector<SliceOption> &slice_input);
295 
296   ~SliceOperation() override = default;
297 
298   std::shared_ptr<TensorOp> Build() override;
299 
300   Status ValidateParams() override;
301 
Name()302   std::string Name() const override { return kSliceOperation; }
303 
304  private:
305   std::vector<SliceOption> slice_input_;
306 };
307 
308 class TypeCastOperation : public TensorOperation {
309  public:
310   explicit TypeCastOperation(const DataType &data_type);  // Used for C++ API
311 
312   explicit TypeCastOperation(const std::string &data_type);  // Used for Pybind
313 
314   ~TypeCastOperation() override = default;
315 
316   std::shared_ptr<TensorOp> Build() override;
317 
318   Status ValidateParams() override;
319 
Name()320   std::string Name() const override { return kTypeCastOperation; }
321 
322   Status to_json(nlohmann::json *out_json) override;
323 
324   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
325 
326  private:
327   DataType data_type_;
328 };
329 
330 #ifndef ENABLE_ANDROID
331 class UniqueOperation : public TensorOperation {
332  public:
333   UniqueOperation() = default;
334 
335   ~UniqueOperation() override = default;
336 
337   std::shared_ptr<TensorOp> Build() override;
338 
339   Status ValidateParams() override;
340 
Name()341   std::string Name() const override { return kUniqueOperation; }
342 
343   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
344 };
345 
346 class PluginOperation : public TensorOperation {
347  public:
PluginOperation(const std::string & lib_path,const std::string & func_name,const std::string & user_args)348   explicit PluginOperation(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
349       : lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {}
350 
351   ~PluginOperation() override = default;
352 
353   std::shared_ptr<TensorOp> Build() override;
354 
355   Status ValidateParams() override;
356 
Name()357   std::string Name() const override { return kPluginOperation; }
358 
359   Status to_json(nlohmann::json *out_json) override;
360 
361   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
362 
363  private:
364   std::string lib_path_;
365   std::string func_name_;
366   std::string user_args_;
367 };
368 #endif
369 }  // namespace transforms
370 }  // namespace dataset
371 }  // namespace mindspore
372 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
373