• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
17 
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
RandomSelectSubpolicyOp(const std::vector<Subpolicy> & policy)24 RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
25     : policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {}
26 
Compute(const TensorRow & input,TensorRow * output)27 Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
28   IO_CHECK_VECTOR(input, output);
29   TensorRow in_row = input;
30   size_t rand_num = rand_int_(random_generator_);
31   CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(),
32                                "RandomSelectSubpolicy: "
33                                "get rand number failed:" +
34                                  std::to_string(rand_num));
35   for (auto &sub : policy_[rand_num]) {
36     if (rand_double_(random_generator_) <= sub.second) {
37       RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
38       in_row = std::move(*output);
39     }
40   }
41   *output = std::move(in_row);
42   return Status::OK();
43 }
44 
NumInput()45 uint32_t RandomSelectSubpolicyOp::NumInput() {
46   uint32_t num_in = policy_.front().front().first->NumInput();
47   for (auto &sub : policy_) {
48     for (auto &p : sub) {
49       if (num_in != p.first->NumInput()) {
50         MS_LOG(WARNING) << "Unable to determine numInput.";
51         return 0;
52       }
53     }
54   }
55   return num_in;
56 }
57 
NumOutput()58 uint32_t RandomSelectSubpolicyOp::NumOutput() {
59   uint32_t num_out = policy_.front().front().first->NumOutput();
60   for (auto &sub : policy_) {
61     for (auto &p : sub) {
62       if (num_out != p.first->NumOutput()) {
63         MS_LOG(WARNING) << "Unable to determine numInput.";
64         return 0;
65       }
66     }
67   }
68   return num_out;
69 }
70 
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)71 Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
72   outputs.clear();
73   outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
74   return Status::OK();
75 }
76 
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)77 Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
78   RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
79   for (auto &sub : policy_) {
80     for (auto &p : sub) {
81       std::vector<DataType> tmp_types;
82       RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
83       if (outputs != tmp_types) {
84         outputs.clear();
85         outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
86         return Status::OK();
87       }
88     }
89   }
90   return Status::OK();
91 }
92 }  // namespace dataset
93 }  // namespace mindspore
94