• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/data/random_choice_op.h"
17 
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> & ops)24 RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
25     : ops_(ops), rand_int_(0, ops.size() - 1) {}
26 
NumInput()27 uint32_t RandomChoiceOp::NumInput() {
28   uint32_t num_input = ops_.front()->NumInput();
29   for (auto &op : ops_) {
30     uint32_t cur_num = op->NumInput();
31     if (num_input != cur_num && cur_num > 0) {
32       MS_LOG(WARNING) << "Unable to determine Num of Input, ops in RandomChoice don't take the same number of input.";
33       return 0;
34     }
35   }
36   return num_input;
37 }
38 
NumOutput()39 uint32_t RandomChoiceOp::NumOutput() {
40   uint32_t num_output = ops_.front()->NumOutput();
41   for (auto &op : ops_) {
42     uint32_t cur_num = op->NumOutput();
43     if (num_output != cur_num) {
44       MS_LOG(WARNING) << "Unable to determine NumOutput, ops in RandomChoice don't have the same number of output.";
45       return 0;
46     }
47   }
48   return num_output;
49 }
50 
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)51 Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
52   RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
53   for (auto &op : ops_) {
54     std::vector<TensorShape> out_shapes;
55     RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
56     if (outputs != out_shapes) {
57       MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
58       outputs.clear();
59       outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
60       return Status::OK();
61     }
62   }
63   return Status::OK();
64 }
65 
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)66 Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
67   RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
68   for (auto &op : ops_) {
69     std::vector<DataType> out_types;
70     RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
71     if (outputs != out_types) {
72       MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
73       outputs.clear();
74       outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
75       return Status::OK();
76     }
77   }
78   return Status::OK();
79 }
80 
Compute(const TensorRow & input,TensorRow * output)81 Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
82   IO_CHECK_VECTOR(input, output);
83   size_t rand_num = rand_int_(random_generator_);
84   CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
85   RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
86   return Status::OK();
87 }
88 }  // namespace dataset
89 }  // namespace mindspore
90