• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_
18 
19 #include <limits>
20 #include <memory>
21 #include <random>
22 #include <vector>
23 
24 #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
25 
26 namespace mindspore {
27 namespace dataset {
28 class DistributedSamplerRT : public SamplerRT {
29  public:
30   /// \brief Constructor
31   /// \param[in] num_shards Total number of shards for the distributed sampler
32   /// \param[in] shard_id Device id of the shard
33   /// \param[in] shuffle Option to shuffle
34   /// \param[in] num_samples The total number of rows in the dataset
35   /// \param seed Seed parameter to shuffle, default to max unsigned int (different seed in sampler will
36   ///     result in different samples being picked
37   /// \param[in] offset The starting device id where the elements in the dataset are send to, which should be no more
38   ///     than num_dev. The application scenario of this parameter is when the concatdataset is set distributedSampler
39   /// \param even_dist The option to indicate whether or not each shard returns the same number of rows.
40   ///     This option is not exposed in the python API. Current behavior is that the remainder will always
41   ///     be handled by the first n shards, n being the corresponding device id. Please notice that when offset is set,
42   ///     even_dist will be forcibly converted to false for sending rest datasets in concatdataset scenario.
43   DistributedSamplerRT(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
44                        uint32_t seed = std::numeric_limits<uint32_t>::max(), int64_t offset = -1,
45                        bool even_dist = true);
46 
47   /// \brief default destructor
48   ~DistributedSamplerRT() = default;
49 
50   /// \param TensorRow out
51   /// \param int32_t workerId
52   /// \return Status code
53   Status GetNextSample(TensorRow *out) override;
54 
55   /// Init sampler, called by base class or python
56   Status InitSampler() override;
57 
58   /// \brief Reset for next epoch.
59   /// \param[in] failover_reset A boolean to show whether we are resetting the pipeline
60   /// \return Status The status code returned
61   Status ResetSampler(const bool failover_reset) override;
62 
GetDeviceID()63   int64_t GetDeviceID() { return device_id_; }
64 
GetDeviceNum()65   int64_t GetDeviceNum() { return num_devices_; }
66 
67   /// \brief Recursively calls this function on its children to get the actual number of samples on a tree of samplers
68   /// \note This is not a getter for num_samples_. For example, if num_samples_ is 0 or if it's smaller than num_rows,
69   ///     then num_samples_ is not returned at all.
70   /// \param[in] num_rows The total number of rows in the dataset
71   /// \return int64_t Calculated number of samples
72   int64_t CalculateNumSamples(int64_t num_rows) override;
73 
74   void SamplerPrint(std::ostream &out, bool show_all) const override;
75 
76   /// \brief Get the arguments of node
77   /// \param[out] out_json JSON string of all attributes
78   /// \return Status of the function
79   Status to_json(nlohmann::json *out_json) override;
80 
81  private:
82   int64_t cnt_;  // number of samples that have already been filled in to Tensor
83   uint32_t seed_;
84   int64_t device_id_;
85   int64_t num_devices_;
86   bool shuffle_;
87   std::mt19937 rnd_;
88   std::vector<int64_t> shuffle_vec_;
89   bool even_dist_;
90   int64_t offset_;
91   bool non_empty_;
92 };
93 }  // namespace dataset
94 }  // namespace mindspore
95 
96 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_
97