• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "minddata/dataset/core/data_type.h"
26 #include "minddata/dataset/kernels/ir/tensor_operation.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 
31 // Transform operations for performing data transformation.
32 namespace transforms {
33 
34 // Char arrays storing name of corresponding classes (in alphabetical order)
35 constexpr char kComposeOperation[] = "Compose";
36 constexpr char kConcatenateOperation[] = "Concatenate";
37 constexpr char kDuplicateOperation[] = "Duplicate";
38 constexpr char kFillOperation[] = "Fill";
39 constexpr char kMaskOperation[] = "Mask";
40 constexpr char kOneHotOperation[] = "OneHot";
41 constexpr char kPadEndOperation[] = "PadEnd";
42 constexpr char kPreBuiltOperation[] = "PreBuilt";
43 constexpr char kSliceOperation[] = "Slice";
44 constexpr char kRandomApplyOperation[] = "RandomApply";
45 constexpr char kRandomChoiceOperation[] = "RandomChoice";
46 constexpr char kTypeCastOperation[] = "TypeCast";
47 constexpr char kUniqueOperation[] = "Unique";
48 constexpr char kPluginOperation[] = "Plugin";
49 /* ####################################### Derived TensorOperation classes ################################# */
50 
51 class ComposeOperation : public TensorOperation {
52  public:
53   explicit ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
54 
55   ~ComposeOperation() = default;
56 
57   std::shared_ptr<TensorOp> Build() override;
58 
59   Status ValidateParams() override;
60 
Name()61   std::string Name() const override { return kComposeOperation; }
62 
63  private:
64   std::vector<std::shared_ptr<TensorOperation>> transforms_;
65 };
66 
67 class ConcatenateOperation : public TensorOperation {
68  public:
69   explicit ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend,
70                                 const std::shared_ptr<Tensor> &append);
71 
72   ~ConcatenateOperation() = default;
73 
74   std::shared_ptr<TensorOp> Build() override;
75 
76   Status ValidateParams() override;
77 
Name()78   std::string Name() const override { return kConcatenateOperation; }
79 
80  private:
81   int8_t axis_;
82   std::shared_ptr<Tensor> prepend_;
83   std::shared_ptr<Tensor> append_;
84 };
85 
86 class DuplicateOperation : public TensorOperation {
87  public:
88   DuplicateOperation() = default;
89 
90   ~DuplicateOperation() = default;
91 
92   std::shared_ptr<TensorOp> Build() override;
93 
94   Status ValidateParams() override;
95 
Name()96   std::string Name() const override { return kDuplicateOperation; }
97 };
98 
99 class FillOperation : public TensorOperation {
100  public:
101   explicit FillOperation(const std::shared_ptr<Tensor> &fill_value);
102 
103   ~FillOperation() = default;
104 
105   std::shared_ptr<TensorOp> Build() override;
106 
107   Status ValidateParams() override;
108 
Name()109   std::string Name() const override { return kFillOperation; }
110 
111   Status to_json(nlohmann::json *out_json) override;
112 
113   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
114 
115  private:
116   std::shared_ptr<Tensor> fill_value_;
117 };
118 
119 class MaskOperation : public TensorOperation {
120  public:
121   explicit MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype);
122 
123   ~MaskOperation() = default;
124 
125   std::shared_ptr<TensorOp> Build() override;
126 
127   Status ValidateParams() override;
128 
Name()129   std::string Name() const override { return kMaskOperation; }
130 
131  private:
132   RelationalOp op_;
133   std::shared_ptr<Tensor> constant_;
134   DataType dtype_;
135 };
136 
137 class OneHotOperation : public TensorOperation {
138  public:
139   explicit OneHotOperation(int32_t num_classes);
140 
141   ~OneHotOperation() = default;
142 
143   std::shared_ptr<TensorOp> Build() override;
144 
145   Status ValidateParams() override;
146 
Name()147   std::string Name() const override { return kOneHotOperation; }
148 
149   Status to_json(nlohmann::json *out_json) override;
150 
151   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
152 
153  private:
154   int32_t num_classes_;
155 };
156 
157 class PadEndOperation : public TensorOperation {
158  public:
159   explicit PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value);
160 
161   ~PadEndOperation() = default;
162 
163   std::shared_ptr<TensorOp> Build() override;
164 
165   Status ValidateParams() override;
166 
Name()167   std::string Name() const override { return kPadEndOperation; }
168 
169  private:
170   TensorShape pad_shape_;
171   std::shared_ptr<Tensor> pad_value_;
172 };
173 
174 class PreBuiltOperation : public TensorOperation {
175  public:
176   explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op);
177 
178   ~PreBuiltOperation() = default;
179 
180   std::shared_ptr<TensorOp> Build() override;
181 
182   Status ValidateParams() override;
183 
184   std::string Name() const override;
185 
186   Status to_json(nlohmann::json *out_json) override;
187 
188  private:
189   std::shared_ptr<TensorOp> op_;
190 };
191 
192 class RandomApplyOperation : public TensorOperation {
193  public:
194   explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob);
195 
196   ~RandomApplyOperation() = default;
197 
198   std::shared_ptr<TensorOp> Build() override;
199 
200   Status ValidateParams() override;
201 
Name()202   std::string Name() const override { return kRandomApplyOperation; }
203 
204  private:
205   std::vector<std::shared_ptr<TensorOperation>> transforms_;
206   double prob_;
207 };
208 
209 class RandomChoiceOperation : public TensorOperation {
210  public:
211   explicit RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms);
212 
213   ~RandomChoiceOperation() = default;
214 
215   std::shared_ptr<TensorOp> Build() override;
216 
217   Status ValidateParams() override;
218 
Name()219   std::string Name() const override { return kRandomChoiceOperation; }
220 
221  private:
222   std::vector<std::shared_ptr<TensorOperation>> transforms_;
223 };
224 
225 class SliceOperation : public TensorOperation {
226  public:
227   explicit SliceOperation(const std::vector<SliceOption> &slice_input);
228 
229   ~SliceOperation() = default;
230 
231   std::shared_ptr<TensorOp> Build() override;
232 
233   Status ValidateParams() override;
234 
Name()235   std::string Name() const override { return kSliceOperation; }
236 
237  private:
238   std::vector<SliceOption> slice_input_;
239 };
240 
241 class TypeCastOperation : public TensorOperation {
242  public:
243   explicit TypeCastOperation(const DataType &data_type);     // Used for C++ API
244   explicit TypeCastOperation(const std::string &data_type);  // Used for Pybind
245 
246   ~TypeCastOperation() = default;
247 
248   std::shared_ptr<TensorOp> Build() override;
249 
250   Status ValidateParams() override;
251 
Name()252   std::string Name() const override { return kTypeCastOperation; }
253 
254   Status to_json(nlohmann::json *out_json) override;
255 
256   static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
257 
258  private:
259   DataType data_type_;
260 };
261 
262 #ifndef ENABLE_ANDROID
263 class UniqueOperation : public TensorOperation {
264  public:
265   UniqueOperation() = default;
266 
267   ~UniqueOperation() = default;
268 
269   std::shared_ptr<TensorOp> Build() override;
270 
271   Status ValidateParams() override;
272 
Name()273   std::string Name() const override { return kUniqueOperation; }
274 };
275 
276 class PluginOperation : public TensorOperation {
277  public:
PluginOperation(const std::string & lib_path,const std::string & func_name,const std::string & user_args)278   explicit PluginOperation(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
279       : lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {}
280 
281   ~PluginOperation() = default;
282 
283   std::shared_ptr<TensorOp> Build() override;
284 
285   Status ValidateParams() override;
286 
Name()287   std::string Name() const override { return kPluginOperation; }
288 
289  private:
290   std::string lib_path_;
291   std::string func_name_;
292   std::string user_args_;
293 };
294 
295 #endif
296 
297 }  // namespace transforms
298 }  // namespace dataset
299 }  // namespace mindspore
300 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_
301