• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
19 
20 #include "minddata/dataset/core/config_manager.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 // Constructor
WeightedRandomSamplerObj(std::vector<double> weights,int64_t num_samples,bool replacement)25 WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
26     : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
27 
28 // Destructor
29 WeightedRandomSamplerObj::~WeightedRandomSamplerObj() = default;
30 
ValidateParams()31 Status WeightedRandomSamplerObj::ValidateParams() {
32   if (weights_.empty()) {
33     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
34   }
35   int32_t zero_elem = 0;
36   for (int32_t i = 0; i < weights_.size(); ++i) {
37     if (weights_[i] < 0) {
38       RETURN_STATUS_UNEXPECTED(
39         "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
40         "weights[" +
41         std::to_string(i) + "] = " + std::to_string(weights_[i]));
42     }
43     if (weights_[i] == 0) {
44       zero_elem++;
45     }
46   }
47   if (zero_elem == weights_.size()) {
48     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
49   }
50   if (num_samples_ < 0) {
51     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
52                              std::to_string(num_samples_));
53   }
54   return Status::OK();
55 }
56 
to_json(nlohmann::json * const out_json)57 Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
58   nlohmann::json args;
59   RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
60   args["sampler_name"] = "WeightedRandomSampler";
61   args["weights"] = weights_;
62   args["replacement"] = replacement_;
63   args["num_samples"] = num_samples_;
64   *out_json = args;
65   return Status::OK();
66 }
67 
68 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)69 Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
70                                            std::shared_ptr<SamplerObj> *sampler) {
71   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "weights", "WeightedRandomSampler"));
72   RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "replacement", "WeightedRandomSampler"));
73   std::vector<double> weights = json_obj["weights"];
74   bool replacement = json_obj["replacement"];
75   *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
76   // Run common code in super class to add children samplers
77   RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
78   return Status::OK();
79 }
80 #endif
81 
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)82 Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
83   *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
84   Status s = BuildChildren(sampler);
85   sampler = s.IsOk() ? sampler : nullptr;
86   return s;
87 }
88 
SamplerCopy()89 std::shared_ptr<SamplerObj> WeightedRandomSamplerObj::SamplerCopy() {
90   auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
91   for (const auto &child : children_) {
92     Status rc = sampler->AddChildSampler(child);
93     if (rc.IsError()) {
94       MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
95     }
96   }
97   return sampler;
98 }
99 }  // namespace dataset
100 }  // namespace mindspore
101