• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/kernels/ir/validators.h"
18 
19 namespace mindspore {
20 namespace dataset {
21 /* ####################################### Validator Functions ############################################ */
ValidateProbability(const std::string & op_name,double probability)22 Status ValidateProbability(const std::string &op_name, double probability) {
23   if (probability < 0.0 || probability > 1.0) {
24     std::string err_msg = op_name + ": probability must be between 0.0 and 1.0, got: " + std::to_string(probability);
25     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
26   }
27 
28   return Status::OK();
29 }
30 
ValidateIntScalarPositive(const std::string & op_name,const std::string & scalar_name,int32_t scalar)31 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
32   RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
33   return Status::OK();
34 }
35 
ValidateFloatScalarPositive(const std::string & op_name,const std::string & scalar_name,float scalar)36 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
37   RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, true));
38   return Status::OK();
39 }
40 
ValidateFloatScalarNonNegative(const std::string & op_name,const std::string & scalar_name,float scalar)41 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
42   RETURN_IF_NOT_OK(ValidateScalar(op_name, scalar_name, scalar, {0}, false));
43   return Status::OK();
44 }
45 
ValidateVectorFillvalue(const std::string & op_name,const std::vector<uint8_t> & fill_value)46 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value) {
47   const size_t kMaxFillValueSize = 3;
48   if (fill_value.empty() || (fill_value.size() != 1 && fill_value.size() != kMaxFillValueSize)) {
49     std::string err_msg =
50       op_name + ": fill_value expecting size 1 or 3, got fill_value.size(): " + std::to_string(fill_value.size());
51     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
52   }
53   // Note that fill_value need to be in range [0, 255],
54   // but we omit the check since its type is uint8_t
55   return Status::OK();
56 }
57 
ValidateVectorColorAttribute(const std::string & op_name,const std::string & attr_name,const std::vector<float> & attr,const std::vector<float> & range)58 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
59                                     const std::vector<float> &attr, const std::vector<float> &range) {
60   const size_t kMaxAttrSize = 2;
61   if (attr.empty() || attr.size() > kMaxAttrSize) {
62     std::string err_msg = op_name + ":" + attr_name + " expecting size 1 or 2, but got: " + std::to_string(attr.size());
63     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
64   }
65   for (auto &attr_val : attr) {
66     RETURN_IF_NOT_OK(ValidateScalar(op_name, attr_name, attr_val, range, false, false));
67   }
68   if (attr.size() == kMaxAttrSize && (attr[0] > attr[1])) {
69     std::string err_msg = op_name + ":" + attr_name +
70                           " lower bound must be less or equal to upper bound, got lb: " + std::to_string(attr[0]) +
71                           ", ub: " + std::to_string(attr[1]);
72     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
73   }
74 
75   return Status::OK();
76 }
77 
ValidateVectorMeanStd(const std::string & op_name,const std::vector<float> & mean,const std::vector<float> & std)78 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean,
79                              const std::vector<float> &std) {
80   if (mean.empty()) {
81     std::string err_msg = op_name + ": mean expecting non-empty vector";
82     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
83   }
84   if (std.empty()) {
85     std::string err_msg = op_name + ": std expecting non-empty vector";
86     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
87   }
88   if (mean.size() != std.size()) {
89     std::string err_msg = op_name + ": mean and std vectors are expected to be of the same size";
90     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
91   }
92   // check std/mean value
93   for (int32_t i = 0; i < std.size(); ++i) {
94     RETURN_IF_NOT_OK(ValidateScalar(op_name, "mean", mean[i], {0.0, 255.0}, false, false));
95     RETURN_IF_NOT_OK(ValidateScalar(op_name, "std", std[i], {0.0, 255.0}, true, false));
96   }
97 
98   return Status::OK();
99 }
100 
ValidateVectorOdd(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & value)101 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value) {
102   constexpr int64_t divided_two = 2;
103   for (int i = 0; i < value.size(); i++) {
104     if (value[i] % divided_two != 1) {
105       std::string err_msg = op_name + ":" + vec_name + " must be odd value, got: " + vec_name + "[" +
106                             std::to_string(i) + "]=" + std::to_string(value[i]);
107       MS_LOG(ERROR) << err_msg;
108       RETURN_SYNTAX_ERROR(err_msg);
109     }
110   }
111   return Status::OK();
112 }
113 
ValidateVectorPadding(const std::string & op_name,const std::vector<int32_t> & padding)114 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding) {
115   const size_t kDefaultPaddingSize = 2;
116   const size_t kMaxPaddingSize = 4;
117   if (padding.size() != 1 && padding.size() != kDefaultPaddingSize && padding.size() != kMaxPaddingSize) {
118     std::string err_msg = op_name + ": padding expecting size 1, 2 or 4, got size: " + std::to_string(padding.size());
119     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
120   }
121   for (const auto &pad_val : padding) {
122     RETURN_IF_NOT_OK(ValidateScalar(op_name, "padding", pad_val, {0, INT_MAX}, false, false));
123   }
124 
125   return Status::OK();
126 }
127 
ValidateVectorPositive(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & vec)128 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name,
129                               const std::vector<int32_t> &vec) {
130   for (const auto &vec_val : vec) {
131     RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, true));
132   }
133 
134   return Status::OK();
135 }
136 
ValidateVectorNonNegative(const std::string & op_name,const std::string & vec_name,const std::vector<int32_t> & vec)137 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
138                                  const std::vector<int32_t> &vec) {
139   for (const auto &vec_val : vec) {
140     RETURN_IF_NOT_OK(ValidateScalar(op_name, vec_name, vec_val, {0}, false));
141   }
142 
143   return Status::OK();
144 }
145 
ValidateVectorSigma(const std::string & op_name,const std::vector<float> & sigma)146 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma) {
147   const size_t kMaxSigmaSize = 2;
148   if (sigma.empty() || sigma.size() > kMaxSigmaSize) {
149     std::string err_msg = op_name + ": sigma expecting size 2, got sigma.size(): " + std::to_string(sigma.size());
150     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
151   }
152   for (const auto &sigma_val : sigma) {
153     RETURN_IF_NOT_OK(ValidateScalar(op_name, "sigma", sigma_val, {0}, false));
154   }
155 
156   return Status::OK();
157 }
158 
ValidateVectorSize(const std::string & op_name,const std::vector<int32_t> & size)159 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size) {
160   const size_t kMaxSizeSize = 2;
161   if (size.empty() || size.size() > kMaxSizeSize) {
162     std::string err_msg = op_name + ": size expecting size 2, got size.size(): " + std::to_string(size.size());
163     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
164   }
165   for (const auto &size_val : size) {
166     RETURN_IF_NOT_OK(ValidateScalar(op_name, "size", size_val, {0, INT_MAX}, true, false));
167   }
168 
169   return Status::OK();
170 }
171 
ValidateVectorScale(const std::string & op_name,const std::vector<float> & scale)172 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale) {
173   const size_t kScaleSize = 2;
174   if (scale.size() != kScaleSize) {
175     std::string err_msg = op_name + ": scale expecting size 2, got scale.size(): " + std::to_string(scale.size());
176     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
177   }
178   RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[0], {0}, false));
179   RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", scale[1], {0}, true));
180   if (scale[1] < scale[0]) {
181     std::string err_msg = op_name + ": scale must be in the format of (min, max), but got: (" +
182                           std::to_string(scale[0]) + ", " + std::to_string(scale[1]) + ").";
183     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
184   }
185 
186   return Status::OK();
187 }
188 
ValidateVectorRatio(const std::string & op_name,const std::vector<float> & ratio)189 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio) {
190   const size_t kRatioSize = 2;
191   if (ratio.size() != kRatioSize) {
192     std::string err_msg = op_name + ": ratio expecting size 2, got ratio.size(): " + std::to_string(ratio.size());
193     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
194   }
195   RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[0], {0}, true));
196   RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[1], {0}, true));
197   if (ratio[1] < ratio[0]) {
198     std::string err_msg = op_name + ": ratio must be in the format of (min, max), but got: (" +
199                           std::to_string(ratio[0]) + ", " + std::to_string(ratio[1]) + ").";
200     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
201   }
202 
203   return Status::OK();
204 }
205 
ValidateVectorTransforms(const std::string & op_name,const std::vector<std::shared_ptr<TensorOperation>> & transforms)206 Status ValidateVectorTransforms(const std::string &op_name,
207                                 const std::vector<std::shared_ptr<TensorOperation>> &transforms) {
208   if (transforms.empty()) {
209     std::string err_msg = op_name + ": transform list must not be empty.";
210     LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
211   }
212   for (int32_t i = 0; i < transforms.size(); ++i) {
213     if (transforms[i] == nullptr) {
214       std::string err_msg =
215         op_name + ": transform ops must not be null, got transform[" + std::to_string(i) + "] == nullptr.";
216       LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
217     } else {
218       RETURN_IF_NOT_OK(transforms[i]->ValidateParams());
219     }
220   }
221 
222   return Status::OK();
223 }
224 
CmpFloat(const float a,const float b,float epsilon)225 bool CmpFloat(const float a, const float b, float epsilon) { return (std::fabs(a - b) < epsilon); }
226 }  // namespace dataset
227 }  // namespace mindspore
228