1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/image/random_rotation_op.h"
17
18 #include <random>
19
20 #include "minddata/dataset/core/cv_tensor.h"
21 #include "minddata/dataset/kernels/image/image_utils.h"
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/dataset/util/status.h"
24
25 namespace mindspore {
26 namespace dataset {
27 const std::vector<float> RandomRotationOp::kDefCenter = {};
28 const InterpolationMode RandomRotationOp::kDefInterpolation = InterpolationMode::kNearestNeighbour;
29 const bool RandomRotationOp::kDefExpand = false;
30 const uint8_t RandomRotationOp::kDefFillR = 0;
31 const uint8_t RandomRotationOp::kDefFillG = 0;
32 const uint8_t RandomRotationOp::kDefFillB = 0;
33
34 // constructor
RandomRotationOp(float start_degree,float end_degree,InterpolationMode resample,bool expand,std::vector<float> center,uint8_t fill_r,uint8_t fill_g,uint8_t fill_b)35 RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, InterpolationMode resample, bool expand,
36 std::vector<float> center, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b)
37 : degree_start_(start_degree),
38 degree_end_(end_degree),
39 center_(center),
40 interpolation_(resample),
41 expand_(expand),
42 fill_r_(fill_r),
43 fill_g_(fill_g),
44 fill_b_(fill_b) {
45 rnd_.seed(GetSeed());
46 is_deterministic_ = false;
47 }
48
49 // main function call for random rotation : Generate the random degrees
Compute(const std::shared_ptr<Tensor> & input,std::shared_ptr<Tensor> * output)50 Status RandomRotationOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
51 IO_CHECK(input, output);
52 float random_double = distribution_(rnd_);
53 // get the degree rotation range, mod by 360 because full rotation doesn't affect
54 // the way this op works (uniform distribution)
55 // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number
56 // Note: the range technically is greater than 360 degrees, but will be halved
57 float degree_range = (degree_end_ - degree_start_) / 2;
58 float mid = (degree_end_ + degree_start_) / 2;
59 float degree = mid + random_double * degree_range;
60
61 return Rotate(input, output, center_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_);
62 }
63
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)64 Status RandomRotationOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
65 RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
66 outputs.clear();
67 int32_t outputH = -1, outputW = -1;
68 // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to
69 // <-1,-1[,3]>
70 CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() > 0 && inputs[0].Size() >= 2,
71 "RandomRotationOp: invalid input shape, expected 2D or 3D input, but got input"
72 " dimension is: " +
73 std::to_string(inputs[0].Rank()));
74 if (!expand_) {
75 outputH = inputs[0][0];
76 outputW = inputs[0][1];
77 }
78 TensorShape out = TensorShape{outputH, outputW};
79 if (inputs[0].Rank() == 2) outputs.emplace_back(out);
80 if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2]));
81 if (!outputs.empty()) return Status::OK();
82 return Status(StatusCode::kMDUnexpectedError,
83 "RandomRotation: invalid input shape, expected 2D or 3D input, but got input dimension is:" +
84 std::to_string(inputs[0].Rank()));
85 }
86 } // namespace dataset
87 } // namespace mindspore
88