Home
last modified time | relevance | path

Searched refs:weights_ (Results 1 – 8 of 8) sorted by relevance

/third_party/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/
Dweighted_random_sampler.cc32 …: SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample… in WeightedRandomSamplerRT()
52 if (weights_.size() > static_cast<size_t>(num_rows_)) { in InitSampler()
56 … std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); in InitSampler()
58 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { in InitSampler()
62 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); in InitSampler()
74 …te_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); in InitSampler()
86 for (size_t i = 0; i < weights_.size(); i++) { in InitOnePassSampling()
87 val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); in InitOnePassSampling()
116 if (weights_.size() > static_cast<size_t>(num_rows_)) { in GetNextSample()
120 … std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); in GetNextSample()
[all …]
Dweighted_random_sampler.h68 std::vector<double> weights_;
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/
Dweighted_random_sampler_ir.cc26 : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} in WeightedRandomSamplerObj()
32 if (weights_.empty()) { in ValidateParams()
36 for (int32_t i = 0; i < weights_.size(); ++i) { in ValidateParams()
37 if (weights_[i] < 0) { in ValidateParams()
39 std::to_string(weights_[i])); in ValidateParams()
41 if (weights_[i] == 0.0) { in ValidateParams()
45 if (zero_elem == weights_.size()) { in ValidateParams()
59 args["weights"] = weights_; in to_json()
81 …*sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_… in SamplerBuild()
87 auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); in SamplerCopy()
Dweighted_random_sampler_ir.h66 const std::vector<double> weights_;
/third_party/mindspore/mindspore/ccsrc/ps/
Dparameter_server.cc150 if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { in InitWeight()
152 weights_[key] = weight; in InitWeight()
220 if (weights_.count(key) == 0) { in InitEmbeddingTable()
256 weights_[key] = embedding; in InitEmbeddingTable()
265 bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_… in HasWeight()
281 for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { in UpdateWeights()
337 OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, in AccumGrad()
358 if (weights_.count(key) == 0) { in weight()
361 WeightPtr weight_ptr = weights_[key]; in weight()
373 if (weights_.count(key) == 0) { in DoEmbeddingLookup()
[all …]
Dparameter_server.h166 std::unordered_map<Key, WeightPtr> weights_; variable
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/api/
Dsamplers.cc89 : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} in WeightedRandomSampler()
92 return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); in Parse()
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/include/dataset/
Dsamplers.h249 std::vector<double> weights_;