• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/kernels/ir/tensor_operation.h"
26 #include "minddata/dataset/util/status.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 // Helper function to validate probability
31 Status ValidateProbability(const std::string &op_name, double probability);
32 
33 // Helper function to positive int scalar
34 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
35 
36 // Helper function to positive int scalar
37 Status ValidateIntScalarNonNegative(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
38 
39 // Helper function to positive float scalar
40 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
41 
42 // Helper function to non-negative float scalar
43 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
44 
45 // Helper function to validate scalar
46 template <typename T>
47 Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
48                       const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
49   const size_t kRangeSize = 2;
50   if (range.empty() || range.size() > kRangeSize) {
51     std::string err_msg = op_name + ": expecting range size 1 or 2, but got: " + std::to_string(range.size());
52     MS_LOG(ERROR) << err_msg;
53     RETURN_SYNTAX_ERROR(err_msg);
54   }
55   if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
56     std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
57     std::string err_msg = op_name + ": '" + scalar_name + "' must be" + interval_description +
58                           std::to_string(range[0]) + ", got: " + std::to_string(scalar);
59     MS_LOG(ERROR) << err_msg;
60     RETURN_SYNTAX_ERROR(err_msg);
61   }
62   if (range.size() == kRangeSize) {
63     if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
64       std::string left_bracket = left_open_interval ? "(" : "[";
65       std::string right_bracket = right_open_interval ? ")" : "]";
66       std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
67                             std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
68                             ", got: " + std::to_string(scalar);
69       MS_LOG(ERROR) << err_msg;
70       RETURN_SYNTAX_ERROR(err_msg);
71     }
72   }
73   return Status::OK();
74 }
75 
76 // Helper function to validate enum
77 template <typename T>
ValidateEnum(const std::string & op_name,const std::string & enum_name,const T enumeration,const std::vector<T> & enum_list)78 Status ValidateEnum(const std::string &op_name, const std::string &enum_name, const T enumeration,
79                     const std::vector<T> &enum_list) {
80   auto existed = std::find(enum_list.begin(), enum_list.end(), enumeration);
81   std::string err_msg = op_name + ": Invalid " + enum_name + ", check input value of enum.";
82   if (existed != enum_list.end()) {
83     return Status::OK();
84   }
85   RETURN_SYNTAX_ERROR(err_msg);
86 }
87 
88 // Helper function to validate color attribute
89 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
90                                     const std::vector<float> &attr, const std::vector<float> &range);
91 
92 // Helper function to validate fill value
93 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
94 
95 // Helper function to validate mean/std value
96 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
97 
98 // Helper function to validate odd value
99 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value);
100 
101 // Helper function to validate padding
102 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
103 
104 // Helper function to validate positive value
105 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
106 
107 // Helper function to validate non-negative value
108 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
109                                  const std::vector<int32_t> &vec);
110 
111 // Helper function to validate size of sigma
112 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma);
113 
114 // Helper function to validate size of size
115 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
116 
117 // Helper function to validate scale
118 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
119 
120 // Helper function to validate ratio
121 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
122 
123 // Helper function to validate transforms
124 Status ValidateVectorTransforms(const std::string &op_name,
125                                 const std::vector<std::shared_ptr<TensorOperation>> &transforms);
126 
127 // Helper function to compare float value
128 bool CmpFloat(float a, float b, float epsilon = 0.0000000001F);
129 }  // namespace dataset
130 }  // namespace mindspore
131 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
132