Home
last modified time | relevance | path

Searched refs:batch_idx (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/
Dgather_functor.h61 SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size); in HandleCopies() local
66 while ((batch_idx < batch_idx_end) || in HandleCopies()
67 (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { in HandleCopies()
69 SliceIndex b_next = batch_idx + 1; in HandleCopies()
70 if ((batch_idx == batch_idx_end && i_next < indices_idx_end) || in HandleCopies()
73 &params(batch_idx, indices(i_next), 0)); in HandleCopies()
74 port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0)); in HandleCopies()
75 b_next = batch_idx; in HandleCopies()
93 out_base + (batch_idx * indices_size + indices_idx) * slice_elems, in HandleCopies()
94 params_base + (batch_idx * static_cast<SliceIndex>(limit) + in HandleCopies()
[all …]
Dresource_variable_ops.cc603 for (int64 batch_idx = 0, dest_idx = 0; batch_idx < batch_size; in AddBatchOffsets() local
604 ++batch_idx) { in AddBatchOffsets()
606 indices_flat(dest_idx++) += batch_offset * batch_idx; in AddBatchOffsets()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dreverse_sequence_op.cc76 xla::XlaOp batch_idx = xla::Iota( in Compile() local
92 batch_idx = xla::Transpose(batch_idx, {1, 0, 2}); in Compile()
97 xla::ConcatInDim(builder, {batch_idx, reverse_idx}, in Compile()
/external/tensorflow/tensorflow/stream_executor/
Ddnn.cc164 int depth_idx, batch_idx, spatial_idx; in GetDimIndices() local
168 batch_idx = data_dims - 2; in GetDimIndices()
174 batch_idx = data_dims - 1; in GetDimIndices()
180 batch_idx = 0; in GetDimIndices()
187 batch_idx = 0; in GetDimIndices()
195 return std::make_tuple(depth_idx, batch_idx, spatial_idx); in GetDimIndices()
/external/tensorflow/tensorflow/python/ops/
Dctc_ops.py400 batch_idx = array_ops.zeros_like(label_states[2:])
402 [batch_idx, label_states[2:], label_states[1:-1]], 1)
405 batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
406 indices += array_ops.expand_dims(batch_idx, 1)