1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ 18 19 #include <limits> 20 #include <memory> 21 #include <vector> 22 23 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 24 #include "minddata/dataset/engine/datasetops/source/sampler/subset_sampler.h" 25 26 namespace mindspore { 27 namespace dataset { 28 /// Randomly samples elements from a given list of indices, without replacement. 29 class SubsetRandomSamplerRT : public SubsetSamplerRT { 30 public: 31 /// Constructor. 32 /// \param indices List of indices from where we will randomly draw samples. 33 /// \param num_samples The number of samples to draw. 0 for the full amount. 34 /// \param samples_per_tensor The number of ids we draw on each call to GetNextSample(). 35 /// When samples_per_tensor=0, GetNextSample() will draw all the sample ids and return them at once. 36 SubsetRandomSamplerRT(const std::vector<int64_t> &indices, int64_t num_samples, 37 std::int64_t samples_per_tensor = std::numeric_limits<int64_t>::max()); 38 39 /// Destructor. 40 ~SubsetRandomSamplerRT() = default; 41 42 /// Initialize the sampler. 43 /// \return Status 44 Status InitSampler() override; 45 46 /// Reset the internal variable to the initial state and reshuffle the indices. 47 /// \return Status 48 Status ResetSampler() override; 49 50 /// Printer for debugging purposes. 51 /// \param out - output stream to write to 52 /// \param show_all - bool to show detailed vs summary 53 void SamplerPrint(std::ostream &out, bool show_all) const override; 54 55 /// \brief Get the arguments of node 56 /// \param[out] out_json JSON string of all attributes 57 /// \return Status of the function 58 Status to_json(nlohmann::json *out_json) override; 59 60 private: 61 // A random number generator. 62 std::mt19937 rand_gen_; 63 }; 64 } // namespace dataset 65 } // namespace mindspore 66 67 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ 68