• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/data/random_choice_op.h"
17 
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/kernels/tensor_op.h"
20 #include "minddata/dataset/util/status.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 
NumInput()25 uint32_t RandomChoiceOp::NumInput() {
26   uint32_t num_input = ops_.front()->NumInput();
27   for (auto &op : ops_) {
28     uint32_t cur_num = op->NumInput();
29     if (num_input != cur_num && cur_num > 0) {
30       MS_LOG(WARNING) << "Unable to determine Num of Input, ops in RandomChoice don't take the same number of input.";
31       return 0;
32     }
33   }
34   return num_input;
35 }
36 
NumOutput()37 uint32_t RandomChoiceOp::NumOutput() {
38   uint32_t num_output = ops_.front()->NumOutput();
39   for (auto &op : ops_) {
40     uint32_t cur_num = op->NumOutput();
41     if (num_output != cur_num) {
42       MS_LOG(WARNING) << "Unable to determine NumOutput, ops in RandomChoice don't have the same number of output.";
43       return 0;
44     }
45   }
46   return num_output;
47 }
48 
OutputShape(const std::vector<TensorShape> & inputs,std::vector<TensorShape> & outputs)49 Status RandomChoiceOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
50   RETURN_IF_NOT_OK(ops_.front()->OutputShape(inputs, outputs));
51   for (auto &op : ops_) {
52     std::vector<TensorShape> out_shapes;
53     RETURN_IF_NOT_OK(op->OutputShape(inputs, out_shapes));
54     if (outputs != out_shapes) {
55       MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorShape.";
56       outputs.clear();
57       outputs.resize(NumOutput(), TensorShape::CreateUnknownRankShape());
58       return Status::OK();
59     }
60   }
61   return Status::OK();
62 }
63 
OutputType(const std::vector<DataType> & inputs,std::vector<DataType> & outputs)64 Status RandomChoiceOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
65   RETURN_IF_NOT_OK(ops_.front()->OutputType(inputs, outputs));
66   for (auto &op : ops_) {
67     std::vector<DataType> out_types;
68     RETURN_IF_NOT_OK(op->OutputType(inputs, out_types));
69     if (outputs != out_types) {
70       MS_LOG(WARNING) << "TensorOp in RandomChoice don't return the same tensorType.";
71       outputs.clear();
72       outputs.resize(NumOutput(), DataType(DataType::DE_UNKNOWN));
73       return Status::OK();
74     }
75   }
76   return Status::OK();
77 }
78 
Compute(const TensorRow & input,TensorRow * output)79 Status RandomChoiceOp::Compute(const TensorRow &input, TensorRow *output) {
80   size_t rand_num = rand_int_(gen_);
81   CHECK_FAIL_RETURN_UNEXPECTED(rand_num < ops_.size(), "invalid rand_num:" + std::to_string(rand_num));
82   RETURN_IF_NOT_OK(ops_[rand_num]->Compute(input, output));
83   return Status::OK();
84 }
RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> & ops)85 RandomChoiceOp::RandomChoiceOp(const std::vector<std::shared_ptr<TensorOp>> &ops)
86     : ops_(ops), gen_(GetSeed()), rand_int_(0, ops.size() - 1) {
87   if (ops_.empty()) {
88     MS_LOG(ERROR) << "op_list in RandomChoiceOp is empty.";
89   }
90   if (ops_.size() == 1) {
91     MS_LOG(WARNING) << "op_list has only 1 op, this op would be picked every time.";
92   }
93   is_deterministic_ = false;
94 }
95 }  // namespace dataset
96 }  // namespace mindspore
97