Lines Matching refs:expected_parameters
89 std::vector<std::vector<torch::Tensor>> expected_parameters) { in check_exact_values() argument
138 expected_parameters.at(i / kSampleEvery).size() == parameters.size()); in check_exact_values()
145 expected_parameters.at(i / kSampleEvery).at(p).to(torch::kFloat64); in check_exact_values()
301 check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam()); in TEST()
307 expected_parameters::Adam_with_weight_decay()); in TEST()
313 expected_parameters::Adam_with_weight_decay_and_amsgrad()); in TEST()
325 check_exact_values<AdamW>(AdamWOptions(1.0), expected_parameters::AdamW()); in TEST()
331 expected_parameters::AdamW_without_weight_decay()); in TEST()
337 expected_parameters::AdamW_with_amsgrad()); in TEST()
342 AdagradOptions(1.0), expected_parameters::Adagrad()); in TEST()
348 expected_parameters::Adagrad_with_weight_decay()); in TEST()
354 expected_parameters::Adagrad_with_weight_decay_and_lr_decay()); in TEST()
359 RMSpropOptions(0.1), expected_parameters::RMSprop()); in TEST()
365 expected_parameters::RMSprop_with_weight_decay()); in TEST()
371 expected_parameters::RMSprop_with_weight_decay_and_centered()); in TEST()
379 expected_parameters:: in TEST()
384 check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD()); in TEST()
390 expected_parameters::SGD_with_weight_decay()); in TEST()
396 expected_parameters::SGD_with_weight_decay_and_momentum()); in TEST()
402 expected_parameters::SGD_with_weight_decay_and_nesterov_momentum()); in TEST()
406 check_exact_values<LBFGS>(LBFGSOptions(1.0), expected_parameters::LBFGS()); in TEST()
412 expected_parameters::LBFGS_with_line_search()); in TEST()