• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "minddata/dataset/engine/ir/datasetops/source/samplers/weighted_random_sampler_ir.h"
18 #include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
19 
20 #include "minddata/dataset/core/config_manager.h"
21 
22 namespace mindspore {
23 namespace dataset {
24 // Constructor
WeightedRandomSamplerObj(std::vector<double> weights,int64_t num_samples,bool replacement)25 WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
26     : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
27 
28 // Destructor
29 WeightedRandomSamplerObj::~WeightedRandomSamplerObj() = default;
30 
ValidateParams()31 Status WeightedRandomSamplerObj::ValidateParams() {
32   if (weights_.empty()) {
33     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not be empty");
34   }
35   int32_t zero_elem = 0;
36   for (int32_t i = 0; i < weights_.size(); ++i) {
37     if (weights_[i] < 0) {
38       RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
39                                std::to_string(weights_[i]));
40     }
41     if (weights_[i] == 0.0) {
42       zero_elem++;
43     }
44   }
45   if (zero_elem == weights_.size()) {
46     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: elements of weights vector must not be all zero");
47   }
48   if (num_samples_ < 0) {
49     RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: num_samples must be greater than or equal to 0, but got: " +
50                              std::to_string(num_samples_));
51   }
52   return Status::OK();
53 }
54 
to_json(nlohmann::json * const out_json)55 Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
56   nlohmann::json args;
57   RETURN_IF_NOT_OK(SamplerObj::to_json(&args));
58   args["sampler_name"] = "WeightedRandomSampler";
59   args["weights"] = weights_;
60   args["replacement"] = replacement_;
61   args["num_samples"] = num_samples_;
62   *out_json = args;
63   return Status::OK();
64 }
65 
66 #ifndef ENABLE_ANDROID
from_json(nlohmann::json json_obj,int64_t num_samples,std::shared_ptr<SamplerObj> * sampler)67 Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
68                                            std::shared_ptr<SamplerObj> *sampler) {
69   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
70   CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
71   std::vector<double> weights = json_obj["weights"];
72   bool replacement = json_obj["replacement"];
73   *sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
74   // Run common code in super class to add children samplers
75   RETURN_IF_NOT_OK(SamplerObj::from_json(json_obj, sampler));
76   return Status::OK();
77 }
78 #endif
79 
SamplerBuild(std::shared_ptr<SamplerRT> * sampler)80 Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
81   *sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_);
82   Status s = BuildChildren(sampler);
83   sampler = s.IsOk() ? sampler : nullptr;
84   return s;
85 }
SamplerCopy()86 std::shared_ptr<SamplerObj> WeightedRandomSamplerObj::SamplerCopy() {
87   auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
88   for (const auto &child : children_) {
89     Status rc = sampler->AddChildSampler(child);
90     if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
91   }
92   return sampler;
93 }
94 }  // namespace dataset
95 }  // namespace mindspore
96