• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/image/random_rotation_op.h"
17 
18 #include <random>
19 #include <utility>
20 
21 #include "minddata/dataset/core/cv_tensor.h"
22 #include "minddata/dataset/kernels/image/image_utils.h"
23 #include "minddata/dataset/util/random.h"
24 #include "minddata/dataset/util/status.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 // constructor
RandomRotationOp(float start_degree,float end_degree,InterpolationMode resample,bool expand,std::vector<float> center,uint8_t fill_r,uint8_t fill_g,uint8_t fill_b)29 RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, InterpolationMode resample, bool expand,
30                                    std::vector<float> center, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b)
31     : degree_start_(start_degree),
32       degree_end_(end_degree),
33       center_(std::move(center)),
34       interpolation_(resample),
35       expand_(expand),
36       fill_r_(fill_r),
37       fill_g_(fill_g),
38       fill_b_(fill_b) {}
39 
40 // main function call for random rotation : Generate the random degrees
Compute(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)41 Status RandomRotationOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
42   IO_CHECK(input, output);
43   float random_double = distribution_(random_generator_);
44   // get the degree rotation range, mod by 360 because full rotation doesn't affect
45   // the way this op works (uniform distribution)
46   // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number
47   // Note: the range technically is greater than 360 degrees, but will be halved
48   float degree_range = (degree_end_ - degree_start_) / 2;
49   float mid = (degree_end_ + degree_start_) / 2;
50   float degree = mid + random_double * degree_range;
51 
52   return Rotate(input, output, center_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_);
53 }
54 
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)55 Status RandomRotationOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
56   RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
57   outputs.clear();
58   int32_t outputH = -1;
59   int32_t outputW = -1;
60   constexpr int32_t kDimensionTwo = 2;
61   constexpr int32_t kDimensionThree = 3;
62   // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to
63   // <-1,-1[,3]>
64   CHECK_FAIL_RETURN_UNEXPECTED(!inputs.empty() && inputs[0].Size() >= kDimensionTwo,
65                                "RandomRotationOp: invalid input shape, expected 2D or 3D input, but got input"
66                                " dimension is: " +
67                                  std::to_string(inputs[0].Rank()));
68   if (!expand_) {
69     outputH = static_cast<int32_t>(inputs[0][0]);
70     outputW = static_cast<int32_t>(inputs[0][1]);
71   }
72   TensorShape out = TensorShape{outputH, outputW};
73   if (inputs[0].Rank() == kDimensionTwo) {
74     (void)outputs.emplace_back(out);
75   }
76   if (inputs[0].Rank() == kDimensionThree) {
77     (void)outputs.emplace_back(out.AppendDim(inputs[0][2]));
78   }
79   CHECK_FAIL_RETURN_UNEXPECTED(
80     !outputs.empty(), "RandomRotation: invalid input shape, expected 2D or 3D input, but got input dimension is:" +
81                         std::to_string(inputs[0].Rank()));
82   return Status::OK();
83 }
84 }  // namespace dataset
85 }  // namespace mindspore
86