Searched refs:weights_ (Results 1 – 8 of 8) sorted by relevance
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/ |
D | weighted_random_sampler.cc | 32 …: SamplerRT(num_samples, samples_per_tensor), weights_(weights), replacement_(replacement), sample… in WeightedRandomSamplerRT() 52 if (weights_.size() > static_cast<size_t>(num_rows_)) { in InitSampler() 56 … std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); in InitSampler() 58 if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { in InitSampler() 62 std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_)); in InitSampler() 74 …te_dist_ = std::make_unique<std::discrete_distribution<int64_t>>(weights_.begin(), weights_.end()); in InitSampler() 86 for (size_t i = 0; i < weights_.size(); i++) { in InitOnePassSampling() 87 val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); in InitOnePassSampling() 116 if (weights_.size() > static_cast<size_t>(num_rows_)) { in GetNextSample() 120 … std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_)); in GetNextSample() [all …]
|
D | weighted_random_sampler.h | 68 std::vector<double> weights_;
|
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/samplers/ |
D | weighted_random_sampler_ir.cc | 26 : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} in WeightedRandomSamplerObj() 32 if (weights_.empty()) { in ValidateParams() 36 for (int32_t i = 0; i < weights_.size(); ++i) { in ValidateParams() 37 if (weights_[i] < 0) { in ValidateParams() 39 std::to_string(weights_[i])); in ValidateParams() 41 if (weights_[i] == 0.0) { in ValidateParams() 45 if (zero_elem == weights_.size()) { in ValidateParams() 59 args["weights"] = weights_; in to_json() 81 …*sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(weights_, num_samples_, replacement_… in SamplerBuild() 87 auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); in SamplerCopy()
|
D | weighted_random_sampler_ir.h | 66 const std::vector<double> weights_;
|
/third_party/mindspore/mindspore/ccsrc/ps/ |
D | parameter_server.cc | 150 if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { in InitWeight() 152 weights_[key] = weight; in InitWeight() 220 if (weights_.count(key) == 0) { in InitEmbeddingTable() 256 weights_[key] = embedding; in InitEmbeddingTable() 265 bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_… in HasWeight() 281 for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { in UpdateWeights() 337 OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, in AccumGrad() 358 if (weights_.count(key) == 0) { in weight() 361 WeightPtr weight_ptr = weights_[key]; in weight() 373 if (weights_.count(key) == 0) { in DoEmbeddingLookup() [all …]
|
D | parameter_server.h | 166 std::unordered_map<Key, WeightPtr> weights_; variable
|
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/api/ |
D | samplers.cc | 89 : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} in WeightedRandomSampler() 92 return std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_); in Parse()
|
/third_party/mindspore/mindspore/ccsrc/minddata/dataset/include/dataset/ |
D | samplers.h | 249 std::vector<double> weights_;
|