1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/ir/validators.h"
17
18 namespace mindspore {
19 namespace dataset {
20 /* ####################################### Validator Functions ############################################ */
ValidateProbability(const std::string & op_name,const double probability)21 Status ValidateProbability(const std::string &op_name, const double probability) {
22 if (probability < 0.0 || probability > 1.0) {
23 std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
24 MS_LOG(ERROR) << err_msg;
25 RETURN_STATUS_SYNTAX_ERROR(err_msg);
26 }
27
28 return Status::OK();
29 }
30
ValidateIntScalarPositive(const std::string & op_name,const std::string & scalar_name,int32_t scalar)31 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
32 RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
33 return Status::OK();
34 }
35
ValidateFloatScalarPositive(const std::string & op_name,const std::string & scalar_name,float scalar)36 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
37 RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
38 return Status::OK();
39 }
40
ValidateFloatScalarNonNegative(const std::string & op_name,const std::string & scalar_name,float scalar)41 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
42 RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, false));
43 return Status::OK();
44 }
45
ValidateVectorFillvalue(const std::string & op_name,const std::vector<uint8_t> & fill_value)46 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
47 if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != 3)) {
48 std::string err_msg =
49 op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
50 MS_LOG(ERROR) << err_msg;
51 RETURN_STATUS_SYNTAX_ERROR(err_msg);
52 }
53 // Note that fill_value need to be in range [0, 255],
54 // but we omit the check since its type is uint8_t
55 return Status::OK();
56 }
57
ValidateVectorColorAttribute(const std::string & op_name,const std::string & attr_name,const std::vector<float> & attr,const std::vector<float> & range)58 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
59 const std::vector<float> &attr, const std::vector<float> &range) {
60 if (attr.empty() || attr.size() > 2) {
61 std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
62 MS_LOG(ERROR) << err_msg;
63 RETURN_STATUS_SYNTAX_ERROR(err_msg);
64 }
65 for (auto &attr_val : attr) {
66 RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
67 }
68 constexpr size_t attr_size_two = 2;
69 if (attr.size() == attr_size_two && (attr[0] > attr[1])) {
70 std::string err_msg = op_name + ":" + attr_name +
71 " lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
72 ", ub: " + std::to_string(attr[1]);
73 MS_LOG(ERROR) << err_msg;
74 RETURN_STATUS_SYNTAX_ERROR(err_msg);
75 }
76
77 return Status::OK();
78 }
79
ValidateVectorMeanStd(const std::string & op_name,const std::vector<float> & mean,const std::vector<float> & std)80 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
81 const std::vector<float> &std) {
82 if (mean.size() == 0) {
83 std::string err_msg = op_name + ": mean expecting non-empty vector";
84 MS_LOG(ERROR) << err_msg;
85 RETURN_STATUS_SYNTAX_ERROR(err_msg);
86 }
87 if (std.size() == 0) {
88 std::string err_msg = op_name + ": std expecting non-empty vector";
89 MS_LOG(ERROR) << err_msg;
90 RETURN_STATUS_SYNTAX_ERROR(err_msg);
91 }
92 if (mean.size() != std.size()) {
93 std::string err_msg = op_name + ": mean and std vectors are expected to be of the same size";
94 MS_LOG(ERROR) << err_msg;
95 RETURN_STATUS_SYNTAX_ERROR(err_msg);
96 }
97 // check std/mean value
98 for (int32_t i = 0; i < std.size(); ++i) {
99 RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
100 RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
101 }
102
103 return Status::OK();
104 }
105
ValidateVectorOdd(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & value)106 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value) {
107 for (int i = 0; i < value.size(); i++) {
108 if (value[i] % 2 != 1) {
109 std::string err_msg = op_name + ":" + vec_name + " must be odd value, got: " + vec_name + "[" +
110 std::to_string(i) + "]=" + std::to_string(value[i]);
111 MS_LOG(ERROR) << err_msg;
112 return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
113 }
114 }
115 return Status::OK();
116 }
117
ValidateVectorPadding(const std::string & op_name,const std::vector<int32_t> & padding)118 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
119 if (padding.empty() || padding.size() == 3 || padding.size() > 4) {
120 std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
121 MS_LOG(ERROR) << err_msg;
122 RETURN_STATUS_SYNTAX_ERROR(err_msg);
123 }
124 for (const auto &pad_val : padding) {
125 RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
126 }
127
128 return Status::OK();
129 }
130
ValidateVectorPositive(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & vec)131 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
132 const std::vector<int32_t> &vec) {
133 for (const auto &vec_val : vec) {
134 RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
135 }
136
137 return Status::OK();
138 }
139
ValidateVectorNonNegative(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & vec)140 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
141 const std::vector<int32_t> &vec) {
142 for (const auto &vec_val : vec) {
143 RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
144 }
145
146 return Status::OK();
147 }
148
ValidateVectorSigma(const std::string & op_name,const std::vector<float> & sigma)149 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma) {
150 if (sigma.empty() || sigma.size() > 2) {
151 std::string err_msg = op_name + ": sigma expecting size 2, got sigma.size(): " + std::to_string(sigma.size());
152 MS_LOG(ERROR) << err_msg;
153 RETURN_STATUS_SYNTAX_ERROR(err_msg);
154 }
155 for (const auto &sigma_val : sigma) {
156 RETURN_IF_NOT_OK(ValidateScalar(op_name, "sigma", sigma_val, {0}, false));
157 }
158
159 return Status::OK();
160 }
161
ValidateVectorSize(const std::string & op_name,const std::vector<int32_t> & size)162 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
163 if (size.empty() || size.size() > 2) {
164 std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
165 MS_LOG(ERROR) << err_msg;
166 RETURN_STATUS_SYNTAX_ERROR(err_msg);
167 }
168 for (const auto &size_val : size) {
169 RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
170 }
171
172 return Status::OK();
173 }
174
ValidateVectorScale(const std::string & op_name,const std::vector<float> & scale)175 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
176 if (scale.size() != 2) {
177 std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
178 MS_LOG(ERROR) << err_msg;
179 RETURN_STATUS_SYNTAX_ERROR(err_msg);
180 }
181 RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
182 RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
183 if (scale[1] < scale[0]) {
184 std::string err_msg = op_name + ": scale must be in the format of (min, max).";
185 MS_LOG(ERROR) << op_name + ": scale must be in the format of (min, max), but got: " << scale;
186 RETURN_STATUS_SYNTAX_ERROR(err_msg);
187 }
188
189 return Status::OK();
190 }
191
ValidateVectorRatio(const std::string & op_name,const std::vector<float> & ratio)192 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
193 if (ratio.size() != 2) {
194 std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
195 MS_LOG(ERROR) << err_msg;
196 RETURN_STATUS_SYNTAX_ERROR(err_msg);
197 }
198 RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[0], {0}, true));
199 RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[1], {0}, true));
200 if (ratio[1] < ratio[0]) {
201 std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
202 MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;
203 RETURN_STATUS_SYNTAX_ERROR(err_msg);
204 }
205
206 return Status::OK();
207 }
208
ValidateVectorTransforms(const std::string & op_name,const std::vector<std::shared_ptr<TensorOperation>> & transforms)209 Status ValidateVectorTransforms(const std::string &op_name,
210 const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
211 if (transforms.empty()) {
212 std::string err_msg = op_name + ": transform list must not be empty.";
213 MS_LOG(ERROR) << err_msg;
214 RETURN_STATUS_SYNTAX_ERROR(err_msg);
215 }
216 for (int32_t i = 0; i < transforms.size(); ++i) {
217 if (transforms[i] == nullptr) {
218 std::string err_msg =
219 op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
220 MS_LOG(ERROR) << err_msg;
221 RETURN_STATUS_SYNTAX_ERROR(err_msg);
222 } else {
223 RETURN_IF_NOT_OK(transforms[i]->ValidateParams());
224 }
225 }
226
227 return Status::OK();
228 }
229
CmpFloat(const float a,const float b,float epsilon)230 bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
231 } // namespace dataset
232 } // namespace mindspore
233