1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
17
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
24
Compute(const TensorRow & input,TensorRow * output)25 Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
26 TensorRow in_row = input;
27 size_t rand_num = rand_int_(gen_);
28 CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(),
29 "RandomSelectSubpolicy: "
30 "get rand number failed:" +
31 std::to_string(rand_num));
32 for (auto &sub : policy_[rand_num]) {
33 if (rand_double_(gen_) <= sub.second) {
34 RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
35 in_row = std::move(*output);
36 }
37 }
38 *output = std::move(in_row);
39 return Status::OK();
40 }
41
NumInput()42 uint32_t RandomSelectSubpolicyOp::NumInput() {
43 uint32_t num_in = policy_.front().front().first->NumInput();
44 for (auto &sub : policy_) {
45 for (auto &p : sub) {
46 if (num_in != p.first->NumInput()) {
47 MS_LOG(WARNING) << "Unable to determine numInput.";
48 return 0;
49 }
50 }
51 }
52 return num_in;
53 }
54
NumOutput()55 uint32_t RandomSelectSubpolicyOp::NumOutput() {
56 uint32_t num_out = policy_.front().front().first->NumOutput();
57 for (auto &sub : policy_) {
58 for (auto &p : sub) {
59 if (num_out != p.first->NumOutput()) {
60 MS_LOG(WARNING) << "Unable to determine numInput.";
61 return 0;
62 }
63 }
64 }
65 return num_out;
66 }
67
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)68 Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
69 outputs.clear();
70 outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
71 return Status::OK();
72 }
73
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)74 Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
75 RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
76 for (auto &sub : policy_) {
77 for (auto &p : sub) {
78 std::vector<DataType> tmp_types;
79 RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
80 if (outputs != tmp_types) {
81 outputs.clear();
82 outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
83 return Status::OK();
84 }
85 }
86 }
87 return Status::OK();
88 }
RandomSelectSubpolicyOp(const std::vector<Subpolicy> & policy)89 RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
90 : gen_(GetSeed()), policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {
91 if (policy_.empty()) {
92 MS_LOG(ERROR) << "RandomSelectSubpolicy: policy in RandomSelectSubpolicyOp is empty.";
93 }
94 is_deterministic_ = false;
95 }
96
97 } // namespace dataset
98 } // namespace mindspore
99