• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
19 
20 #include <limits>
21 #include <memory>
22 #include <string>
23 
24 #include <nlohmann/json.hpp>
25 
26 #include "minddata/dataset/core/tensor.h"
27 #include "minddata/dataset/util/status.h"
28 
29 namespace mindspore {
30 namespace dataset {
31 // validator Parameter in json file
ValidateParamInJson(const nlohmann::json & json_obj,const std::string & param_name,const std::string & operator_name)32 inline Status ValidateParamInJson(const nlohmann::json &json_obj, const std::string &param_name,
33                                   const std::string &operator_name) {
34   if (json_obj.find(param_name) == json_obj.end()) {
35     std::string err_msg = "Failed to find key '" + param_name + "' in " + operator_name +
36                           "' JSON file or input dict, check input content of deserialize().";
37     RETURN_STATUS_UNEXPECTED(err_msg);
38   }
39   return Status::OK();
40 }
41 
42 inline Status ValidateTensorShape(const std::string &op_name, bool cond, const std::string &expected_shape = "",
43                                   const std::string &actual_dim = "") {
44   if (!cond) {
45     std::string err_msg = op_name + ": the shape of input tensor does not match the requirement of operator.";
46     if (expected_shape != "") {
47       err_msg += " Expecting tensor in shape of " + expected_shape + ".";
48     }
49     if (actual_dim != "") {
50       err_msg += " But got tensor with dimension " + actual_dim + ".";
51     }
52     RETURN_STATUS_UNEXPECTED(err_msg);
53   }
54   return Status::OK();
55 }
56 
57 inline Status ValidateLowRank(const std::string &op_name, const std::shared_ptr<Tensor> &input, dsize_t threshold = 0,
58                               const std::string &expected_shape = "") {
59   dsize_t dim = input->shape().Size();
60   return ValidateTensorShape(op_name, dim >= threshold, expected_shape, std::to_string(dim));
61 }
62 
63 inline Status ValidateTensorType(const std::string &op_name, bool cond, const std::string &expected_type = "",
64                                  const std::string &actual_type = "") {
65   if (!cond) {
66     std::string err_msg = op_name + ": the data type of input tensor does not match the requirement of operator.";
67     if (expected_type != "") {
68       err_msg += " Expecting tensor in type of " + expected_type + ".";
69     }
70     if (actual_type != "") {
71       err_msg += " But got type " + actual_type + ".";
72     }
73     RETURN_STATUS_UNEXPECTED(err_msg);
74   }
75   return Status::OK();
76 }
77 
ValidateTensorNumeric(const std::string & op_name,const std::shared_ptr<Tensor> & input)78 inline Status ValidateTensorNumeric(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
79   return ValidateTensorType(op_name, input->type().IsNumeric(), "[int, float, double]", input->type().ToString());
80 }
81 
ValidateTensorFloat(const std::string & op_name,const std::shared_ptr<Tensor> & input)82 inline Status ValidateTensorFloat(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
83   return ValidateTensorType(op_name, input->type().IsFloat(), "[float, double]", input->type().ToString());
84 }
85 
86 template <typename T>
ValidateEqual(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)87 inline Status ValidateEqual(const std::string &op_name, const std::string &param_name, T param_value,
88                             const std::string &other_name, T other_value) {
89   if (param_value != other_value) {
90     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be equal to '" + other_name +
91                           "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
92                           " " + std::to_string(other_value) + ".";
93     RETURN_STATUS_UNEXPECTED(err_msg);
94   }
95   return Status::OK();
96 }
97 
98 template <typename T>
ValidateNotEqual(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)99 inline Status ValidateNotEqual(const std::string &op_name, const std::string &param_name, T param_value,
100                                const std::string &other_name, T other_value) {
101   if (param_value == other_value) {
102     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' can not be equal to '" + other_name +
103                           "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
104                           " " + std::to_string(other_value) + ".";
105     RETURN_STATUS_UNEXPECTED(err_msg);
106   }
107   return Status::OK();
108 }
109 
110 template <typename T>
ValidateGreaterThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)111 inline Status ValidateGreaterThan(const std::string &op_name, const std::string &param_name, T param_value,
112                                   const std::string &other_name, T other_value) {
113   if (param_value <= other_value) {
114     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be greater than '" + other_name +
115                           "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
116                           " " + std::to_string(other_value) + ".";
117     RETURN_STATUS_UNEXPECTED(err_msg);
118   }
119   return Status::OK();
120 }
121 
122 template <typename T>
ValidateLessThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)123 inline Status ValidateLessThan(const std::string &op_name, const std::string &param_name, T param_value,
124                                const std::string &other_name, T other_value) {
125   if (param_value >= other_value) {
126     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be less than '" + other_name +
127                           "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
128                           " " + std::to_string(other_value) + ".";
129     RETURN_STATUS_UNEXPECTED(err_msg);
130   }
131   return Status::OK();
132 }
133 
134 template <typename T>
ValidateNoGreaterThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)135 inline Status ValidateNoGreaterThan(const std::string &op_name, const std::string &param_name, T param_value,
136                                     const std::string &other_name, T other_value) {
137   if (param_value > other_value) {
138     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no greater than '" +
139                           other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " +
140                           other_name + " " + std::to_string(other_value) + ".";
141     RETURN_STATUS_UNEXPECTED(err_msg);
142   }
143   return Status::OK();
144 }
145 
146 template <typename T>
ValidateNoLessThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)147 inline Status ValidateNoLessThan(const std::string &op_name, const std::string &param_name, T param_value,
148                                  const std::string &other_name, T other_value) {
149   if (param_value < other_value) {
150     std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no less than '" + other_name +
151                           "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
152                           " " + std::to_string(other_value) + ".";
153     RETURN_STATUS_UNEXPECTED(err_msg);
154   }
155   return Status::OK();
156 }
157 
158 template <typename T>
ValidatePositive(const std::string & op_name,const std::string & param_name,T param_value)159 inline Status ValidatePositive(const std::string &op_name, const std::string &param_name, T param_value) {
160   if (param_value <= 0) {
161     std::string err_msg = op_name + ": invalid parameter, '" + param_name +
162                           "' should be positive, but got: " + std::to_string(param_value) + ".";
163     RETURN_STATUS_UNEXPECTED(err_msg);
164   }
165   return Status::OK();
166 }
167 
168 template <typename T>
ValidateNegative(const std::string & op_name,const std::string & param_name,T param_value)169 inline Status ValidateNegative(const std::string &op_name, const std::string &param_name, T param_value) {
170   if (param_value >= 0) {
171     std::string err_msg = op_name + ": invalid parameter, '" + param_name +
172                           "' should be negative, but got: " + std::to_string(param_value) + ".";
173     RETURN_STATUS_UNEXPECTED(err_msg);
174   }
175   return Status::OK();
176 }
177 
178 template <typename T>
ValidateNonPositive(const std::string & op_name,const std::string & param_name,T param_value)179 inline Status ValidateNonPositive(const std::string &op_name, const std::string &param_name, T param_value) {
180   if (param_value > 0) {
181     std::string err_msg = op_name + ": invalid parameter, '" + param_name +
182                           "' should be non positive, but got: " + std::to_string(param_value) + ".";
183     RETURN_STATUS_UNEXPECTED(err_msg);
184   }
185   return Status::OK();
186 }
187 
188 template <typename T>
ValidateNonNegative(const std::string & op_name,const std::string & param_name,T param_value)189 inline Status ValidateNonNegative(const std::string &op_name, const std::string &param_name, T param_value) {
190   if (param_value < 0) {
191     std::string err_msg = op_name + ": invalid parameter, '" + param_name +
192                           "' should be non negative, but got: " + std::to_string(param_value) + ".";
193     RETURN_STATUS_UNEXPECTED(err_msg);
194   }
195   return Status::OK();
196 }
197 
DataTypeSetToString(const std::set<uint8_t> & valid_dtype)198 inline std::string DataTypeSetToString(const std::set<uint8_t> &valid_dtype) {
199   std::string init;
200   std::string err_msg =
201     std::accumulate(valid_dtype.begin(), valid_dtype.end(), init, [](const std::string &str, uint8_t dtype) {
202       if (str.empty()) {
203         return DataType(DataType::Type(dtype)).ToString();
204       } else {
205         return str + ", " + DataType(DataType::Type(dtype)).ToString();
206       }
207     });
208   return "(" + err_msg + ")";
209 }
210 
211 template <typename T>
NumberSetToString(const std::set<T> & valid_value)212 std::string NumberSetToString(const std::set<T> &valid_value) {
213   std::string init;
214   std::string err_msg =
215     std::accumulate(valid_value.begin(), valid_value.end(), init, [](const std::string &str, T value) {
216       if (str.empty()) {
217         return std::to_string(value);
218       } else {
219         return str + ", " + std::to_string(value);
220       }
221     });
222   return "(" + err_msg + ")";
223 }
224 }  // namespace dataset
225 }  // namespace mindspore
226 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
227