1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/distributed_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h"
19 #include "minddata/dataset/core/config_manager.h"
20
21 #ifndef ENABLE_ANDROID
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 #endif
30
31 namespace mindspore {
32 namespace dataset {
33 // Constructor
DistributedSamplerObj(int64_t num_shards,int64_t shard_id,bool shuffle,int64_t num_samples,uint32_t seed,int64_t offset,bool even_dist)34 DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
35 uint32_t seed, int64_t offset, bool even_dist)
36 : num_shards_(num_shards),
37 shard_id_(shard_id),
38 shuffle_(shuffle),
39 num_samples_(num_samples),
40 seed_(seed),
41 offset_(offset),
42 even_dist_(even_dist) {
43 // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion
44 // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't
45 // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once
46 // PreBuildSampler is phased out, this can be cleaned up.
47 GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_);
48 }
49
50 // Destructor
51 DistributedSamplerObj::~DistributedSamplerObj() = default;
52
ValidateParams()53 Status DistributedSamplerObj::ValidateParams() {
54 if (num_shards_ <= 0) {
55 RETURN_STATUS_UNEXPECTED("DistributedSampler: num_shards must be greater than 0, but got: " +
56 std::to_string(num_shards_));
57 }
58
59 if (shard_id_ < 0 || shard_id_ >= num_shards_) {
60 RETURN_STATUS_UNEXPECTED("DistributedSampler: shard_id must be in range [0, " + std::to_string(num_shards_) +
61 "), but got: " + std::to_string(shard_id_));
62 }
63
64 if (num_samples_ < 0) {
65 RETURN_STATUS_UNEXPECTED("DistributedSampler: num_samples must be greater than or equal to 0, but got: " +
66 std::to_string(num_samples_));
67 }
68
69 if (offset_ > num_shards_) {
70 RETURN_STATUS_UNEXPECTED("DistributedSampler: offset must be no more than num_shards(" +
71 std::to_string(num_shards_) + "), but got: " + std::to_string(offset_));
72 }
73
74 return Status::OK();
75 }
76
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)77 Status DistributedSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
78 // runtime sampler object
79 *sampler = std::make_shared<dataset::DistributedSamplerRT>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
80 offset_, even_dist_);
81 Status s = BuildChildren(sampler);
82 sampler = s.IsOk() ? sampler : nullptr;
83 return s;
84 }
85
86 #ifndef ENABLE_ANDROID
BuildForMindDataset()87 std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
88 // runtime mindrecord sampler object
89 auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
90 num_samples_, offset_);
91 return mind_sampler;
92 }
93 #endif
94
to_json(nlohmann::json * const out_json)95 Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
96 nlohmann::json args;
97 RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
98 args["sampler_name"] = "DistributedSampler";
99 args["num_shards"] = num_shards_;
100 args["shard_id"] = shard_id_;
101 args["shuffle"] = shuffle_;
102 args["seed"] = seed_;
103 args["offset"] = offset_;
104 args["num_samples"] = num_samples_;
105 args["even_dist"] = even_dist_;
106 *out_json = args;
107 return Status::OK();
108 }
109
110 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)111 Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
112 std::shared_ptr<SamplerObj> *sampler) {
113 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
114 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
115 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
116 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
117 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
118 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
119 int64_t num_shards = json_obj["num_shards"];
120 int64_t shard_id = json_obj["shard_id"];
121 bool shuffle = json_obj["shuffle"];
122 uint32_t seed = json_obj["seed"];
123 int64_t offset = json_obj["offset"];
124 bool even_dist = json_obj["even_dist"];
125 *sampler =
126 std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed, offset, even_dist);
127 // Run common code in super class to add children samplers
128 RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
129 return Status::OK();
130 }
131 #endif
132
SamplerCopy()133 std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() {
134 auto sampler =
135 std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_);
136 for (const auto &child : children_) {
137 Status rc = sampler->AddChildSampler(child);
138 if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
139 }
140 return sampler;
141 }
142
ShardId()143 int64_t DistributedSamplerObj::ShardId() { return shard_id_; }
144 } // namespace dataset
145 } // namespace mindspore
146