1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17
18 #include "minddata/dataset/kernels/ir/vision/random_rotation_ir.h"
19
20 #ifndef ENABLE_ANDROID
21 #include "minddata/dataset/kernels/image/random_rotation_op.h"
22 #endif
23
24 #include "minddata/dataset/kernels/ir/validators.h"
25
26 namespace mindspore {
27 namespace dataset {
28 namespace vision {
29 #ifndef ENABLE_ANDROID
30 constexpr size_t dimension_zero = 0;
31 constexpr size_t dimension_one = 1;
32 constexpr size_t dimension_two = 2;
33 constexpr size_t size_one = 1;
34 constexpr size_t size_two = 2;
35 constexpr size_t size_three = 3;
36
37 // Function to create RandomRotationOperation.
RandomRotationOperation(const std::vector<float> & degrees,InterpolationMode resample,bool expand,const std::vector<float> & center,const std::vector<uint8_t> & fill_value)38 RandomRotationOperation::RandomRotationOperation(const std::vector<float> °rees, InterpolationMode resample,
39 bool expand, const std::vector<float> ¢er,
40 const std::vector<uint8_t> &fill_value)
41 : TensorOperation(true),
42 degrees_(degrees),
43 interpolation_mode_(resample),
44 expand_(expand),
45 center_(center),
46 fill_value_(fill_value) {}
47
48 RandomRotationOperation::~RandomRotationOperation() = default;
49
Name() const50 std::string RandomRotationOperation::Name() const { return kRandomRotationOperation; }
51
ValidateParams()52 Status RandomRotationOperation::ValidateParams() {
53 // degrees
54 if (degrees_.size() != size_two && degrees_.size() != size_one) {
55 std::string err_msg =
56 "RandomRotation: degrees must be a vector of one or two values, got: " + std::to_string(degrees_.size());
57 MS_LOG(ERROR) << "RandomRotation: degrees must be a vector of one or two values, got: " << degrees_;
58 RETURN_STATUS_SYNTAX_ERROR(err_msg);
59 }
60 if ((degrees_.size() == size_two) && (degrees_[dimension_one] < degrees_[dimension_zero])) {
61 std::string err_msg = "RandomRotation: degrees must be in the format of (min, max), got: (" +
62 std::to_string(degrees_[dimension_zero]) + ", " + std::to_string(degrees_[dimension_one]) +
63 ")";
64 MS_LOG(ERROR) << err_msg;
65 RETURN_STATUS_SYNTAX_ERROR(err_msg);
66 } else if ((degrees_.size() == size_one) && (degrees_[dimension_zero] < 0)) {
67 std::string err_msg =
68 "RandomRotation: if degrees only has one value, it must be greater than or equal to 0, got: " +
69 std::to_string(degrees_[dimension_zero]);
70 MS_LOG(ERROR) << err_msg;
71 RETURN_STATUS_SYNTAX_ERROR(err_msg);
72 }
73 // center
74 if (center_.size() != 0 && center_.size() != size_two) {
75 std::string err_msg =
76 "RandomRotation: center must be a vector of two values or empty, got: " + std::to_string(center_.size());
77 MS_LOG(ERROR) << err_msg;
78 RETURN_STATUS_SYNTAX_ERROR(err_msg);
79 }
80 // fill_value
81 RETURN_IF_NOT_OK(ValidateVectorFillvalue("RandomRotation", fill_value_));
82 // interpolation
83 if (interpolation_mode_ != InterpolationMode::kLinear &&
84 interpolation_mode_ != InterpolationMode::kNearestNeighbour && interpolation_mode_ != InterpolationMode::kCubic &&
85 interpolation_mode_ != InterpolationMode::kArea) {
86 std::string err_msg = "RandomRotation: Invalid InterpolationMode, check input value of enum.";
87 MS_LOG(ERROR) << err_msg;
88 RETURN_STATUS_SYNTAX_ERROR(err_msg);
89 }
90 return Status::OK();
91 }
92
Build()93 std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
94 float start_degree, end_degree;
95 if (degrees_.size() == size_one) {
96 start_degree = -degrees_[dimension_zero];
97 end_degree = degrees_[dimension_zero];
98 } else if (degrees_.size() == size_two) {
99 start_degree = degrees_[dimension_zero];
100 end_degree = degrees_[dimension_one];
101 }
102
103 uint8_t fill_r, fill_g, fill_b;
104 fill_r = fill_value_[dimension_zero];
105 fill_g = fill_value_[dimension_zero];
106 fill_b = fill_value_[dimension_zero];
107
108 if (fill_value_.size() == size_three) {
109 fill_r = fill_value_[dimension_zero];
110 fill_g = fill_value_[dimension_one];
111 fill_b = fill_value_[dimension_two];
112 }
113
114 std::shared_ptr<RandomRotationOp> tensor_op = std::make_shared<RandomRotationOp>(
115 start_degree, end_degree, interpolation_mode_, expand_, center_, fill_r, fill_g, fill_b);
116 return tensor_op;
117 }
118
to_json(nlohmann::json * out_json)119 Status RandomRotationOperation::to_json(nlohmann::json *out_json) {
120 nlohmann::json args;
121 args["degrees"] = degrees_;
122 args["resample"] = interpolation_mode_;
123 args["expand"] = expand_;
124 args["center"] = center_;
125 args["fill_value"] = fill_value_;
126 *out_json = args;
127 return Status::OK();
128 }
129
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)130 Status RandomRotationOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
131 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("degrees") != op_params.end(), "Failed to find degrees");
132 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("resample") != op_params.end(), "Failed to find resample");
133 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("expand") != op_params.end(), "Failed to find expand");
134 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("center") != op_params.end(), "Failed to find center");
135 CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("fill_value") != op_params.end(), "Failed to find fill_value");
136 std::vector<float> degrees = op_params["degrees"];
137 InterpolationMode resample = static_cast<InterpolationMode>(op_params["resample"]);
138 bool expand = op_params["expand"];
139 std::vector<float> center = op_params["center"];
140 std::vector<uint8_t> fill_value = op_params["fill_value"];
141 *operation = std::make_shared<vision::RandomRotationOperation>(degrees, resample, expand, center, fill_value);
142 return Status::OK();
143 }
144
145 #endif
146 } // namespace vision
147 } // namespace dataset
148 } // namespace mindspore
149