1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/data/random_choice_op.h"
17
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
24
NumInput()25 uint32_t RandomChoiceOp::NumInput() {
26 uint32_t num_input = ops_.front()->NumInput();
27 for (auto &op : ops_) {
28 uint32_t cur_num = op->NumInput();
29 if (num_input != cur_num && cur_num > 0) {
30 MS_LOG(WARNING) << "Unable to determine Num of Input, ops in RandomChoice don't take the same number of input.";
31 return 0;
32 }
33 }
34 return num_input;
35 }
36
NumOutput()37 uint32_t RandomChoiceOp::NumOutput() {
38 uint32_t num_output = ops_.front()->NumOutput();
39 for (auto &op : ops_) {
40 uint32_t cur_num = op->NumOutput();
41 if (num_output != cur_num) {
42 MS_LOG(WARNING) << "Unable to determine NumOutput, ops in RandomChoice don't have the same number of output.";
43 return 0;
44 }
45 }
46 return num_output;
47 }
48
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)49 Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
50 RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
51 for (auto &op : ops_) {
52 std::vector<TensorShape> out_shapes;
53 RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
54 if (outputs != out_shapes) {
55 MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
56 outputs.clear();
57 outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
58 return Status::OK();
59 }
60 }
61 return Status::OK();
62 }
63
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)64 Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
65 RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
66 for (auto &op : ops_) {
67 std::vector<DataType> out_types;
68 RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
69 if (outputs != out_types) {
70 MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
71 outputs.clear();
72 outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
73 return Status::OK();
74 }
75 }
76 return Status::OK();
77 }
78
Compute(const TensorRow & input,TensorRow * output)79 Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
80 size_t rand_num = rand_int_(gen_);
81 CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
82 RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
83 return Status::OK();
84 }
RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> & ops)85 RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
86 : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) {
87 if (ops_.empty()) {
88 MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty.";
89 }
90 if (ops_.size() == 1) {
91 MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
92 }
93 is_deterministic_ = false;
94 }
95 } // namespace dataset
96 } // namespace mindspore
97