Home
last modified time | relevance | path

Searched refs:probs_sort (Results 1 – 4 of 4) sorted by relevance

/external/executorch/examples/models/llama/runner/
Dgeneration.py30 probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
31 probs_sum = torch.cumsum(probs_sort, dim=-1)
32 mask = probs_sum - probs_sort > p
33 probs_sort[mask] = 0.0
34 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
35 next_token = torch.multinomial(probs_sort, num_samples=1)
/external/executorch/examples/qualcomm/oss_scripts/llama2/
Dllama.py196 probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True)
197 probs_sum = torch.cumsum(probs_sort, dim=-1)
198 mask = probs_sum - probs_sort > top_p
199 probs_sort[mask] = 0
200 probs_sort /= probs_sort.sum(dim=-1, keepdim=True)
201 next_token = torch.multinomial(probs_sort, num_samples=1)
/external/executorch/examples/models/llama/experimental/
Dgenerate.py19 probs_sort, argument
21 q = torch.empty_like(probs_sort).exponential_(1)
22 return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
/external/pytorch/benchmarks/gpt_fast/
Dgenerate.py54 probs_sort, argument
56 q = torch.empty_like(probs_sort).exponential_(1)
57 return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)