Home
last modified time | relevance | path

Searched defs:shifted_logits (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/dtensor/mlir/expansions/
Dsoftmax_spmd_expander.cc103 mlir::Value& shifted_logits, in ComputeExpAndSum()
154 const mlir::Value& shifted_logits, in ComputeLogSoftmax()
170 mlir::Value shifted_logits; in ComputeShardedSoftmax() local
568 mlir::Value shifted_logits; in ExpandOp() local
/external/tensorflow/tensorflow/core/kernels/
Dsoftmax_op_functor.h61 auto shifted_logits = (logits - logits.maximum(along_class) in Compute() local
/external/tensorflow/tensorflow/compiler/mlir/tfrt/benchmarks/
Dsoftmax_op_benchmark.cc62 auto shifted_logits = (logits - logits.maximum(along_class) in ComputeSoftmax() local
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dsoftmax_op.cc60 auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); in CrossEntropyWithLogits() local
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dlower_tf.cc1670 auto shifted_logits = rewriter.create<TF::SubOp>(loc, logits, max_logits); in matchAndRewrite() local