1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
17
18 #include <memory>
19 #include <vector>
20
21 namespace mindspore {
22 namespace dataset {
SequentialSamplerRT(int64_t start_index,int64_t num_samples,int64_t samples_per_tensor)23 SequentialSamplerRT::SequentialSamplerRT(int64_t start_index, int64_t num_samples, int64_t samples_per_tensor)
24 : SamplerRT(num_samples, samples_per_tensor), current_id_(start_index), start_index_(start_index), id_count_(0) {}
25
GetNextSample(TensorRow * out)26 Status SequentialSamplerRT::GetNextSample(TensorRow *out) {
27 if (id_count_ > num_samples_) {
28 RETURN_STATUS_UNEXPECTED(
29 "Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
30 std::to_string(id_count_) + ", num_samples_: " + std::to_string(num_samples_));
31 } else if (id_count_ == num_samples_) {
32 (*out) = TensorRow(TensorRow::kFlagEOE);
33 } else {
34 if (HasChildSampler()) {
35 RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
36 }
37
38 std::shared_ptr<Tensor> sampleIds;
39
40 // Compute how many ids are left to pack, and pack this amount into a new Tensor. Respect the setting for
41 // samples per Tensor though.
42 int64_t remaining_ids = num_samples_ - id_count_;
43 int64_t num_elements = std::min(remaining_ids, samples_per_tensor_);
44
45 RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
46 auto idPtr = sampleIds->begin<int64_t>();
47 for (int64_t i = 0; i < num_elements; i++) {
48 int64_t sampled_id = current_id_;
49 if (HasChildSampler()) {
50 RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
51 }
52
53 *idPtr = sampled_id;
54 current_id_++; // Move the current id to the next one in the sequence
55 ++idPtr;
56 }
57
58 id_count_ += num_elements; // Count the packed ids towards our overall sample count
59
60 (*out) = {sampleIds};
61 }
62 return Status::OK();
63 }
64
InitSampler()65 Status SequentialSamplerRT::InitSampler() {
66 if (is_initialized) {
67 return Status::OK();
68 }
69 CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
70 "Invalid parameter, start_index must be greater than or equal to 0, but got " +
71 std::to_string(start_index_) + ".\n");
72 CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_ || (num_rows_ == 0 && start_index_ == 0),
73 "Invalid parameter, start_index must be less than num_rows, but got start_index: " +
74 std::to_string(start_index_) + ", num_rows: " + std::to_string(num_rows_) + ".\n");
75 CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0,
76 "Invalid parameter, num_samples must be greater than or equal to 0, but got " +
77 std::to_string(num_samples_) + ".\n");
78 // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample
79 // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data.
80 int64_t available_row_count = num_rows_ - start_index_;
81 if (num_samples_ == 0 || num_samples_ > available_row_count) {
82 num_samples_ = available_row_count;
83 }
84 CHECK_FAIL_RETURN_UNEXPECTED((num_samples_ > 0 && samples_per_tensor_ > 0) || num_samples_ == 0,
85 "Invalid parameter, samples_per_tensor(num_samplers) must be greater than 0, but got " +
86 std::to_string(samples_per_tensor_));
87 samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_;
88
89 is_initialized = true;
90 return Status::OK();
91 }
92
ResetSampler()93 Status SequentialSamplerRT::ResetSampler() {
94 CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "[Internal ERROR] Reset() Sampler called early or late.");
95 current_id_ = start_index_;
96 id_count_ = 0;
97
98 if (HasChildSampler()) {
99 RETURN_IF_NOT_OK(child_[0]->ResetSampler());
100 }
101
102 return Status::OK();
103 }
104
CalculateNumSamples(int64_t num_rows)105 int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
106 // Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
107 // num_rows from the dataset
108 int64_t child_num_rows = num_rows;
109 if (!child_.empty()) {
110 child_num_rows = child_[0]->CalculateNumSamples(num_rows);
111 // return -1 if child_num_rows is undetermined
112 if (child_num_rows == -1) return child_num_rows;
113 }
114 int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
115 // For this sampler we need to take start_index into account. Because for example in the case we are given n rows
116 // and start_index != 0 and num_samples >= n then we can't return all the n rows.
117 if (child_num_rows - start_index_ <= 0) {
118 return 0;
119 }
120 if (child_num_rows - start_index_ < num_samples)
121 num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
122 return num_samples;
123 }
124
SamplerPrint(std::ostream & out,bool show_all) const125 void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
126 out << "\nSampler: SequentialSampler";
127 if (show_all) {
128 // Call the super class for displaying any common detailed info
129 SamplerRT::SamplerPrint(out, show_all);
130 // Then add our own info
131 out << "\nStart index: " << start_index_;
132 }
133 }
134
to_json(nlohmann::json * out_json)135 Status SequentialSamplerRT::to_json(nlohmann::json *out_json) {
136 nlohmann::json args;
137 RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
138 args["sampler_name"] = "SequentialSampler";
139 args["start_index"] = start_index_;
140 *out_json = args;
141 return Status::OK();
142 }
143 } // namespace dataset
144 } // namespace mindspore
145