• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 namespace mindspore {
22 namespace dataset {
SequentialSamplerRT(int64_t start_index,int64_t num_samples,int64_t samples_per_tensor)23 SequentialSamplerRT::SequentialSamplerRT(int64_t start_index, int64_t num_samples, int64_t samples_per_tensor)
24     : SamplerRT(num_samples, samples_per_tensor), current_id_(start_index), start_index_(start_index), id_count_(0) {}
25 
GetNextSample(TensorRow * out)26 Status SequentialSamplerRT::GetNextSample(TensorRow *out) {
27   if (id_count_ > num_samples_) {
28     RETURN_STATUS_UNEXPECTED(
29       "Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
30       std::to_string(id_count_) + ", num_samples_: " + std::to_string(num_samples_));
31   } else if (id_count_ == num_samples_) {
32     (*out) = TensorRow(TensorRow::kFlagEOE);
33   } else {
34     if (HasChildSampler()) {
35       RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
36     }
37 
38     std::shared_ptr<Tensor> sampleIds;
39 
40     // Compute how many ids are left to pack, and pack this amount into a new Tensor.  Respect the setting for
41     // samples per Tensor though.
42     int64_t remaining_ids = num_samples_ - id_count_;
43     int64_t num_elements = std::min(remaining_ids, samples_per_tensor_);
44 
45     RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
46     auto idPtr = sampleIds->begin<int64_t>();
47     for (int64_t i = 0; i < num_elements; i++) {
48       int64_t sampled_id = current_id_;
49       if (HasChildSampler()) {
50         RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
51       }
52 
53       *idPtr = sampled_id;
54       current_id_++;  // Move the current id to the next one in the sequence
55       ++idPtr;
56     }
57 
58     id_count_ += num_elements;  // Count the packed ids towards our overall sample count
59 
60     (*out) = {sampleIds};
61   }
62   return Status::OK();
63 }
64 
InitSampler()65 Status SequentialSamplerRT::InitSampler() {
66   if (is_initialized) {
67     return Status::OK();
68   }
69   CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
70                                "Invalid parameter, start_index must be greater than or equal to 0, but got " +
71                                  std::to_string(start_index_) + ".\n");
72   CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_ || (num_rows_ == 0 && start_index_ == 0),
73                                "Invalid parameter, start_index must be less than num_rows, but got start_index: " +
74                                  std::to_string(start_index_) + ", num_rows: " + std::to_string(num_rows_) + ".\n");
75   CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0,
76                                "Invalid parameter, num_samples must be greater than or equal to 0, but got " +
77                                  std::to_string(num_samples_) + ".\n");
78   // Adjust the num_samples count based on the range of ids we are sequencing.  If num_samples is 0, we sample
79   // the entire set.  If it's non-zero, we will implicitly cap the amount sampled based on available data.
80   int64_t available_row_count = num_rows_ - start_index_;
81   if (num_samples_ == 0 || num_samples_ > available_row_count) {
82     num_samples_ = available_row_count;
83   }
84   CHECK_FAIL_RETURN_UNEXPECTED((num_samples_ > 0 && samples_per_tensor_ > 0) || num_samples_ == 0,
85                                "Invalid parameter, samples_per_tensor(num_samplers) must be greater than 0, but got " +
86                                  std::to_string(samples_per_tensor_));
87   samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_;
88 
89   is_initialized = true;
90   return Status::OK();
91 }
92 
ResetSampler()93 Status SequentialSamplerRT::ResetSampler() {
94   CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late.");
95   current_id_ = start_index_;
96   id_count_ = 0;
97 
98   if (HasChildSampler()) {
99     RETURN_IF_NOT_OK(child_[0]->ResetSampler());
100   }
101 
102   return Status::OK();
103 }
104 
CalculateNumSamples(int64_t num_rows)105 int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
106   // Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
107   // num_rows from the dataset
108   int64_t child_num_rows = num_rows;
109   if (!child_.empty()) {
110     child_num_rows = child_[0]->CalculateNumSamples(num_rows);
111     // return -1 if child_num_rows is undetermined
112     if (child_num_rows == -1) return child_num_rows;
113   }
114   int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
115   // For this sampler we need to take start_index into account. Because for example in the case we are given n rows
116   // and start_index != 0 and num_samples >= n then we can't return all the n rows.
117   if (child_num_rows - start_index_ <= 0) {
118     return 0;
119   }
120   if (child_num_rows - start_index_ < num_samples)
121     num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
122   return num_samples;
123 }
124 
SamplerPrint(std::ostream & out,bool show_all) const125 void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
126   out << "\nSampler: SequentialSampler";
127   if (show_all) {
128     // Call the super class for displaying any common detailed info
129     SamplerRT::SamplerPrint(out, show_all);
130     // Then add our own info
131     out << "\nStart index: " << start_index_;
132   }
133 }
134 
to_json(nlohmann::json * out_json)135 Status SequentialSamplerRT::to_json(nlohmann::json *out_json) {
136   nlohmann::json args;
137   RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
138   args["sampler_name"] = "SequentialSampler";
139   args["start_index"] = start_index_;
140   *out_json = args;
141   return Status::OK();
142 }
143 }  // namespace dataset
144 }  // namespace mindspore
145