• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 
18 #include "minddata/dataset/kernels/ir/vision/random_rotation_ir.h"
19 
20 #ifndef ENABLE_ANDROID
21 #include "minddata/dataset/kernels/image/random_rotation_op.h"
22 #endif
23 
24 #include "minddata/dataset/kernels/ir/validators.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 namespace vision {
29 #ifndef ENABLE_ANDROID
30 constexpr size_t dimension_zero = 0;
31 constexpr size_t dimension_one = 1;
32 constexpr size_t dimension_two = 2;
33 constexpr size_t size_one = 1;
34 constexpr size_t size_two = 2;
35 constexpr size_t size_three = 3;
36 
37 // Function to create RandomRotationOperation.
RandomRotationOperation(const std::vector<float> & degrees,InterpolationMode resample,bool expand,const std::vector<float> & center,const std::vector<uint8_t> & fill_value)38 RandomRotationOperation::RandomRotationOperation(const std::vector<float> &degrees, InterpolationMode resample,
39                                                  bool expand, const std::vector<float> &center,
40                                                  const std::vector<uint8_t> &fill_value)
41     : TensorOperation(true),
42       degrees_(degrees),
43       interpolation_mode_(resample),
44       expand_(expand),
45       center_(center),
46       fill_value_(fill_value) {}
47 
48 RandomRotationOperation::~RandomRotationOperation() = default;
49 
Name() const50 std::string RandomRotationOperation::Name() const { return kRandomRotationOperation; }
51 
ValidateParams()52 Status RandomRotationOperation::ValidateParams() {
53   // degrees
54   if (degrees_.size() != size_two && degrees_.size() != size_one) {
55     std::string err_msg =
56       "RandomRotation: degrees must be a vector of one or two values, got: " + std::to_string(degrees_.size());
57     MS_LOG(ERROR) << "RandomRotation: degrees must be a vector of one or two values, got: " << degrees_;
58     RETURN_STATUS_SYNTAX_ERROR(err_msg);
59   }
60   if ((degrees_.size() == size_two) && (degrees_[dimension_one] < degrees_[dimension_zero])) {
61     std::string err_msg = "RandomRotation: degrees must be in the format of (min, max), got: (" +
62                           std::to_string(degrees_[dimension_zero]) + ", " + std::to_string(degrees_[dimension_one]) +
63                           ")";
64     MS_LOG(ERROR) << err_msg;
65     RETURN_STATUS_SYNTAX_ERROR(err_msg);
66   } else if ((degrees_.size() == size_one) && (degrees_[dimension_zero] < 0)) {
67     std::string err_msg =
68       "RandomRotation: if degrees only has one value, it must be greater than or equal to 0, got: " +
69       std::to_string(degrees_[dimension_zero]);
70     MS_LOG(ERROR) << err_msg;
71     RETURN_STATUS_SYNTAX_ERROR(err_msg);
72   }
73   // center
74   if (center_.size() != 0 && center_.size() != size_two) {
75     std::string err_msg =
76       "RandomRotation: center must be a vector of two values or empty, got: " + std::to_string(center_.size());
77     MS_LOG(ERROR) << err_msg;
78     RETURN_STATUS_SYNTAX_ERROR(err_msg);
79   }
80   // fill_value
81   RETURN_IF_NOT_OK(ValidateVectorFillvalue("RandomRotation", fill_value_));
82   // interpolation
83   if (interpolation_mode_ != InterpolationMode::kLinear &&
84       interpolation_mode_ != InterpolationMode::kNearestNeighbour && interpolation_mode_ != InterpolationMode::kCubic &&
85       interpolation_mode_ != InterpolationMode::kArea) {
86     std::string err_msg = "RandomRotation: Invalid InterpolationMode, check input value of enum.";
87     MS_LOG(ERROR) << err_msg;
88     RETURN_STATUS_SYNTAX_ERROR(err_msg);
89   }
90   return Status::OK();
91 }
92 
Build()93 std::shared_ptr<TensorOp> RandomRotationOperation::Build() {
94   float start_degree, end_degree;
95   if (degrees_.size() == size_one) {
96     start_degree = -degrees_[dimension_zero];
97     end_degree = degrees_[dimension_zero];
98   } else if (degrees_.size() == size_two) {
99     start_degree = degrees_[dimension_zero];
100     end_degree = degrees_[dimension_one];
101   }
102 
103   uint8_t fill_r, fill_g, fill_b;
104   fill_r = fill_value_[dimension_zero];
105   fill_g = fill_value_[dimension_zero];
106   fill_b = fill_value_[dimension_zero];
107 
108   if (fill_value_.size() == size_three) {
109     fill_r = fill_value_[dimension_zero];
110     fill_g = fill_value_[dimension_one];
111     fill_b = fill_value_[dimension_two];
112   }
113 
114   std::shared_ptr<RandomRotationOp> tensor_op = std::make_shared<RandomRotationOp>(
115     start_degree, end_degree, interpolation_mode_, expand_, center_, fill_r, fill_g, fill_b);
116   return tensor_op;
117 }
118 
to_json(nlohmann::json * out_json)119 Status RandomRotationOperation::to_json(nlohmann::json *out_json) {
120   nlohmann::json args;
121   args["degrees"] = degrees_;
122   args["resample"] = interpolation_mode_;
123   args["expand"] = expand_;
124   args["center"] = center_;
125   args["fill_value"] = fill_value_;
126   *out_json = args;
127   return Status::OK();
128 }
129 
from_json(nlohmann::json op_params,std::shared_ptr<TensorOperation> * operation)130 Status RandomRotationOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
131   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("degrees") != op_params.end(), "Failed to find degrees");
132   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("resample") != op_params.end(), "Failed to find resample");
133   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("expand") != op_params.end(), "Failed to find expand");
134   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("center") != op_params.end(), "Failed to find center");
135   CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("fill_value") != op_params.end(), "Failed to find fill_value");
136   std::vector<float> degrees = op_params["degrees"];
137   InterpolationMode resample = static_cast<InterpolationMode>(op_params["resample"]);
138   bool expand = op_params["expand"];
139   std::vector<float> center = op_params["center"];
140   std::vector<uint8_t> fill_value = op_params["fill_value"];
141   *operation = std::make_shared<vision::RandomRotationOperation>(degrees, resample, expand, center, fill_value);
142   return Status::OK();
143 }
144 
145 #endif
146 }  // namespace vision
147 }  // namespace dataset
148 }  // namespace mindspore
149