1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "minddata/dataset/core/tensor.h" 25 #include "minddata/dataset/kernels/ir/tensor_operation.h" 26 #include "minddata/dataset/util/status.h" 27 28 namespace mindspore { 29 namespace dataset { 30 // Helper function to validate probability 31 Status ValidateProbability(const std::string &op_name, const double probability); 32 33 // Helper function to positive int scalar 34 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar); 35 36 // Helper function to positive float scalar 37 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar); 38 39 // Helper function to non-negative float scalar 40 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar); 41 42 // Helper function to validate scalar 43 template <typename T> 44 Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar, 45 const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) { 46 if (range.empty() || range.size() > 2) { 47 std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size()); 48 MS_LOG(ERROR) << err_msg; 49 return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); 50 } 51 if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) { 52 std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to "; 53 std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) + 54 ", got: " + std::to_string(scalar); 55 MS_LOG(ERROR) << err_msg; 56 return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); 57 } 58 if (range.size() == 2) { 59 if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) { 60 std::string left_bracket = left_open_interval ? "(" : "["; 61 std::string right_bracket = right_open_interval ? ")" : "]"; 62 std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket + 63 std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket + 64 ", got: " + std::to_string(scalar); 65 MS_LOG(ERROR) << err_msg; 66 return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); 67 } 68 } 69 return Status::OK(); 70 } 71 72 // Helper function to validate color attribute 73 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name, 74 const std::vector<float> &attr, const std::vector<float> &range); 75 76 // Helper function to validate fill value 77 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value); 78 79 // Helper function to validate mean/std value 80 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std); 81 82 // Helper function to validate odd value 83 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value); 84 85 // Helper function to validate padding 86 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding); 87 88 // Helper function to validate positive value 89 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec); 90 91 // Helper function to validate non-negative value 92 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name, 93 const std::vector<int32_t> &vec); 94 95 // Helper function to validate size of sigma 96 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma); 97 98 // Helper function to validate size of size 99 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size); 100 101 // Helper function to validate scale 102 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale); 103 104 // Helper function to validate ratio 105 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio); 106 107 // Helper function to validate transforms 108 Status ValidateVectorTransforms(const std::string &op_name, 109 const std::vector<std::shared_ptr<TensorOperation>> &transforms); 110 111 // Helper function to compare float value 112 bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f); 113 } // namespace dataset 114 } // namespace mindspore 115 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_ 116