1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 #include "nlohmann/json.hpp" 23 24 #include "minddata/dataset/core/tensor.h" 25 #include "minddata/dataset/core/tensor_row.h" 26 #include "minddata/dataset/util/status.h" 27 #include "minddata/dataset/core/device_tensor.h" 28 #include "minddata/dataset/core/device_resource.h" 29 30 #define IO_CHECK(input, output) \ 31 do { \ 32 if (input == nullptr || output == nullptr) { \ 33 RETURN_STATUS_UNEXPECTED("input or output is null."); \ 34 } \ 35 } while (false) 36 37 #define IO_CHECK_VECTOR(input, output) \ 38 do { \ 39 if (output == nullptr) { \ 40 RETURN_STATUS_UNEXPECTED("output is null."); \ 41 } \ 42 for (auto &_i : input) { \ 43 if (_i == nullptr) { \ 44 RETURN_STATUS_UNEXPECTED("input is null."); \ 45 } \ 46 } \ 47 } while (false) 48 49 namespace mindspore { 50 namespace dataset { 51 52 // base class 53 constexpr char kTensorOp[] = "TensorOp"; 54 55 // image 56 constexpr char kAdjustGammaOp[] = "AdjustGammaOp"; 57 constexpr char kAffineOp[] = "AffineOp"; 58 constexpr char kAutoContrastOp[] = "AutoContrastOp"; 59 constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; 60 constexpr char kDecodeOp[] = "DecodeOp"; 61 constexpr char kCenterCropOp[] = "CenterCropOp"; 62 constexpr char kConvertColorOp[] = "ConvertColorOp"; 63 constexpr char kCutMixBatchOp[] = "CutMixBatchOp"; 64 constexpr char kCutOutOp[] = "CutOutOp"; 65 constexpr char kCropOp[] = "CropOp"; 66 constexpr char kDvppCropJpegOp[] = "DvppCropJpegOp"; 67 constexpr char kDvppDecodeResizeCropJpegOp[] = "DvppDecodeResizeCropJpegOp"; 68 constexpr char kDvppDecodeResizeJpegOp[] = "DvppDecodeResizeJpegOp"; 69 constexpr char kDvppDecodeJpegOp[] = "DvppDecodeJpegOp"; 70 constexpr char kDvppDecodePngOp[] = "DvppDecodePngOp"; 71 constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp"; 72 constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp"; 73 constexpr char kEqualizeOp[] = "EqualizeOp"; 74 constexpr char kGaussianBlurOp[] = "GaussianBlurOp"; 75 constexpr char kHorizontalFlipOp[] = "HorizontalFlipOp"; 76 constexpr char kHwcToChwOp[] = "HWC2CHWOp"; 77 constexpr char kInvertOp[] = "InvertOp"; 78 constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; 79 constexpr char kNormalizeOp[] = "NormalizeOp"; 80 constexpr char kNormalizePadOp[] = "NormalizePadOp"; 81 constexpr char kPadOp[] = "PadOp"; 82 constexpr char kRandomAffineOp[] = "RandomAffineOp"; 83 constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; 84 constexpr char kRandomColorOp[] = "RandomColorOp"; 85 constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; 86 constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; 87 constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp"; 88 constexpr char kRandomCropOp[] = "RandomCropOp"; 89 constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp"; 90 constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp"; 91 constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp"; 92 constexpr char kRandomResizeOp[] = "RandomResizeOp"; 93 constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; 94 constexpr char kRandomRotationOp[] = "RandomRotationOp"; 95 constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp"; 96 constexpr char kRandomSharpnessOp[] = "RandomSharpnessOp"; 97 constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; 98 constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; 99 constexpr char kRescaleOp[] = "RescaleOp"; 100 constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; 101 constexpr char kResizeOp[] = "ResizeOp"; 102 constexpr char kResizePreserveAROp[] = "ResizePreserveAROp"; 103 constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; 104 constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp"; 105 constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp"; 106 constexpr char kRgbToBgrOp[] = "RgbToBgrOp"; 107 constexpr char kRgbToGrayOp[] = "RgbToGrayOp"; 108 constexpr char kRotateOp[] = "RotateOp"; 109 constexpr char kSharpnessOp[] = "SharpnessOp"; 110 constexpr char kSlicePatchesOp[] = "SlicePatchesOp"; 111 constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp"; 112 constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp"; 113 constexpr char kSolarizeOp[] = "SolarizeOp"; 114 constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; 115 constexpr char kUniformAugOp[] = "UniformAugOp"; 116 constexpr char kVerticalFlipOp[] = "VerticalFlipOp"; 117 118 // text 119 constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; 120 constexpr char kBertTokenizerOp[] = "BertTokenizerOp"; 121 constexpr char kCaseFoldOp[] = "CaseFoldOp"; 122 constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; 123 constexpr char kLookupOp[] = "LookupOp"; 124 constexpr char kNgramOp[] = "NgramOp"; 125 constexpr char kSlidingWindowOp[] = "SlidingWindowOp"; 126 constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; 127 constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; 128 constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; 129 constexpr char kToNumberOp[] = "ToNumberOp"; 130 constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp"; 131 constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp"; 132 constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp"; 133 constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp"; 134 constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp"; 135 constexpr char kRandomChoiceOp[] = "RandomChoiceOp"; 136 constexpr char kRandomApplyOp[] = "RandomApplyOp"; 137 constexpr char kComposeOp[] = "Compose"; 138 constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp"; 139 constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp"; 140 141 // audio 142 constexpr char kAllpassBiquadOp[] = "AllpassBiquadOp"; 143 constexpr char kAmplitudeToDBOp[] = "AmplitudeToDBOp"; 144 constexpr char kAngleOp[] = "AngleOp"; 145 constexpr char kBandBiquadOp[] = "BandBiquadOp"; 146 constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp"; 147 constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp"; 148 constexpr char kBassBiquadOp[] = "BassBiquadOp"; 149 constexpr char kBiquadOp[] = "BiquadOp"; 150 constexpr char kComplexNormOp[] = "ComplexNormOp"; 151 constexpr char kContrastOp[] = "ContrastOp"; 152 constexpr char kDCShiftOp[] = "DCShiftOp"; 153 constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp"; 154 constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp"; 155 constexpr char kFadeOp[] = "FadeOp"; 156 constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp"; 157 constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp"; 158 constexpr char kLFilterOp[] = "LFilterOp"; 159 constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp"; 160 constexpr char kMagphaseOp[] = "MagphaseOp"; 161 constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp"; 162 constexpr char kTimeMaskingOp[] = "TimeMaskingOp"; 163 constexpr char kTimeStretchOp[] = "TimeStretchOp"; 164 constexpr char kVolOp[] = "VolOp"; 165 166 // data 167 constexpr char kConcatenateOp[] = "ConcatenateOp"; 168 constexpr char kDuplicateOp[] = "DuplicateOp"; 169 constexpr char kFillOp[] = "FillOp"; 170 constexpr char kMaskOp[] = "MaskOp"; 171 constexpr char kOneHotOp[] = "OneHotOp"; 172 constexpr char kPadEndOp[] = "PadEndOp"; 173 constexpr char kSliceOp[] = "SliceOp"; 174 constexpr char kToFloat16Op[] = "ToFloat16Op"; 175 constexpr char kTypeCastOp[] = "TypeCastOp"; 176 constexpr char kUniqueOp[] = "UniqueOp"; 177 178 // other 179 constexpr char kCFuncOp[] = "CFuncOp"; 180 constexpr char kPyFuncOp[] = "PyFuncOp"; 181 constexpr char kPluginOp[] = "PluginOp"; 182 constexpr char kNoOp[] = "NoOp"; 183 184 // A class that does a computation on a Tensor 185 class TensorOp { 186 public: 187 TensorOp() = default; 188 189 virtual ~TensorOp() = default; 190 191 // A function that prints info about the tensor operation 192 // @param out Print(std::ostream & out)193 virtual void Print(std::ostream &out) const { out << Name() << std::endl; } 194 195 // Provide stream operator for displaying it 196 // @param output stream 197 // @param so the TensorOp object to be printed 198 // @return output stream 199 friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) { 200 so.Print(out); 201 return out; 202 } 203 204 // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp 205 // @param input shares the ownership of the Tensor (increase the ref count). 206 // @param output the address to a shared_ptr where the result will be placed. 207 // @return Status 208 virtual Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output); 209 210 // Perform an operation on Tensors from multiple columns, and produce multiple Tensors. 211 // This is for m-to-n column MapOp. 212 // @param input is a vector of shared_ptr to Tensor (pass by const reference). 213 // @param output is the address to an empty vector of shared_ptr to Tensor. 214 // @return Status 215 virtual Status Compute(const TensorRow &input, TensorRow *output); 216 217 // Perform an operation on one DeviceTensor and produce one DeviceTensor. This is for 1-to-1 column MapOp 218 // @param input shares the ownership of the Tensor (increase the ref count). 219 // @param output the address to a shared_ptr where the result will be placed. 220 // @return Status 221 virtual Status Compute(const std::shared_ptr<DeviceTensor> &input, std::shared_ptr<DeviceTensor> *output); 222 223 // Returns true oif the TensorOp takes one input and returns one output. 224 // @return true/false OneToOne()225 bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } 226 227 // Returns true oif the TensorOp produces deterministic result. 228 // @return true/false Deterministic()229 bool Deterministic() { return is_deterministic_; } 230 231 // Function to determine the number of inputs the TensorOp can take. 0: means undefined. 232 // @return uint32_t NumInput()233 virtual uint32_t NumInput() { return 1; } 234 235 // Function to determine the number of output the TensorOp generates. 0: means undefined. 236 // @return uint32_t NumOutput()237 virtual uint32_t NumOutput() { return 1; } 238 239 // Function to determine the shapes of the output tensor given the input tensors' shapes. 240 // If a subclass did not override this function, it means that the shape does not change. 241 // @param inputs in: vector of the shapes of the input tensors. 242 // @param outputs out: vector of the shapes of the output tensors to be filled. 243 // @return Status 244 virtual Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs); 245 246 // Function to determine the types of the output tensor given the input tensor's types. 247 // If a subclass did not override this function, it means that the type does not change. 248 // @param inputs in: vector of the types of the input tensors. 249 // @param outputs out: vector of the types of the output tensors to be filled. 250 // @return Status 251 virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs); 252 253 virtual std::string Name() const = 0; 254 to_json(nlohmann::json * out_json)255 virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } 256 257 virtual Status SetAscendResource(const std::shared_ptr<DeviceResource> &resource); 258 259 protected: 260 bool is_deterministic_{true}; 261 }; 262 } // namespace dataset 263 } // namespace mindspore 264 265 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_ 266