1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22
23 namespace mindspore {
24 namespace dataset {
SequentialSamplerRT(int64_t start_index,int64_t num_samples,int64_t samples_per_tensor)25 SequentialSamplerRT::SequentialSamplerRT(int64_t start_index, int64_t num_samples, int64_t samples_per_tensor)
26 : SamplerRT(num_samples, samples_per_tensor),
27 current_index_(start_index),
28 start_index_(start_index),
29 index_produced_(0) {}
30
GetNextSample(TensorRow * out)31 Status SequentialSamplerRT::GetNextSample(TensorRow *out) {
32 RETURN_UNEXPECTED_IF_NULL(out);
33 if (index_produced_ > num_samples_) {
34 RETURN_STATUS_UNEXPECTED(
35 "[Internal ERROR] Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
36 std::to_string(index_produced_) + ", num_samples_: " + std::to_string(num_samples_));
37 } else if (index_produced_ == num_samples_) {
38 (*out) = TensorRow(TensorRow::kFlagEOE);
39 } else {
40 if (HasChildSampler()) {
41 RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
42 }
43
44 std::shared_ptr<Tensor> sampleIds;
45
46 // Compute how many ids are left to pack, and pack this amount into a new Tensor. Respect the setting for
47 // samples per Tensor though.
48 int64_t remaining_ids = num_samples_ - index_produced_;
49 int64_t num_elements = std::min(remaining_ids, samples_per_tensor_);
50
51 RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
52 auto idPtr = sampleIds->begin<int64_t>();
53 for (int64_t i = 0; i < num_elements; i++) {
54 int64_t sampled_id = current_index_;
55 if (HasChildSampler()) {
56 RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
57 }
58
59 *idPtr = sampled_id;
60 current_index_++; // Move the current id to the next one in the sequence
61 ++idPtr;
62 }
63
64 index_produced_ += num_elements; // Count the packed ids towards our overall sample count
65
66 (*out) = {sampleIds};
67 }
68 return Status::OK();
69 }
70
InitSampler()71 Status SequentialSamplerRT::InitSampler() {
72 if (is_initialized) {
73 return Status::OK();
74 }
75 CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
76 "Invalid parameter, start_index must be greater than or equal to 0, but got " +
77 std::to_string(start_index_) + ".\n");
78 CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_ || (num_rows_ == 0 && start_index_ == 0),
79 "Invalid parameter, start_index must be less than num_rows, but got start_index: " +
80 std::to_string(start_index_) + ", num_rows: " + std::to_string(num_rows_) + ".\n");
81 CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0,
82 "Invalid parameter, num_samples must be greater than or equal to 0, but got " +
83 std::to_string(num_samples_) + ".\n");
84 // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
85 // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
86 int64_t available_row_count = num_rows_ - start_index_;
87 if (num_samples_ == 0 || num_samples_ > available_row_count) {
88 num_samples_ = available_row_count;
89 }
90 CHECK_FAIL_RETURN_UNEXPECTED((num_samples_ > 0 && samples_per_tensor_ > 0) || num_samples_ == 0,
91 "Invalid parameter, samples_per_tensor(num_samplers) must be greater than 0, but got " +
92 std::to_string(samples_per_tensor_));
93 samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_;
94
95 is_initialized = true;
96 return Status::OK();
97 }
98
ResetSampler(const bool failover_reset)99 Status SequentialSamplerRT::ResetSampler(const bool failover_reset) {
100 CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || index_produced_ == num_samples_,
101 "[Internal ERROR] ResetSampler() called early or late.");
102 current_index_ = start_index_;
103 index_produced_ = 0;
104
105 if (HasChildSampler()) {
106 RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
107 }
108
109 return Status::OK();
110 }
111
CalculateNumSamples(int64_t num_rows)112 int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
113 // Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
114 // num_rows from the dataset
115 int64_t child_num_rows = num_rows;
116 if (!child_.empty()) {
117 child_num_rows = child_[0]->CalculateNumSamples(num_rows);
118 // return -1 if child_num_rows is undetermined
119 if (child_num_rows == -1) {
120 return child_num_rows;
121 }
122 }
123 int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
124 // For this sampler we need to take start_index into account. Because for example in the case we are given n rows
125 // and start_index != 0 and num_samples >= n then we can't return all the n rows.
126 if (child_num_rows - start_index_ <= 0) {
127 return 0;
128 }
129 if (child_num_rows - start_index_ < num_samples) {
130 num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
131 }
132 return num_samples;
133 }
134
SamplerPrint(std::ostream & out,bool show_all) const135 void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
136 out << "\nSampler: SequentialSampler";
137 if (show_all) {
138 // Call the super class for displaying any common detailed info
139 SamplerRT::SamplerPrint(out, show_all);
140 // Then add our own info
141 out << "\nStart index: " << start_index_;
142 }
143 }
144
to_json(nlohmann::json * out_json)145 Status SequentialSamplerRT::to_json(nlohmann::json *out_json) {
146 RETURN_UNEXPECTED_IF_NULL(out_json);
147 nlohmann::json args;
148 RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
149 args["sampler_name"] = "SequentialSampler";
150 args["start_index"] = start_index_;
151 *out_json = args;
152 return Status::OK();
153 }
154 } // namespace dataset
155 } // namespace mindspore
156