1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "minddata/dataset/core/data_type.h" 26 #include "minddata/dataset/kernels/ir/tensor_operation.h" 27 28 namespace mindspore { 29 namespace dataset { 30 31 // Transform operations for performing data transformation. 32 namespace transforms { 33 34 // Char arrays storing name of corresponding classes (in alphabetical order) 35 constexpr char kComposeOperation[] = "Compose"; 36 constexpr char kConcatenateOperation[] = "Concatenate"; 37 constexpr char kDuplicateOperation[] = "Duplicate"; 38 constexpr char kFillOperation[] = "Fill"; 39 constexpr char kMaskOperation[] = "Mask"; 40 constexpr char kOneHotOperation[] = "OneHot"; 41 constexpr char kPadEndOperation[] = "PadEnd"; 42 constexpr char kPreBuiltOperation[] = "PreBuilt"; 43 constexpr char kSliceOperation[] = "Slice"; 44 constexpr char kRandomApplyOperation[] = "RandomApply"; 45 constexpr char kRandomChoiceOperation[] = "RandomChoice"; 46 constexpr char kTypeCastOperation[] = "TypeCast"; 47 constexpr char kUniqueOperation[] = "Unique"; 48 constexpr char kPluginOperation[] = "Plugin"; 49 /* ####################################### Derived TensorOperation classes ################################# */ 50 51 class ComposeOperation : public TensorOperation { 52 public: 53 explicit ComposeOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms); 54 55 ~ComposeOperation() = default; 56 57 std::shared_ptr<TensorOp> Build() override; 58 59 Status ValidateParams() override; 60 Name()61 std::string Name() const override { return kComposeOperation; } 62 63 private: 64 std::vector<std::shared_ptr<TensorOperation>> transforms_; 65 }; 66 67 class ConcatenateOperation : public TensorOperation { 68 public: 69 explicit ConcatenateOperation(int8_t axis, const std::shared_ptr<Tensor> &prepend, 70 const std::shared_ptr<Tensor> &append); 71 72 ~ConcatenateOperation() = default; 73 74 std::shared_ptr<TensorOp> Build() override; 75 76 Status ValidateParams() override; 77 Name()78 std::string Name() const override { return kConcatenateOperation; } 79 80 private: 81 int8_t axis_; 82 std::shared_ptr<Tensor> prepend_; 83 std::shared_ptr<Tensor> append_; 84 }; 85 86 class DuplicateOperation : public TensorOperation { 87 public: 88 DuplicateOperation() = default; 89 90 ~DuplicateOperation() = default; 91 92 std::shared_ptr<TensorOp> Build() override; 93 94 Status ValidateParams() override; 95 Name()96 std::string Name() const override { return kDuplicateOperation; } 97 }; 98 99 class FillOperation : public TensorOperation { 100 public: 101 explicit FillOperation(const std::shared_ptr<Tensor> &fill_value); 102 103 ~FillOperation() = default; 104 105 std::shared_ptr<TensorOp> Build() override; 106 107 Status ValidateParams() override; 108 Name()109 std::string Name() const override { return kFillOperation; } 110 111 Status to_json(nlohmann::json *out_json) override; 112 113 static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); 114 115 private: 116 std::shared_ptr<Tensor> fill_value_; 117 }; 118 119 class MaskOperation : public TensorOperation { 120 public: 121 explicit MaskOperation(RelationalOp op, const std::shared_ptr<Tensor> &constant, const DataType &dtype); 122 123 ~MaskOperation() = default; 124 125 std::shared_ptr<TensorOp> Build() override; 126 127 Status ValidateParams() override; 128 Name()129 std::string Name() const override { return kMaskOperation; } 130 131 private: 132 RelationalOp op_; 133 std::shared_ptr<Tensor> constant_; 134 DataType dtype_; 135 }; 136 137 class OneHotOperation : public TensorOperation { 138 public: 139 explicit OneHotOperation(int32_t num_classes); 140 141 ~OneHotOperation() = default; 142 143 std::shared_ptr<TensorOp> Build() override; 144 145 Status ValidateParams() override; 146 Name()147 std::string Name() const override { return kOneHotOperation; } 148 149 Status to_json(nlohmann::json *out_json) override; 150 151 static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); 152 153 private: 154 int32_t num_classes_; 155 }; 156 157 class PadEndOperation : public TensorOperation { 158 public: 159 explicit PadEndOperation(const TensorShape &pad_shape, const std::shared_ptr<Tensor> &pad_value); 160 161 ~PadEndOperation() = default; 162 163 std::shared_ptr<TensorOp> Build() override; 164 165 Status ValidateParams() override; 166 Name()167 std::string Name() const override { return kPadEndOperation; } 168 169 private: 170 TensorShape pad_shape_; 171 std::shared_ptr<Tensor> pad_value_; 172 }; 173 174 class PreBuiltOperation : public TensorOperation { 175 public: 176 explicit PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op); 177 178 ~PreBuiltOperation() = default; 179 180 std::shared_ptr<TensorOp> Build() override; 181 182 Status ValidateParams() override; 183 184 std::string Name() const override; 185 186 Status to_json(nlohmann::json *out_json) override; 187 188 private: 189 std::shared_ptr<TensorOp> op_; 190 }; 191 192 class RandomApplyOperation : public TensorOperation { 193 public: 194 explicit RandomApplyOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms, double prob); 195 196 ~RandomApplyOperation() = default; 197 198 std::shared_ptr<TensorOp> Build() override; 199 200 Status ValidateParams() override; 201 Name()202 std::string Name() const override { return kRandomApplyOperation; } 203 204 private: 205 std::vector<std::shared_ptr<TensorOperation>> transforms_; 206 double prob_; 207 }; 208 209 class RandomChoiceOperation : public TensorOperation { 210 public: 211 explicit RandomChoiceOperation(const std::vector<std::shared_ptr<TensorOperation>> &transforms); 212 213 ~RandomChoiceOperation() = default; 214 215 std::shared_ptr<TensorOp> Build() override; 216 217 Status ValidateParams() override; 218 Name()219 std::string Name() const override { return kRandomChoiceOperation; } 220 221 private: 222 std::vector<std::shared_ptr<TensorOperation>> transforms_; 223 }; 224 225 class SliceOperation : public TensorOperation { 226 public: 227 explicit SliceOperation(const std::vector<SliceOption> &slice_input); 228 229 ~SliceOperation() = default; 230 231 std::shared_ptr<TensorOp> Build() override; 232 233 Status ValidateParams() override; 234 Name()235 std::string Name() const override { return kSliceOperation; } 236 237 private: 238 std::vector<SliceOption> slice_input_; 239 }; 240 241 class TypeCastOperation : public TensorOperation { 242 public: 243 explicit TypeCastOperation(const DataType &data_type); // Used for C++ API 244 explicit TypeCastOperation(const std::string &data_type); // Used for Pybind 245 246 ~TypeCastOperation() = default; 247 248 std::shared_ptr<TensorOp> Build() override; 249 250 Status ValidateParams() override; 251 Name()252 std::string Name() const override { return kTypeCastOperation; } 253 254 Status to_json(nlohmann::json *out_json) override; 255 256 static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation); 257 258 private: 259 DataType data_type_; 260 }; 261 262 #ifndef ENABLE_ANDROID 263 class UniqueOperation : public TensorOperation { 264 public: 265 UniqueOperation() = default; 266 267 ~UniqueOperation() = default; 268 269 std::shared_ptr<TensorOp> Build() override; 270 271 Status ValidateParams() override; 272 Name()273 std::string Name() const override { return kUniqueOperation; } 274 }; 275 276 class PluginOperation : public TensorOperation { 277 public: PluginOperation(const std::string & lib_path,const std::string & func_name,const std::string & user_args)278 explicit PluginOperation(const std::string &lib_path, const std::string &func_name, const std::string &user_args) 279 : lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {} 280 281 ~PluginOperation() = default; 282 283 std::shared_ptr<TensorOp> Build() override; 284 285 Status ValidateParams() override; 286 Name()287 std::string Name() const override { return kPluginOperation; } 288 289 private: 290 std::string lib_path_; 291 std::string func_name_; 292 std::string user_args_; 293 }; 294 295 #endif 296 297 } // namespace transforms 298 } // namespace dataset 299 } // namespace mindspore 300 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_DATA_TRANSFORMS_IR_H_ 301