Home
last modified time | relevance | path

Searched refs:RandomnessType (Results 1 – 13 of 13) sorted by relevance

/external/pytorch/aten/src/ATen/functorch/
DBatchRulesRandomness.cpp28 RandomnessType randomness = maybe_layer->randomness(); in random_batching_rule()
30 if (randomness == RandomnessType::Different) { in random_batching_rule()
44 RandomnessType randomness = maybe_layer->randomness(); in random_inplace_batching_rule()
47 !(randomness == RandomnessType::Different && !self_bdim), in random_inplace_batching_rule()
50 if (randomness == RandomnessType::Same && self_bdim) { in random_inplace_batching_rule()
65 RandomnessType randomness = maybe_layer->randomness(); in bernoulli_inplace_Tensor_batching_rule()
92 !(randomness == RandomnessType::Different && !self_bdim), in bernoulli_inplace_Tensor_batching_rule()
95 if (randomness == RandomnessType::Same && self_bdim) { in bernoulli_inplace_Tensor_batching_rule()
111 RandomnessType randomness = maybe_layer->randomness(); in randperm_batching_rule()
113 if (randomness == RandomnessType::Different) { in randperm_batching_rule()
[all …]
DInterpreter.h44 enum class RandomnessType { enum
92 explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) : in VmapInterpreterMeta()
95 RandomnessType randomness_;
125 static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) { in Vmap()
DDynamicLayer.h47 std::optional<RandomnessType> randomness = std::nullopt,
60 RandomnessType randomness() const;
69 std::optional<RandomnessType> randomness = std::nullopt,
DBatchRulesHelper.cpp86 void check_randomness(RandomnessType randomness, bool any_tensor_batched) { in check_randomness()
88 randomness != RandomnessType::Error, in check_randomness()
94 !(randomness == RandomnessType::Same && any_tensor_batched), in check_randomness()
100 void check_randomness(RandomnessType randomness) { in check_randomness()
DVmapInterpreter.h18 RandomnessType randomness() const { in randomness()
DDynamicLayer.cpp31 std::optional<RandomnessType> randomness, in DynamicLayer()
72 RandomnessType DynamicLayer::randomness() const { in randomness()
254 std::optional<RandomnessType> randomness, in initAndPushDynamicLayer()
DBatchRulesBinaryOps.cpp61 RandomnessType randomness = maybe_layer->randomness(); in apply()
68 if (randomness == RandomnessType::Different && !tensor_bdim && !other_bdim) { in apply()
77 } else if (randomness == RandomnessType::Same && !tensor_bdim && !other_bdim) { in apply()
DBatchRulesHelper.h44 void check_randomness(RandomnessType randomness);
45 void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
DBatchRulesLinearAlgebra.cpp501 RandomnessType randomness = maybe_layer->randomness(); in _scaled_dot_product_flash_attention_batch_rule()
550 RandomnessType randomness = maybe_layer->randomness(); in _scaled_dot_product_efficient_attention_batch_rule()
592 RandomnessType randomness = maybe_layer->randomness(); in _scaled_dot_product_cudnn_attention_batch_rule()
/external/pytorch/torch/csrc/functorch/
Dinit.cpp228 RandomnessType get_randomness_enum(const std::string& randomness) { in get_randomness_enum()
230 return RandomnessType::Error; in get_randomness_enum()
232 return RandomnessType::Same; in get_randomness_enum()
234 return RandomnessType::Different; in get_randomness_enum()
567 py::enum_<RandomnessType>(m, "RandomnessType") in initFuncTorchBindings()
568 .value("Error", RandomnessType::Error) in initFuncTorchBindings()
569 .value("Same", RandomnessType::Same) in initFuncTorchBindings()
570 .value("Different", RandomnessType::Different); in initFuncTorchBindings()
/external/pytorch/torch/_functorch/
Dpyfunctorch.py16 RandomnessType,
137 if typ == RandomnessType.Error:
139 elif typ == RandomnessType.Same:
141 elif typ == RandomnessType.Different:
/external/pytorch/torch/_C/
D_functorch.pyi44 class RandomnessType(Enum):
74 def randomness(self) -> RandomnessType: ...
/external/pytorch/functorch/csrc/dim/
Ddim.cpp1137 …micLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); in EnableAllLayers()