1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ 18 19 #include <limits> 20 #include <memory> 21 #include <vector> 22 23 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 24 25 namespace mindspore { 26 namespace dataset { 27 /// Samples elements from a given list of indices. 28 class SubsetSamplerRT : public SamplerRT { 29 public: 30 /// Constructor. 31 /// \param indices List of indices. 32 /// \param num_samples The number of elements to sample. 0 for the full amount. 33 /// \param samples_per_tensor The number of ids we draw on each call to GetNextSample(). 34 /// When samples_per_tensor=0, GetNextSample() will draw all the sample ids and return them at once. 35 SubsetSamplerRT(const std::vector<int64_t> &indices, int64_t num_samples, 36 std::int64_t samples_per_tensor = std::numeric_limits<int64_t>::max()); 37 38 /// Destructor. 39 ~SubsetSamplerRT() = default; 40 41 /// Initialize the sampler. 42 /// \return Status 43 Status InitSampler() override; 44 45 /// Reset the internal variable to the initial state and reshuffle the indices. 46 /// \return Status 47 Status ResetSampler() override; 48 49 /// Get the sample ids. 50 /// \param[out] TensorRow where the sample ids will be placed. 51 /// @note the sample ids (int64_t) will be placed in one Tensor 52 Status GetNextSample(TensorRow *out) override; 53 54 /// Printer for debugging purposes. 55 /// \param out - output stream to write to 56 /// \param show_all - bool to show detailed vs summary 57 void SamplerPrint(std::ostream &out, bool show_all) const override; 58 59 /// \brief Get the arguments of node 60 /// \param[out] out_json JSON string of all attributes 61 /// \return Status of the function 62 Status to_json(nlohmann::json *out_json) override; 63 64 /// Calculate num samples. Unlike GetNumSamples, it is not a getter and doesn't necessarily return the value of 65 /// num_samples_ 66 /// \param num_rows the size of the dataset this sampler will be applied to. 67 /// \return number of samples 68 int64_t CalculateNumSamples(int64_t num_rows) override; 69 70 protected: 71 /// A list of indices (already randomized in constructor). 72 std::vector<int64_t> indices_; 73 74 private: 75 /// Current sample id. 76 int64_t sample_id_; 77 }; 78 } // namespace dataset 79 } // namespace mindspore 80 81 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ 82