1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/subset_random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
19 #include "minddata/dataset/core/config_manager.h"
20
21 #ifndef ENABLE_ANDROID
22 #include "minddata/dataset/util/random.h"
23 #include "minddata/mindrecord/include/shard_distributed_sample.h"
24 #include "minddata/mindrecord/include/shard_operator.h"
25 #include "minddata/mindrecord/include/shard_pk_sample.h"
26 #include "minddata/mindrecord/include/shard_sample.h"
27 #include "minddata/mindrecord/include/shard_sequential_sample.h"
28 #include "minddata/mindrecord/include/shard_shuffle.h"
29 #endif
30
31 namespace mindspore {
32 namespace dataset {
33 // Constructor
SubsetRandomSamplerObj(std::vector<int64_t> indices,int64_t num_samples)34 SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
35 : SubsetSamplerObj(std::move(indices), num_samples) {}
36
37 // Destructor
38 SubsetRandomSamplerObj::~SubsetRandomSamplerObj() = default;
39
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)40 Status SubsetRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
41 // runtime sampler object
42 *sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(indices_, num_samples_);
43 Status s = BuildChildren(sampler);
44 sampler = s.IsOk() ? sampler : nullptr;
45 return s;
46 }
47
48 #ifndef ENABLE_ANDROID
BuildForMindDataset()49 std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
50 // runtime mindrecord sampler object
51 auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
52
53 return mind_sampler;
54 }
55 #endif
56
to_json(nlohmann::json * const out_json)57 Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
58 nlohmann::json args;
59 RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
60 args["sampler_name"] = "SubsetRandomSampler";
61 args["indices"] = indices_;
62 args["num_samples"] = num_samples_;
63 *out_json = args;
64 return Status::OK();
65 }
66
67 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)68 Status SubsetRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
69 std::shared_ptr<SamplerObj> *sampler) {
70 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "indices", "SubsetRandomSampler"));
71 std::vector<int64_t> indices = json_obj["indices"];
72 *sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
73 // Run common code in super class to add children samplers
74 RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
75 return Status::OK();
76 }
77 #endif
78
SamplerCopy()79 std::shared_ptr<SamplerObj> SubsetRandomSamplerObj::SamplerCopy() {
80 auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
81 for (const auto &child : children_) {
82 Status rc = sampler->AddChildSampler(child);
83 if (rc.IsError()) {
84 MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
85 }
86 }
87 return sampler;
88 }
89 } // namespace dataset
90 } // namespace mindspore
91