Searched refs:new_batch_dim (Results 1 – 2 of 2) sorted by relevance
139 new_batch_dim = self._compute_static_batch_dim()143 lambda ts: ts._unbatch()._batch(new_batch_dim),175 new_batch_dim = tensor_util.constant_value(self._batch_sizes)176 if new_batch_dim is None:179 if isinstance(new_batch_dim, np.ndarray):180 if len(new_batch_dim.shape) == 1:181 if np.all(new_batch_dim == new_batch_dim[0]):182 new_batch_dim = new_batch_dim[0]185 elif len(new_batch_dim.shape) > 1:191 if self._may_form_partial_batches(new_batch_dim):[all …]
193 HloInstruction* select_val, int64_t new_batch_dim,658 int64_t new_batch_dim, new_spatial_dim; in BringSpaceNextToBatch() local668 new_batch_dim = pushed_counter; in BringSpaceNextToBatch()686 activations_batch_dim = new_batch_dim; in BringSpaceNextToBatch()700 new_batch_dim = pushed_counter; in BringSpaceNextToBatch()720 activations_batch_dim = new_batch_dim; in BringSpaceNextToBatch()1209 const int64_t new_batch_dim = in CanPropagate() local1214 first_operand->shape().dimensions(new_batch_dim); in CanPropagate()1250 const int64_t new_batch_dim = in CanPropagate() local1253 second_operand->shape().dimensions(new_batch_dim); in CanPropagate()[all …]