1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
19
20 #include "minddata/dataset/core/config_manager.h"
21
22 namespace mindspore {
23 namespace dataset {
24 // Constructor
WeightedRandomSamplerObj(std::vector<double> weights,int64_t num_samples,bool replacement)25 WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
26 : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
27
28 // Destructor
29 WeightedRandomSamplerObj::~WeightedRandomSamplerObj() = default;
30
ValidateParams()31 Status WeightedRandomSamplerObj::ValidateParams() {
32 if (weights_.empty()) {
33 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
34 }
35 int32_t zero_elem = 0;
36 for (int32_t i = 0; i < weights_.size(); ++i) {
37 if (weights_[i] < 0) {
38 RETURN_STATUS_UNEXPECTED(
39 "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
40 "weights[" +
41 std::to_string(i) + "] = " + std::to_string(weights_[i]));
42 }
43 if (weights_[i] == 0) {
44 zero_elem++;
45 }
46 }
47 if (zero_elem == weights_.size()) {
48 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
49 }
50 if (num_samples_ < 0) {
51 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
52 std::to_string(num_samples_));
53 }
54 return Status::OK();
55 }
56
to_json(nlohmann::json * const out_json)57 Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
58 nlohmann::json args;
59 RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
60 args["sampler_name"] = "WeightedRandomSampler";
61 args["weights"] = weights_;
62 args["replacement"] = replacement_;
63 args["num_samples"] = num_samples_;
64 *out_json = args;
65 return Status::OK();
66 }
67
68 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)69 Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
70 std::shared_ptr<SamplerObj> *sampler) {
71 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "weights", "WeightedRandomSampler"));
72 RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "replacement", "WeightedRandomSampler"));
73 std::vector<double> weights = json_obj["weights"];
74 bool replacement = json_obj["replacement"];
75 *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
76 // Run common code in super class to add children samplers
77 RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
78 return Status::OK();
79 }
80 #endif
81
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)82 Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
83 *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
84 Status s = BuildChildren(sampler);
85 sampler = s.IsOk() ? sampler : nullptr;
86 return s;
87 }
88
SamplerCopy()89 std::shared_ptr<SamplerObj> WeightedRandomSamplerObj::SamplerCopy() {
90 auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
91 for (const auto &child : children_) {
92 Status rc = sampler->AddChildSampler(child);
93 if (rc.IsError()) {
94 MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
95 }
96 }
97 return sampler;
98 }
99 } // namespace dataset
100 } // namespace mindspore
101