• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/image/random_select_subpolicy_op.h"
17 
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 
Compute(const TensorRow & input,TensorRow * output)25 Status RandomSelectSubpolicyOp::Compute(const TensorRow &input, TensorRow *output) {
26   TensorRow in_row = input;
27   size_t rand_num = rand_int_(gen_);
28   CHECK_FAIL_RETURN_UNEXPECTED(rand_num < policy_.size(),
29                                "RandomSelectSubpolicy: "
30                                "get rand number failed:" +
31                                  std::to_string(rand_num));
32   for (auto &sub : policy_[rand_num]) {
33     if (rand_double_(gen_) <= sub.second) {
34       RETURN_IF_NOT_OK(sub.first->Compute(in_row, output));
35       in_row = std::move(*output);
36     }
37   }
38   *output = std::move(in_row);
39   return Status::OK();
40 }
41 
NumInput()42 uint32_t RandomSelectSubpolicyOp::NumInput() {
43   uint32_t num_in = policy_.front().front().first->NumInput();
44   for (auto &sub : policy_) {
45     for (auto &p : sub) {
46       if (num_in != p.first->NumInput()) {
47         MS_LOG(WARNING) << "Unable to determine numInput.";
48         return 0;
49       }
50     }
51   }
52   return num_in;
53 }
54 
NumOutput()55 uint32_t RandomSelectSubpolicyOp::NumOutput() {
56   uint32_t num_out = policy_.front().front().first->NumOutput();
57   for (auto &sub : policy_) {
58     for (auto &p : sub) {
59       if (num_out != p.first->NumOutput()) {
60         MS_LOG(WARNING) << "Unable to determine numInput.";
61         return 0;
62       }
63     }
64   }
65   return num_out;
66 }
67 
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)68 Status RandomSelectSubpolicyOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
69   outputs.clear();
70   outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
71   return Status::OK();
72 }
73 
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)74 Status RandomSelectSubpolicyOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
75   RETURN_IF_NOT_OK(policy_.front().front().first->OutputType(inputs, outputs));
76   for (auto &sub : policy_) {
77     for (auto &p : sub) {
78       std::vector<DataType> tmp_types;
79       RETURN_IF_NOT_OK(p.first->OutputType(inputs, tmp_types));
80       if (outputs != tmp_types) {
81         outputs.clear();
82         outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
83         return Status::OK();
84       }
85     }
86   }
87   return Status::OK();
88 }
RandomSelectSubpolicyOp(const std::vector<Subpolicy> & policy)89 RandomSelectSubpolicyOp::RandomSelectSubpolicyOp(const std::vector<Subpolicy> &policy)
90     : gen_(GetSeed()), policy_(policy), rand_int_(0, policy.size() - 1), rand_double_(0, 1) {
91   if (policy_.empty()) {
92     MS_LOG(ERROR) << "RandomSelectSubpolicy: policy in RandomSelectSubpolicyOp is empty.";
93   }
94   is_deterministic_ = false;
95 }
96 
97 }  // namespace dataset
98 }  // namespace mindspore
99