1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
19 #include "minddata/dataset/core/config_manager.h"
20
21 #ifndef ENABLE_ANDROID
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 #endif
30
31 namespace mindspore {
32 namespace dataset {
33 // Constructor
RandomSamplerObj(bool replacement,int64_t num_samples,bool reshuffle_each_epoch)34 RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples, bool reshuffle_each_epoch)
35 : replacement_(replacement), num_samples_(num_samples), reshuffle_each_epoch_(reshuffle_each_epoch) {}
36
37 // Destructor
38 RandomSamplerObj::~RandomSamplerObj() = default;
39
ValidateParams()40 Status RandomSamplerObj::ValidateParams() {
41 if (num_samples_ < 0) {
42 RETURN_STATUS_UNEXPECTED("RandomSampler: num_samples must be greater than or equal to 0, but got: " +
43 std::to_string(num_samples_));
44 }
45 return Status::OK();
46 }
47
to_json(nlohmann::json * const out_json)48 Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
49 nlohmann::json args;
50 RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
51 args["sampler_name"] = "RandomSampler";
52 args["replacement"] = replacement_;
53 args["reshuffle_each_epoch"] = reshuffle_each_epoch_;
54 args["num_samples"] = num_samples_;
55 *out_json = args;
56 return Status::OK();
57 }
58
59 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)60 Status RandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
61 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
62 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
63 "Failed to find reshuffle_each_epoch");
64 bool replacement = json_obj["replacement"];
65 bool reshuffle_each_epoch = json_obj["reshuffle_each_epoch"];
66 *sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
67 // Run common code in super class to add children samplers
68 RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
69 return Status::OK();
70 }
71 #endif
72
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)73 Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
74 // runtime sampler object
75 *sampler = std::make_shared<dataset::RandomSamplerRT>(replacement_, num_samples_, reshuffle_each_epoch_);
76 Status s = BuildChildren(sampler);
77 sampler = s.IsOk() ? sampler : nullptr;
78 return s;
79 }
80
81 #ifndef ENABLE_ANDROID
BuildForMindDataset()82 std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
83 // runtime mindrecord sampler object
84 auto mind_sampler =
85 std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
86
87 return mind_sampler;
88 }
89 #endif
90
SamplerCopy()91 std::shared_ptr<SamplerObj> RandomSamplerObj::SamplerCopy() {
92 auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
93 for (const auto &child : children_) {
94 Status rc = sampler->AddChildSampler(child);
95 if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
96 }
97 return sampler;
98 }
99 } // namespace dataset
100 } // namespace mindspore
101