1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
19
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/kernels/ir/tensor_operation.h"
26 #include "minddata/dataset/util/status.h"
27
28 namespace mindspore {
29 namespace dataset {
30 // Helper function to validate probability
31 Status ValidateProbability(const std::string &op_name, double probability);
32
33 // Helper function to positive int scalar
34 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
35
36 // Helper function to positive int scalar
37 Status ValidateIntScalarNonNegative(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
38
39 // Helper function to positive float scalar
40 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
41
42 // Helper function to non-negative float scalar
43 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
44
45 // Helper function to validate scalar
46 template <typename T>
47 Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
48 const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
49 const size_t kRangeSize = 2;
50 if (range.empty() || range.size() > kRangeSize) {
51 std::string err_msg = op_name + ": expecting range size 1 or 2, but got: " + std::to_string(range.size());
52 MS_LOG(ERROR) << err_msg;
53 RETURN_SYNTAX_ERROR(err_msg);
54 }
55 if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
56 std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
57 std::string err_msg = op_name + ": '" + scalar_name + "' must be" + interval_description +
58 std::to_string(range[0]) + ", got: " + std::to_string(scalar);
59 MS_LOG(ERROR) << err_msg;
60 RETURN_SYNTAX_ERROR(err_msg);
61 }
62 if (range.size() == kRangeSize) {
63 if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
64 std::string left_bracket = left_open_interval ? "(" : "[";
65 std::string right_bracket = right_open_interval ? ")" : "]";
66 std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
67 std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
68 ", got: " + std::to_string(scalar);
69 MS_LOG(ERROR) << err_msg;
70 RETURN_SYNTAX_ERROR(err_msg);
71 }
72 }
73 return Status::OK();
74 }
75
76 // Helper function to validate enum
77 template <typename T>
ValidateEnum(const std::string & op_name,const std::string & enum_name,const T enumeration,const std::vector<T> & enum_list)78 Status ValidateEnum(const std::string &op_name, const std::string &enum_name, const T enumeration,
79 const std::vector<T> &enum_list) {
80 auto existed = std::find(enum_list.begin(), enum_list.end(), enumeration);
81 std::string err_msg = op_name + ": Invalid " + enum_name + ", check input value of enum.";
82 if (existed != enum_list.end()) {
83 return Status::OK();
84 }
85 RETURN_SYNTAX_ERROR(err_msg);
86 }
87
88 // Helper function to validate color attribute
89 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
90 const std::vector<float> &attr, const std::vector<float> &range);
91
92 // Helper function to validate fill value
93 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
94
95 // Helper function to validate mean/std value
96 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
97
98 // Helper function to validate odd value
99 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value);
100
101 // Helper function to validate padding
102 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
103
104 // Helper function to validate positive value
105 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
106
107 // Helper function to validate non-negative value
108 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
109 const std::vector<int32_t> &vec);
110
111 // Helper function to validate size of sigma
112 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma);
113
114 // Helper function to validate size of size
115 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
116
117 // Helper function to validate scale
118 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
119
120 // Helper function to validate ratio
121 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
122
123 // Helper function to validate transforms
124 Status ValidateVectorTransforms(const std::string &op_name,
125 const std::vector<std::shared_ptr<TensorOperation>> &transforms);
126
127 // Helper function to compare float value
128 bool CmpFloat(float a, float b, float epsilon = 0.0000000001F);
129 } // namespace dataset
130 } // namespace mindspore
131 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
132