• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "minddata/dataset/core/tensor.h"
25 #include "minddata/dataset/kernels/ir/tensor_operation.h"
26 #include "minddata/dataset/util/status.h"
27 
28 namespace mindspore {
29 namespace dataset {
30 // Helper function to validate probability
31 Status ValidateProbability(const std::string &op_name, const double probability);
32 
33 // Helper function to positive int scalar
34 Status ValidateIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
35 
36 // Helper function to positive float scalar
37 Status ValidateFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
38 
39 // Helper function to non-negative float scalar
40 Status ValidateFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
41 
42 // Helper function to validate scalar
43 template <typename T>
44 Status ValidateScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
45                       const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false) {
46   if (range.empty() || range.size() > 2) {
47     std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
48     MS_LOG(ERROR) << err_msg;
49     return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
50   }
51   if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
52     std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
53     std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
54                           ", got: " + std::to_string(scalar);
55     MS_LOG(ERROR) << err_msg;
56     return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
57   }
58   if (range.size() == 2) {
59     if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
60       std::string left_bracket = left_open_interval ? "(" : "[";
61       std::string right_bracket = right_open_interval ? ")" : "]";
62       std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
63                             std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
64                             ", got: " + std::to_string(scalar);
65       MS_LOG(ERROR) << err_msg;
66       return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
67     }
68   }
69   return Status::OK();
70 }
71 
72 // Helper function to validate color attribute
73 Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
74                                     const std::vector<float> &attr, const std::vector<float> &range);
75 
76 // Helper function to validate fill value
77 Status ValidateVectorFillvalue(const std::string &op_name, const std::vector<uint8_t> &fill_value);
78 
79 // Helper function to validate mean/std value
80 Status ValidateVectorMeanStd(const std::string &op_name, const std::vector<float> &mean, const std::vector<float> &std);
81 
82 // Helper function to validate odd value
83 Status ValidateVectorOdd(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &value);
84 
85 // Helper function to validate padding
86 Status ValidateVectorPadding(const std::string &op_name, const std::vector<int32_t> &padding);
87 
88 // Helper function to validate positive value
89 Status ValidateVectorPositive(const std::string &op_name, const std::string &vec_name, const std::vector<int32_t> &vec);
90 
91 // Helper function to validate non-negative value
92 Status ValidateVectorNonNegative(const std::string &op_name, const std::string &vec_name,
93                                  const std::vector<int32_t> &vec);
94 
95 // Helper function to validate size of sigma
96 Status ValidateVectorSigma(const std::string &op_name, const std::vector<float> &sigma);
97 
98 // Helper function to validate size of size
99 Status ValidateVectorSize(const std::string &op_name, const std::vector<int32_t> &size);
100 
101 // Helper function to validate scale
102 Status ValidateVectorScale(const std::string &op_name, const std::vector<float> &scale);
103 
104 // Helper function to validate ratio
105 Status ValidateVectorRatio(const std::string &op_name, const std::vector<float> &ratio);
106 
107 // Helper function to validate transforms
108 Status ValidateVectorTransforms(const std::string &op_name,
109                                 const std::vector<std::shared_ptr<TensorOperation>> &transforms);
110 
111 // Helper function to compare float value
112 bool CmpFloat(const float a, const float b, float epsilon = 0.0000000001f);
113 }  // namespace dataset
114 }  // namespace mindspore
115 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IR_VALIDATORS_H_
116