1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
17
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
RandomSelectSubpolicyOp(const std::vector<Subpolicy> & policy)24 RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
25 : policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {}
26
Compute(const TensorRow & input,TensorRow * output)27 Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
28 IO_CHECK_VECTOR(input, output);
29 TensorRow in_row = input;
30 size_t rand_num = rand_int_(random_generator_);
31 CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(),
32 "RandomSelectSubpolicy: "
33 "get rand number failed:" +
34 std::to_string(rand_num));
35 for (auto &sub : policy_[rand_num]) {
36 if (rand_double_(random_generator_) <= sub.second) {
37 RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
38 in_row = std::move(*output);
39 }
40 }
41 *output = std::move(in_row);
42 return Status::OK();
43 }
44
NumInput()45 uint32_t RandomSelectSubpolicyOp::NumInput() {
46 uint32_t num_in = policy_.front().front().first->NumInput();
47 for (auto &sub : policy_) {
48 for (auto &p : sub) {
49 if (num_in != p.first->NumInput()) {
50 MS_LOG(WARNING) << "Unable to determine numInput.";
51 return 0;
52 }
53 }
54 }
55 return num_in;
56 }
57
NumOutput()58 uint32_t RandomSelectSubpolicyOp::NumOutput() {
59 uint32_t num_out = policy_.front().front().first->NumOutput();
60 for (auto &sub : policy_) {
61 for (auto &p : sub) {
62 if (num_out != p.first->NumOutput()) {
63 MS_LOG(WARNING) << "Unable to determine numInput.";
64 return 0;
65 }
66 }
67 }
68 return num_out;
69 }
70
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)71 Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
72 outputs.clear();
73 outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
74 return Status::OK();
75 }
76
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)77 Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
78 RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
79 for (auto &sub : policy_) {
80 for (auto &p : sub) {
81 std::vector<DataType> tmp_types;
82 RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
83 if (outputs != tmp_types) {
84 outputs.clear();
85 outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
86 return Status::OK();
87 }
88 }
89 }
90 return Status::OK();
91 }
92 } // namespace dataset
93 } // namespace mindspore
94