Searched refs:kBatchDim (Results 1 – 5 of 5) sorted by relevance
96 const int kBatchDim = 0; in CrossEntropyWithLogits() local106 auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); in CrossEntropyWithLogits()127 auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim}); in CrossEntropyWithLogits()138 xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); in CrossEntropyWithLogits()
44 const int kBatchDim = 0; in Compute() local46 const int rows = logits.dimension(kBatchDim); in Compute()
46 const int kBatchDim = 0; in Compute() local49 const int batch_size = logits.dimension(kBatchDim); in Compute()
66 const int kBatchDim = 0; in Compute() local69 const int batch_size = shape[kBatchDim]; in Compute()
183 const int kBatchDim = 0; in Compute() local186 const int batch_size = logits.dimension(kBatchDim); in Compute()