1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/data/random_choice_op.h"
17
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21
22 namespace mindspore {
23 namespace dataset {
RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> & ops)24 RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
25 : ops_(ops), rand_int_(0, ops.size() - 1) {}
26
NumInput()27 uint32_t RandomChoiceOp::NumInput() {
28 uint32_t num_input = ops_.front()->NumInput();
29 for (auto &op : ops_) {
30 uint32_t cur_num = op->NumInput();
31 if (num_input != cur_num && cur_num > 0) {
32 MS_LOG(WARNING) << "Unable to determine Num of Input, ops in RandomChoice don't take the same number of input.";
33 return 0;
34 }
35 }
36 return num_input;
37 }
38
NumOutput()39 uint32_t RandomChoiceOp::NumOutput() {
40 uint32_t num_output = ops_.front()->NumOutput();
41 for (auto &op : ops_) {
42 uint32_t cur_num = op->NumOutput();
43 if (num_output != cur_num) {
44 MS_LOG(WARNING) << "Unable to determine NumOutput, ops in RandomChoice don't have the same number of output.";
45 return 0;
46 }
47 }
48 return num_output;
49 }
50
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)51 Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
52 RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
53 for (auto &op : ops_) {
54 std::vector<TensorShape> out_shapes;
55 RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
56 if (outputs != out_shapes) {
57 MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
58 outputs.clear();
59 outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
60 return Status::OK();
61 }
62 }
63 return Status::OK();
64 }
65
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)66 Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
67 RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
68 for (auto &op : ops_) {
69 std::vector<DataType> out_types;
70 RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
71 if (outputs != out_types) {
72 MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
73 outputs.clear();
74 outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
75 return Status::OK();
76 }
77 }
78 return Status::OK();
79 }
80
Compute(const TensorRow & input,TensorRow * output)81 Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
82 IO_CHECK_VECTOR(input, output);
83 size_t rand_num = rand_int_(random_generator_);
84 CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
85 RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
86 return Status::OK();
87 }
88 } // namespace dataset
89 } // namespace mindspore
90