• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
19 #include "minddata/dataset/core/config_manager.h"
20 
21 #ifndef ENABLE_ANDROID
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 #endif
30 
31 namespace mindspore {
32 namespace dataset {
33 // Constructor
DistributedSamplerObj(int64_t num_shards,int64_t shard_id,bool shuffle,int64_t num_samples,uint32_t seed,int64_t offset,bool even_dist)34 DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
35                                              uint32_t seed, int64_t offset, bool even_dist)
36     : num_shards_(num_shards),
37       shard_id_(shard_id),
38       shuffle_(shuffle),
39       num_samples_(num_samples),
40       seed_(seed),
41       offset_(offset),
42       even_dist_(even_dist) {
43   // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
44   // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
45   // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
46   // PreBuildSampler is phased out, this can be cleaned up.
47   GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
48 }
49 
50 // Destructor
51 DistributedSamplerObj::~DistributedSamplerObj() = default;
52 
ValidateParams()53 Status DistributedSamplerObj::ValidateParams() {
54   if (num_shards_ <= 0) {
55     RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
56                              std::to_string(num_shards_));
57   }
58 
59   if (shard_id_ < 0 || shard_id_ >= num_shards_) {
60     RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
61                              "), but got: " + std::to_string(shard_id_));
62   }
63 
64   if (num_samples_ < 0) {
65     RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
66                              std::to_string(num_samples_));
67   }
68 
69   if (offset_ > num_shards_) {
70     RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
71                              std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
72   }
73 
74   return Status::OK();
75 }
76 
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)77 Status DistributedSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
78   // runtime sampler object
79   *sampler = std::make_shared<dataset::DistributedSamplerRT>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
80                                                              offset_, even_dist_);
81   Status s = BuildChildren(sampler);
82   sampler = s.IsOk() ? sampler : nullptr;
83   return s;
84 }
85 
86 #ifndef ENABLE_ANDROID
BuildForMindDataset()87 std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
88   // runtime mindrecord sampler object
89   auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
90                                                                            num_samples_, offset_);
91   return mind_sampler;
92 }
93 #endif
94 
to_json(nlohmann::json * const out_json)95 Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
96   nlohmann::json args;
97   RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
98   args["sampler_name"] = "DistributedSampler";
99   args["num_shards"] = num_shards_;
100   args["shard_id"] = shard_id_;
101   args["shuffle"] = shuffle_;
102   args["seed"] = seed_;
103   args["offset"] = offset_;
104   args["num_samples"] = num_samples_;
105   args["even_dist"] = even_dist_;
106   *out_json = args;
107   return Status::OK();
108 }
109 
110 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)111 Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
112                                         std::shared_ptr<SamplerObj> *sampler) {
113   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
114   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
115   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
116   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
117   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
118   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
119   int64_t num_shards = json_obj["num_shards"];
120   int64_t shard_id = json_obj["shard_id"];
121   bool shuffle = json_obj["shuffle"];
122   uint32_t seed = json_obj["seed"];
123   int64_t offset = json_obj["offset"];
124   bool even_dist = json_obj["even_dist"];
125   *sampler =
126     std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
127   // Run common code in super class to add children samplers
128   RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
129   return Status::OK();
130 }
131 #endif
132 
SamplerCopy()133 std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() {
134   auto sampler =
135     std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_);
136   for (const auto &child : children_) {
137     Status rc = sampler->AddChildSampler(child);
138     if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
139   }
140   return sampler;
141 }
142 
ShardId()143 int64_t DistributedSamplerObj::ShardId() { return shard_id_; }
144 }  // namespace dataset
145 }  // namespace mindspore
146