• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 
23 namespace mindspore {
24 namespace dataset {
SequentialSamplerRT(int64_t start_index,int64_t num_samples,int64_t samples_per_tensor)25 SequentialSamplerRT::SequentialSamplerRT(int64_t start_index, int64_t num_samples, int64_t samples_per_tensor)
26     : SamplerRT(num_samples, samples_per_tensor),
27       current_index_(start_index),
28       start_index_(start_index),
29       index_produced_(0) {}
30 
GetNextSample(TensorRow * out)31 Status SequentialSamplerRT::GetNextSample(TensorRow *out) {
32   RETURN_UNEXPECTED_IF_NULL(out);
33   if (index_produced_ > num_samples_) {
34     RETURN_STATUS_UNEXPECTED(
35       "[Internal ERROR] Sampler index must be less than or equal to num_samples(total rows in dataset), but got:" +
36       std::to_string(index_produced_) + ", num_samples_: " + std::to_string(num_samples_));
37   } else if (index_produced_ == num_samples_) {
38     (*out) = TensorRow(TensorRow::kFlagEOE);
39   } else {
40     if (HasChildSampler()) {
41       RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_));
42     }
43 
44     std::shared_ptr<Tensor> sampleIds;
45 
46     // Compute how many ids are left to pack, and pack this amount into a new Tensor.  Respect the setting for
47     // samples per Tensor though.
48     int64_t remaining_ids = num_samples_ - index_produced_;
49     int64_t num_elements = std::min(remaining_ids, samples_per_tensor_);
50 
51     RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements));
52     auto idPtr = sampleIds->begin<int64_t>();
53     for (int64_t i = 0; i < num_elements; i++) {
54       int64_t sampled_id = current_index_;
55       if (HasChildSampler()) {
56         RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id));
57       }
58 
59       *idPtr = sampled_id;
60       current_index_++;  // Move the current id to the next one in the sequence
61       ++idPtr;
62     }
63 
64     index_produced_ += num_elements;  // Count the packed ids towards our overall sample count
65 
66     (*out) = {sampleIds};
67   }
68   return Status::OK();
69 }
70 
InitSampler()71 Status SequentialSamplerRT::InitSampler() {
72   if (is_initialized) {
73     return Status::OK();
74   }
75   CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0,
76                                "Invalid parameter, start_index must be greater than or equal to 0, but got " +
77                                  std::to_string(start_index_) + ".\n");
78   CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_ || (num_rows_ == 0 && start_index_ == 0),
79                                "Invalid parameter, start_index must be less than num_rows, but got start_index: " +
80                                  std::to_string(start_index_) + ", num_rows: " + std::to_string(num_rows_) + ".\n");
81   CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0,
82                                "Invalid parameter, num_samples must be greater than or equal to 0, but got " +
83                                  std::to_string(num_samples_) + ".\n");
84   // Adjust the num_samples count based on the range of ids we are sequencing.  If num_samples is 0, we sample
85   // the entire set.  If it's non-zero, we will implicitly cap the amount sampled based on available data.
86   int64_t available_row_count = num_rows_ - start_index_;
87   if (num_samples_ == 0 || num_samples_ > available_row_count) {
88     num_samples_ = available_row_count;
89   }
90   CHECK_FAIL_RETURN_UNEXPECTED((num_samples_ > 0 && samples_per_tensor_ > 0) || num_samples_ == 0,
91                                "Invalid parameter, samples_per_tensor(num_samplers) must be greater than 0, but got " +
92                                  std::to_string(samples_per_tensor_));
93   samples_per_tensor_ = samples_per_tensor_ > num_samples_ ? num_samples_ : samples_per_tensor_;
94 
95   is_initialized = true;
96   return Status::OK();
97 }
98 
ResetSampler(const bool failover_reset)99 Status SequentialSamplerRT::ResetSampler(const bool failover_reset) {
100   CHECK_FAIL_RETURN_UNEXPECTED(failover_reset || index_produced_ == num_samples_,
101                                "[Internal ERROR] ResetSampler() called early or late.");
102   current_index_ = start_index_;
103   index_produced_ = 0;
104 
105   if (HasChildSampler()) {
106     RETURN_IF_NOT_OK(child_[0]->ResetSampler(failover_reset));
107   }
108 
109   return Status::OK();
110 }
111 
CalculateNumSamples(int64_t num_rows)112 int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) {
113   // Holds the number of rows available for Sequential sampler. It can be the rows passed from its child sampler or the
114   // num_rows from the dataset
115   int64_t child_num_rows = num_rows;
116   if (!child_.empty()) {
117     child_num_rows = child_[0]->CalculateNumSamples(num_rows);
118     // return -1 if child_num_rows is undetermined
119     if (child_num_rows == -1) {
120       return child_num_rows;
121     }
122   }
123   int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows;
124   // For this sampler we need to take start_index into account. Because for example in the case we are given n rows
125   // and start_index != 0 and num_samples >= n then we can't return all the n rows.
126   if (child_num_rows - start_index_ <= 0) {
127     return 0;
128   }
129   if (child_num_rows - start_index_ < num_samples) {
130     num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_;
131   }
132   return num_samples;
133 }
134 
SamplerPrint(std::ostream & out,bool show_all) const135 void SequentialSamplerRT::SamplerPrint(std::ostream &out, bool show_all) const {
136   out << "\nSampler: SequentialSampler";
137   if (show_all) {
138     // Call the super class for displaying any common detailed info
139     SamplerRT::SamplerPrint(out, show_all);
140     // Then add our own info
141     out << "\nStart index: " << start_index_;
142   }
143 }
144 
to_json(nlohmann::json * out_json)145 Status SequentialSamplerRT::to_json(nlohmann::json *out_json) {
146   RETURN_UNEXPECTED_IF_NULL(out_json);
147   nlohmann::json args;
148   RETURN_IF_NOT_OK(SamplerRT::to_json(&args));
149   args["sampler_name"] = "SequentialSampler";
150   args["start_index"] = start_index_;
151   *out_json = args;
152   return Status::OK();
153 }
154 }  // namespace dataset
155 }  // namespace mindspore
156