1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ 17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ 18 19 #include <limits> 20 #include <memory> 21 #include <random> 22 #include <vector> 23 24 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" 25 26 namespace mindspore { 27 namespace dataset { 28 class DistributedSamplerRT : public SamplerRT { 29 public: 30 /// \brief Constructor 31 /// \param[in] num_shards Total number of shards for the distributed sampler 32 /// \param[in] shard_id Device id of the shard 33 /// \param[in] shuffle Option to shuffle 34 /// \param[in] num_samples The total number of rows in the dataset 35 /// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will 36 /// result in different samples being picked 37 /// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more 38 /// than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler 39 /// \param even_dist The option to indicate whether or not each shard returns the same number of rows. 40 /// This option is not exposed in the python API. Current behavior is that the remainder will always 41 /// be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set, 42 /// even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario. 43 DistributedSamplerRT(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, 44 uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1, 45 bool even_dist = true); 46 47 /// \brief default destructor 48 ~DistributedSamplerRT() = default; 49 50 /// \param TensorRow out 51 /// \param int32_t workerId 52 /// \return Status code 53 Status GetNextSample(TensorRow *out) override; 54 55 /// Init sampler, called by base class or python 56 Status InitSampler() override; 57 58 /// \brief Reset for next epoch. 59 /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline 60 /// \return Status The status code returned 61 Status ResetSampler(const bool failover_reset) override; 62 GetDeviceID()63 int64_t GetDeviceID() { return device_id_; } 64 GetDeviceNum()65 int64_t GetDeviceNum() { return num_devices_; } 66 67 /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers 68 /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows, 69 /// then num_samples_ is not returned at all. 70 /// \param[in] num_rows The total number of rows in the dataset 71 /// \return int64_t Calculated number of samples 72 int64_t CalculateNumSamples(int64_t num_rows) override; 73 74 void SamplerPrint(std::ostream &out, bool show_all) const override; 75 76 /// \brief Get the arguments of node 77 /// \param[out] out_json JSON string of all attributes 78 /// \return Status of the function 79 Status to_json(nlohmann::json *out_json) override; 80 81 private: 82 int64_t cnt_; // number of samples that have already been filled in to Tensor 83 uint32_t seed_; 84 int64_t device_id_; 85 int64_t num_devices_; 86 bool shuffle_; 87 std::mt19937 rnd_; 88 std::vector<int64_t> shuffle_vec_; 89 bool even_dist_; 90 int64_t offset_; 91 bool non_empty_; 92 }; 93 } // namespace dataset 94 } // namespace mindspore 95 96 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ 97