1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
19
20 #include <limits>
21 #include <memory>
22 #include <string>
23
24 #include <nlohmann/json.hpp>
25
26 #include "minddata/dataset/core/tensor.h"
27 #include "minddata/dataset/util/status.h"
28
29 namespace mindspore {
30 namespace dataset {
31 // validator Parameter in json file
ValidateParamInJson(const nlohmann::json & json_obj,const std::string & param_name,const std::string & operator_name)32 inline Status ValidateParamInJson(const nlohmann::json &json_obj, const std::string ¶m_name,
33 const std::string &operator_name) {
34 if (json_obj.find(param_name) == json_obj.end()) {
35 std::string err_msg = "Failed to find key '" + param_name + "' in " + operator_name +
36 "' JSON file or input dict, check input content of deserialize().";
37 RETURN_STATUS_UNEXPECTED(err_msg);
38 }
39 return Status::OK();
40 }
41
42 inline Status ValidateTensorShape(const std::string &op_name, bool cond, const std::string &expected_shape = "",
43 const std::string &actual_dim = "") {
44 if (!cond) {
45 std::string err_msg = op_name + ": the shape of input tensor does not match the requirement of operator.";
46 if (expected_shape != "") {
47 err_msg += " Expecting tensor in shape of " + expected_shape + ".";
48 }
49 if (actual_dim != "") {
50 err_msg += " But got tensor with dimension " + actual_dim + ".";
51 }
52 RETURN_STATUS_UNEXPECTED(err_msg);
53 }
54 return Status::OK();
55 }
56
57 inline Status ValidateLowRank(const std::string &op_name, const std::shared_ptr<Tensor> &input, dsize_t threshold = 0,
58 const std::string &expected_shape = "") {
59 dsize_t dim = input->shape().Size();
60 return ValidateTensorShape(op_name, dim >= threshold, expected_shape, std::to_string(dim));
61 }
62
63 inline Status ValidateTensorType(const std::string &op_name, bool cond, const std::string &expected_type = "",
64 const std::string &actual_type = "") {
65 if (!cond) {
66 std::string err_msg = op_name + ": the data type of input tensor does not match the requirement of operator.";
67 if (expected_type != "") {
68 err_msg += " Expecting tensor in type of " + expected_type + ".";
69 }
70 if (actual_type != "") {
71 err_msg += " But got type " + actual_type + ".";
72 }
73 RETURN_STATUS_UNEXPECTED(err_msg);
74 }
75 return Status::OK();
76 }
77
ValidateTensorNumeric(const std::string & op_name,const std::shared_ptr<Tensor> & input)78 inline Status ValidateTensorNumeric(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
79 return ValidateTensorType(op_name, input->type().IsNumeric(), "[int, float, double]", input->type().ToString());
80 }
81
ValidateTensorFloat(const std::string & op_name,const std::shared_ptr<Tensor> & input)82 inline Status ValidateTensorFloat(const std::string &op_name, const std::shared_ptr<Tensor> &input) {
83 return ValidateTensorType(op_name, input->type().IsFloat(), "[float, double]", input->type().ToString());
84 }
85
86 template <typename T>
ValidateEqual(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)87 inline Status ValidateEqual(const std::string &op_name, const std::string ¶m_name, T param_value,
88 const std::string &other_name, T other_value) {
89 if (param_value != other_value) {
90 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be equal to '" + other_name +
91 "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
92 " " + std::to_string(other_value) + ".";
93 RETURN_STATUS_UNEXPECTED(err_msg);
94 }
95 return Status::OK();
96 }
97
98 template <typename T>
ValidateNotEqual(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)99 inline Status ValidateNotEqual(const std::string &op_name, const std::string ¶m_name, T param_value,
100 const std::string &other_name, T other_value) {
101 if (param_value == other_value) {
102 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' can not be equal to '" + other_name +
103 "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
104 " " + std::to_string(other_value) + ".";
105 RETURN_STATUS_UNEXPECTED(err_msg);
106 }
107 return Status::OK();
108 }
109
110 template <typename T>
ValidateGreaterThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)111 inline Status ValidateGreaterThan(const std::string &op_name, const std::string ¶m_name, T param_value,
112 const std::string &other_name, T other_value) {
113 if (param_value <= other_value) {
114 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be greater than '" + other_name +
115 "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
116 " " + std::to_string(other_value) + ".";
117 RETURN_STATUS_UNEXPECTED(err_msg);
118 }
119 return Status::OK();
120 }
121
122 template <typename T>
ValidateLessThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)123 inline Status ValidateLessThan(const std::string &op_name, const std::string ¶m_name, T param_value,
124 const std::string &other_name, T other_value) {
125 if (param_value >= other_value) {
126 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be less than '" + other_name +
127 "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
128 " " + std::to_string(other_value) + ".";
129 RETURN_STATUS_UNEXPECTED(err_msg);
130 }
131 return Status::OK();
132 }
133
134 template <typename T>
ValidateNoGreaterThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)135 inline Status ValidateNoGreaterThan(const std::string &op_name, const std::string ¶m_name, T param_value,
136 const std::string &other_name, T other_value) {
137 if (param_value > other_value) {
138 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no greater than '" +
139 other_name + "', but got: " + param_name + " " + std::to_string(param_value) + " while " +
140 other_name + " " + std::to_string(other_value) + ".";
141 RETURN_STATUS_UNEXPECTED(err_msg);
142 }
143 return Status::OK();
144 }
145
146 template <typename T>
ValidateNoLessThan(const std::string & op_name,const std::string & param_name,T param_value,const std::string & other_name,T other_value)147 inline Status ValidateNoLessThan(const std::string &op_name, const std::string ¶m_name, T param_value,
148 const std::string &other_name, T other_value) {
149 if (param_value < other_value) {
150 std::string err_msg = op_name + ": invalid parameter, '" + param_name + "' should be no less than '" + other_name +
151 "', but got: " + param_name + " " + std::to_string(param_value) + " while " + other_name +
152 " " + std::to_string(other_value) + ".";
153 RETURN_STATUS_UNEXPECTED(err_msg);
154 }
155 return Status::OK();
156 }
157
158 template <typename T>
ValidatePositive(const std::string & op_name,const std::string & param_name,T param_value)159 inline Status ValidatePositive(const std::string &op_name, const std::string ¶m_name, T param_value) {
160 if (param_value <= 0) {
161 std::string err_msg = op_name + ": invalid parameter, '" + param_name +
162 "' should be positive, but got: " + std::to_string(param_value) + ".";
163 RETURN_STATUS_UNEXPECTED(err_msg);
164 }
165 return Status::OK();
166 }
167
168 template <typename T>
ValidateNegative(const std::string & op_name,const std::string & param_name,T param_value)169 inline Status ValidateNegative(const std::string &op_name, const std::string ¶m_name, T param_value) {
170 if (param_value >= 0) {
171 std::string err_msg = op_name + ": invalid parameter, '" + param_name +
172 "' should be negative, but got: " + std::to_string(param_value) + ".";
173 RETURN_STATUS_UNEXPECTED(err_msg);
174 }
175 return Status::OK();
176 }
177
178 template <typename T>
ValidateNonPositive(const std::string & op_name,const std::string & param_name,T param_value)179 inline Status ValidateNonPositive(const std::string &op_name, const std::string ¶m_name, T param_value) {
180 if (param_value > 0) {
181 std::string err_msg = op_name + ": invalid parameter, '" + param_name +
182 "' should be non positive, but got: " + std::to_string(param_value) + ".";
183 RETURN_STATUS_UNEXPECTED(err_msg);
184 }
185 return Status::OK();
186 }
187
188 template <typename T>
ValidateNonNegative(const std::string & op_name,const std::string & param_name,T param_value)189 inline Status ValidateNonNegative(const std::string &op_name, const std::string ¶m_name, T param_value) {
190 if (param_value < 0) {
191 std::string err_msg = op_name + ": invalid parameter, '" + param_name +
192 "' should be non negative, but got: " + std::to_string(param_value) + ".";
193 RETURN_STATUS_UNEXPECTED(err_msg);
194 }
195 return Status::OK();
196 }
197
DataTypeSetToString(const std::set<uint8_t> & valid_dtype)198 inline std::string DataTypeSetToString(const std::set<uint8_t> &valid_dtype) {
199 std::string init;
200 std::string err_msg =
201 std::accumulate(valid_dtype.begin(), valid_dtype.end(), init, [](const std::string &str, uint8_t dtype) {
202 if (str.empty()) {
203 return DataType(DataType::Type(dtype)).ToString();
204 } else {
205 return str + ", " + DataType(DataType::Type(dtype)).ToString();
206 }
207 });
208 return "(" + err_msg + ")";
209 }
210
211 template <typename T>
NumberSetToString(const std::set<T> & valid_value)212 std::string NumberSetToString(const std::set<T> &valid_value) {
213 std::string init;
214 std::string err_msg =
215 std::accumulate(valid_value.begin(), valid_value.end(), init, [](const std::string &str, T value) {
216 if (str.empty()) {
217 return std::to_string(value);
218 } else {
219 return str + ", " + std::to_string(value);
220 }
221 });
222 return "(" + err_msg + ")";
223 }
224 } // namespace dataset
225 } // namespace mindspore
226 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_VALIDATORS_H_
227