• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 #include "nlohmann/json.hpp"
23 
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/core/tensor_row.h"
26 #include "minddata/dataset/util/status.h"
27 #include "minddata/dataset/core/device_tensor.h"
28 #include "minddata/dataset/core/device_resource.h"
29 
30 #define IO_CHECK(input, output)                             \
31   do {                                                      \
32     if (input == nullptr || output == nullptr) {            \
33       RETURN_STATUS_UNEXPECTED("input or output is null."); \
34     }                                                       \
35   } while (false)
36 
37 #define IO_CHECK_VECTOR(input, output)              \
38   do {                                              \
39     if (output == nullptr) {                        \
40       RETURN_STATUS_UNEXPECTED("output is null.");  \
41     }                                               \
42     for (auto &_i : input) {                        \
43       if (_i == nullptr) {                          \
44         RETURN_STATUS_UNEXPECTED("input is null."); \
45       }                                             \
46     }                                               \
47   } while (false)
48 
49 namespace mindspore {
50 namespace dataset {
51 
52 // base class
53 constexpr char kTensorOp[] = "TensorOp";
54 
55 // image
56 constexpr char kAdjustGammaOp[] = "AdjustGammaOp";
57 constexpr char kAffineOp[] = "AffineOp";
58 constexpr char kAutoContrastOp[] = "AutoContrastOp";
59 constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp";
60 constexpr char kDecodeOp[] = "DecodeOp";
61 constexpr char kCenterCropOp[] = "CenterCropOp";
62 constexpr char kConvertColorOp[] = "ConvertColorOp";
63 constexpr char kCutMixBatchOp[] = "CutMixBatchOp";
64 constexpr char kCutOutOp[] = "CutOutOp";
65 constexpr char kCropOp[] = "CropOp";
66 constexpr char kDvppCropJpegOp[] = "DvppCropJpegOp";
67 constexpr char kDvppDecodeResizeCropJpegOp[] = "DvppDecodeResizeCropJpegOp";
68 constexpr char kDvppDecodeResizeJpegOp[] = "DvppDecodeResizeJpegOp";
69 constexpr char kDvppDecodeJpegOp[] = "DvppDecodeJpegOp";
70 constexpr char kDvppDecodePngOp[] = "DvppDecodePngOp";
71 constexpr char kDvppNormalizeOp[] = "DvppNormalizeOp";
72 constexpr char kDvppResizeJpegOp[] = "DvppResizeJpegOp";
73 constexpr char kEqualizeOp[] = "EqualizeOp";
74 constexpr char kGaussianBlurOp[] = "GaussianBlurOp";
75 constexpr char kHorizontalFlipOp[] = "HorizontalFlipOp";
76 constexpr char kHwcToChwOp[] = "HWC2CHWOp";
77 constexpr char kInvertOp[] = "InvertOp";
78 constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
79 constexpr char kNormalizeOp[] = "NormalizeOp";
80 constexpr char kNormalizePadOp[] = "NormalizePadOp";
81 constexpr char kPadOp[] = "PadOp";
82 constexpr char kRandomAffineOp[] = "RandomAffineOp";
83 constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
84 constexpr char kRandomColorOp[] = "RandomColorOp";
85 constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";
86 constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp";
87 constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp";
88 constexpr char kRandomCropOp[] = "RandomCropOp";
89 constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp";
90 constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp";
91 constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp";
92 constexpr char kRandomResizeOp[] = "RandomResizeOp";
93 constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp";
94 constexpr char kRandomRotationOp[] = "RandomRotationOp";
95 constexpr char kRandomSolarizeOp[] = "RandomSolarizeOp";
96 constexpr char kRandomSharpnessOp[] = "RandomSharpnessOp";
97 constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp";
98 constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
99 constexpr char kRescaleOp[] = "RescaleOp";
100 constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
101 constexpr char kResizeOp[] = "ResizeOp";
102 constexpr char kResizePreserveAROp[] = "ResizePreserveAROp";
103 constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
104 constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
105 constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";
106 constexpr char kRgbToBgrOp[] = "RgbToBgrOp";
107 constexpr char kRgbToGrayOp[] = "RgbToGrayOp";
108 constexpr char kRotateOp[] = "RotateOp";
109 constexpr char kSharpnessOp[] = "SharpnessOp";
110 constexpr char kSlicePatchesOp[] = "SlicePatchesOp";
111 constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
112 constexpr char kSoftDvppDecodeReiszeJpegOp[] = "SoftDvppDecodeReiszeJpegOp";
113 constexpr char kSolarizeOp[] = "SolarizeOp";
114 constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
115 constexpr char kUniformAugOp[] = "UniformAugOp";
116 constexpr char kVerticalFlipOp[] = "VerticalFlipOp";
117 
118 // text
119 constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp";
120 constexpr char kBertTokenizerOp[] = "BertTokenizerOp";
121 constexpr char kCaseFoldOp[] = "CaseFoldOp";
122 constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp";
123 constexpr char kLookupOp[] = "LookupOp";
124 constexpr char kNgramOp[] = "NgramOp";
125 constexpr char kSlidingWindowOp[] = "SlidingWindowOp";
126 constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op";
127 constexpr char kRegexReplaceOp[] = "RegexReplaceOp";
128 constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp";
129 constexpr char kToNumberOp[] = "ToNumberOp";
130 constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp";
131 constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp";
132 constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp";
133 constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp";
134 constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp";
135 constexpr char kRandomChoiceOp[] = "RandomChoiceOp";
136 constexpr char kRandomApplyOp[] = "RandomApplyOp";
137 constexpr char kComposeOp[] = "Compose";
138 constexpr char kRandomSelectSubpolicyOp[] = "RandomSelectSubpolicyOp";
139 constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp";
140 
141 // audio
142 constexpr char kAllpassBiquadOp[] = "AllpassBiquadOp";
143 constexpr char kAmplitudeToDBOp[] = "AmplitudeToDBOp";
144 constexpr char kAngleOp[] = "AngleOp";
145 constexpr char kBandBiquadOp[] = "BandBiquadOp";
146 constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";
147 constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
148 constexpr char kBassBiquadOp[] = "BassBiquadOp";
149 constexpr char kBiquadOp[] = "BiquadOp";
150 constexpr char kComplexNormOp[] = "ComplexNormOp";
151 constexpr char kContrastOp[] = "ContrastOp";
152 constexpr char kDCShiftOp[] = "DCShiftOp";
153 constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
154 constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
155 constexpr char kFadeOp[] = "FadeOp";
156 constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
157 constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
158 constexpr char kLFilterOp[] = "LFilterOp";
159 constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
160 constexpr char kMagphaseOp[] = "MagphaseOp";
161 constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
162 constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
163 constexpr char kTimeStretchOp[] = "TimeStretchOp";
164 constexpr char kVolOp[] = "VolOp";
165 
166 // data
167 constexpr char kConcatenateOp[] = "ConcatenateOp";
168 constexpr char kDuplicateOp[] = "DuplicateOp";
169 constexpr char kFillOp[] = "FillOp";
170 constexpr char kMaskOp[] = "MaskOp";
171 constexpr char kOneHotOp[] = "OneHotOp";
172 constexpr char kPadEndOp[] = "PadEndOp";
173 constexpr char kSliceOp[] = "SliceOp";
174 constexpr char kToFloat16Op[] = "ToFloat16Op";
175 constexpr char kTypeCastOp[] = "TypeCastOp";
176 constexpr char kUniqueOp[] = "UniqueOp";
177 
178 // other
179 constexpr char kCFuncOp[] = "CFuncOp";
180 constexpr char kPyFuncOp[] = "PyFuncOp";
181 constexpr char kPluginOp[] = "PluginOp";
182 constexpr char kNoOp[] = "NoOp";
183 
184 // A class that does a computation on a Tensor
185 class TensorOp {
186  public:
187   TensorOp() = default;
188 
189   virtual ~TensorOp() = default;
190 
191   // A function that prints info about the tensor operation
192   // @param out
Print(std::ostream & out)193   virtual void Print(std::ostream &out) const { out << Name() << std::endl; }
194 
195   // Provide stream operator for displaying it
196   // @param output stream
197   // @param so the TensorOp object to be printed
198   // @return output stream
199   friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) {
200     so.Print(out);
201     return out;
202   }
203 
204   // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp
205   // @param input  shares the ownership of the Tensor (increase the ref count).
206   // @param output the address to a shared_ptr where the result will be placed.
207   // @return Status
208   virtual Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
209 
210   // Perform an operation on Tensors from multiple columns, and produce multiple Tensors.
211   // This is for m-to-n column MapOp.
212   // @param input is a vector of shared_ptr to Tensor (pass by const reference).
213   // @param output is the address to an empty vector of shared_ptr to Tensor.
214   // @return Status
215   virtual Status Compute(const TensorRow &input, TensorRow *output);
216 
217   // Perform an operation on one DeviceTensor and produce one DeviceTensor. This is for 1-to-1 column MapOp
218   // @param input shares the ownership of the Tensor (increase the ref count).
219   // @param output the address to a shared_ptr where the result will be placed.
220   // @return Status
221   virtual Status Compute(const std::shared_ptr<DeviceTensor> &input, std::shared_ptr<DeviceTensor> *output);
222 
223   // Returns true oif the TensorOp takes one input and returns one output.
224   // @return true/false
OneToOne()225   bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; }
226 
227   // Returns true oif the TensorOp produces deterministic result.
228   // @return true/false
Deterministic()229   bool Deterministic() { return is_deterministic_; }
230 
231   // Function to determine the number of inputs the TensorOp can take. 0: means undefined.
232   // @return uint32_t
NumInput()233   virtual uint32_t NumInput() { return 1; }
234 
235   // Function to determine the number of output the TensorOp generates. 0: means undefined.
236   // @return uint32_t
NumOutput()237   virtual uint32_t NumOutput() { return 1; }
238 
239   // Function to determine the shapes of the output tensor given the input tensors' shapes.
240   // If a subclass did not override this function, it means that the shape does not change.
241   // @param inputs in: vector of the shapes of the input tensors.
242   // @param outputs out: vector of the shapes of the output tensors to be filled.
243   // @return Status
244   virtual Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs);
245 
246   // Function to determine the types of the output tensor given the input tensor's types.
247   // If a subclass did not override this function, it means that the type does not change.
248   // @param inputs in: vector of the types of the input tensors.
249   // @param outputs out: vector of the types of the output tensors to be filled.
250   // @return Status
251   virtual Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs);
252 
253   virtual std::string Name() const = 0;
254 
to_json(nlohmann::json * out_json)255   virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
256 
257   virtual Status SetAscendResource(const std::shared_ptr<DeviceResource> &resource);
258 
259  protected:
260   bool is_deterministic_{true};
261 };
262 }  // namespace dataset
263 }  // namespace mindspore
264 
265 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TENSOR_OP_H_
266