1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
19
20 #include "minddata/dataset/core/config_manager.h"
21
22 namespace mindspore {
23 namespace dataset {
24 // Constructor
WeightedRandomSamplerObj(std::vector<double> weights,int64_t num_samples,bool replacement)25 WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
26 : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
27
28 // Destructor
29 WeightedRandomSamplerObj::~WeightedRandomSamplerObj() = default;
30
ValidateParams()31 Status WeightedRandomSamplerObj::ValidateParams() {
32 if (weights_.empty()) {
33 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
34 }
35 int32_t zero_elem = 0;
36 for (int32_t i = 0; i < weights_.size(); ++i) {
37 if (weights_[i] < 0) {
38 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
39 std::to_string(weights_[i]));
40 }
41 if (weights_[i] == 0.0) {
42 zero_elem++;
43 }
44 }
45 if (zero_elem == weights_.size()) {
46 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
47 }
48 if (num_samples_ < 0) {
49 RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
50 std::to_string(num_samples_));
51 }
52 return Status::OK();
53 }
54
to_json(nlohmann::json * const out_json)55 Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
56 nlohmann::json args;
57 RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
58 args["sampler_name"] = "WeightedRandomSampler";
59 args["weights"] = weights_;
60 args["replacement"] = replacement_;
61 args["num_samples"] = num_samples_;
62 *out_json = args;
63 return Status::OK();
64 }
65
66 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)67 Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
68 std::shared_ptr<SamplerObj> *sampler) {
69 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
70 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
71 std::vector<double> weights = json_obj["weights"];
72 bool replacement = json_obj["replacement"];
73 *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
74 // Run common code in super class to add children samplers
75 RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
76 return Status::OK();
77 }
78 #endif
79
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)80 Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
81 *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
82 Status s = BuildChildren(sampler);
83 sampler = s.IsOk() ? sampler : nullptr;
84 return s;
85 }
SamplerCopy()86 std::shared_ptr<SamplerObj> WeightedRandomSamplerObj::SamplerCopy() {
87 auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
88 for (const auto &child : children_) {
89 Status rc = sampler->AddChildSampler(child);
90 if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
91 }
92 return sampler;
93 }
94 } // namespace dataset
95 } // namespace mindspore
96