Home
last modified time | relevance | path

Searched refs:batch_dims (Results 1 – 25 of 96) sorted by relevance

1234

/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_gather_op_test.py103 batch_dims=1,
109 batch_dims=1,
114 batch_dims=1,
122 batch_dims=1,
132 batch_dims=2,
138 batch_dims=2,
145 batch_dims=3,
154 batch_dims=0, argument
161 params, indices, axis=axis, batch_dims=batch_dims)
332 dict(params_shape=[7, 3], indices_shape=[7], batch_dims=1),
[all …]
Dragged_gather_ops.py40 batch_dims=0, argument
91 if batch_dims != indices.shape.rank:
92 batch_dims = array_ops.get_positive_axis(
93 batch_dims,
97 if params.shape.rank is not None and batch_dims >= params.shape.rank:
100 axis = batch_dims
103 if axis < batch_dims:
106 if not 0 <= batch_dims <= indices.shape.rank:
109 (batch_dims, indices.shape.rank))
111 return _gather(params, indices, axis, batch_dims)
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/
Dbanded_triangular_solve_op_test.py31 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None): argument
40 batch_dims=batch_dims,
44 def _verifySolveAllWaysReal(self, x, y, batch_dims=None): argument
45 self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
47 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None): argument
48 self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
55 batch_dims=None, argument
84 if batch_dims is not None:
85 a = np.tile(a, batch_dims + [1, 1])
86 a_np = np.tile(a_np, batch_dims + [1, 1])
[all …]
Dmatrix_triangular_solve_op_test.py31 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None): argument
40 batch_dims=batch_dims,
44 def _verifySolveAllWaysReal(self, x, y, batch_dims=None): argument
45 self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
47 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None): argument
48 self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
55 batch_dims=None, argument
75 if batch_dims is not None:
76 a = np.tile(a, batch_dims + [1, 1])
77 a_np = np.tile(a_np, batch_dims + [1, 1])
[all …]
Dmatrix_solve_op_test.py41 def _verifySolve(self, x, y, batch_dims=None): argument
56 if batch_dims is not None:
57 a = np.tile(a, batch_dims + [1, 1])
58 a_np = np.tile(a_np, batch_dims + [1, 1])
59 b = np.tile(b, batch_dims + [1, 1])
96 for batch_dims in [[2], [2, 2], [7, 4]]:
97 self._verifySolve(matrix, rhs, batch_dims=batch_dims)
Dmatrix_logarithm_op_test.py112 for batch_dims in [(), (1,), (3,), (2, 2)]:
114 shape = batch_dims + (size, size)
123 for batch_dims in [(), (1,), (3,), (2, 2)]:
125 shape = batch_dims + (size, size)
Darray_ops_test.py1930 for params_shape, indices_shape, batch_dims in shapes:
1934 batch_dims=batch_dims):
1939 params=params, indices=indices, batch_dims=batch_dims)
1940 ndims_params = len(params_shape) - batch_dims
1966 params=params, indices=indices, batch_dims=0)
1988 for params_shape, indices_shape, batch_dims in shapes:
1992 batch_dims=batch_dims):
1997 params=params, indices=indices, batch_dims=batch_dims)
1999 if batch_dims > 1:
2001 params, shape=[-1] + list(params_shape[batch_dims:]))
[all …]
/external/tensorflow/tensorflow/core/kernels/
Dgather_op.cc92 int32 batch_dims = batch_dims_; in Compute() local
93 if (batch_dims != 0) { in Compute()
95 batch_dims >= -indices.dims() && batch_dims <= indices.dims(), in Compute()
98 "], but got ", batch_dims)); in Compute()
100 if (batch_dims < 0) { in Compute()
101 batch_dims = indices.dims() + batch_dims; in Compute()
104 if (!axis_is_set) axis = batch_dims; in Compute()
106 OP_REQUIRES(c, batch_dims < params.dims(), in Compute()
107 errors::InvalidArgument("batch_dims (", batch_dims, in Compute()
111 OP_REQUIRES(c, axis >= batch_dims, in Compute()
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/array_ops/
Dgather_op_test.py348 batch_dims=0,
353 batch_dims=0,
358 batch_dims=0,
368 batch_dims=1,
373 batch_dims=2,
378 batch_dims=-1,
383 batch_dims=-1,
390 batch_dims=1,
395 batch_dims=2,
403 batch_dims=1,
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dqr_expander.cc81 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims, in House() argument
87 std::vector<int64> batch_dim_ids(batch_dims.size()); in House()
89 const int64 minor_dim = batch_dims.size(); in House()
94 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); in House()
139 Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims), in House()
143 std::vector<int64>(batch_dims.size(), 1)); in House()
190 std::vector<int64> batch_dims(num_batch_dims); in QrBlock() local
192 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in QrBlock()
207 batch_dims, m, &v, &tau, &beta)); in QrBlock()
209 const int64 minor_dim = batch_dims.size(); in QrBlock()
[all …]
Ddot_as_convolution_util.cc75 dims.batch_dims.push_back({lhs, rhs, output, i}); in ParseConvolutionDimsInfo()
115 for (const auto& dim : dot_dnums.batch_dims) { in CreateShardedConvForDotGeneralConvolution()
159 dnums.batch_dims.emplace_back(); in ParseDotGeneralFromDot()
160 dnums.batch_dims.back().lhs = dot_dim_numbs.lhs_batch_dimensions(i); in ParseDotGeneralFromDot()
161 dnums.batch_dims.back().rhs = dot_dim_numbs.rhs_batch_dimensions(i); in ParseDotGeneralFromDot()
162 dnums.batch_dims.back().output = i; in ParseDotGeneralFromDot()
163 dnums.batch_dims.back().spatial_dim = -1; in ParseDotGeneralFromDot()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dgather_op.cc158 int batch_dims, xla::XlaOp* gather_output) { in XlaGatherWithBatchDimsOpImpl() argument
188 if (batch_dims != 0) { in XlaGatherWithBatchDimsOpImpl()
189 if (batch_dims < 0) { in XlaGatherWithBatchDimsOpImpl()
190 batch_dims = indices_shape.dims() + batch_dims; in XlaGatherWithBatchDimsOpImpl()
193 axis = axis.value_or(batch_dims); in XlaGatherWithBatchDimsOpImpl()
195 if (batch_dims < -indices_shape.dims() || in XlaGatherWithBatchDimsOpImpl()
196 batch_dims > indices_shape.dims()) { in XlaGatherWithBatchDimsOpImpl()
199 indices_shape.dims(), "], but got ", batch_dims); in XlaGatherWithBatchDimsOpImpl()
202 if (batch_dims >= input_shape.dims()) { in XlaGatherWithBatchDimsOpImpl()
203 return errors::InvalidArgument("batch_dims (", batch_dims, in XlaGatherWithBatchDimsOpImpl()
[all …]
Dsoftmax_op.cc63 std::vector<int64> batch_dims(logits_shape.dims() - 1); in Compile() local
64 std::iota(batch_dims.begin(), batch_dims.end(), 0); in Compile()
88 auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); in Compile()
103 ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) in Compile()
105 : xla::Div(exp_shifted, sum, batch_dims); in Compile()
/external/tensorflow/tensorflow/python/ops/
Darray_ops.py4831 batch_dims=0): # pylint: disable=g-doc-args argument
5023 axis = batch_dims
5026 params, indices, axis, batch_dims=batch_dims, name=name)
5041 batch_dims=0, argument
5049 batch_dims=batch_dims)
5068 return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
5071 def _batch_gather(params, indices, batch_dims, axis=None): argument
5097 if batch_dims is not None and not isinstance(batch_dims, int):
5098 raise TypeError("batch_dims must be an int; got %r" % (batch_dims,))
5106 if batch_dims is None:
[all …]
Darray_grad.py585 def _GetBatchIndices(params_shape, indices, batch_dims): argument
592 for dim in range(batch_dims, 0, -1):
606 def _BatchGatherGrad(params_shape, values, indices, batch_dims, argument
612 if batch_dims:
615 outer_shape = values_shape[:batch_dims]
616 inner_shape = values_shape[batch_dims:][1:]
621 indices = _GetBatchIndices(params_shape, indices, batch_dims)
628 if batch_dims:
654 batch_dims = int(op.get_attr("batch_dims"))
656 if batch_dims < 0:
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc131 absl::Span<const int64> batch_dims,
137 const HloGatherInstruction* gather, absl::Span<const int64> batch_dims, in PartitionIndexOnlyPartition() argument
156 for (int64 i = 0; i < batch_dims.size(); ++i) { in PartitionIndexOnlyPartition()
158 output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim; in PartitionIndexOnlyPartition()
159 index_dim_to_output_dim[indices_batch_dim] = batch_dims[i]; in PartitionIndexOnlyPartition()
181 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonPassthroughOperand() argument
216 const HloSharding& output_sharding, absl::Span<const int64> batch_dims, in ParititonTrivialIndexedOperandDimension() argument
272 std::vector<int64> batch_dims; in ParititonTrivialIndexedOperandDimension() local
275 batch_dims.push_back(i); in ParititonTrivialIndexedOperandDimension()
280 batch_dims)); in ParititonTrivialIndexedOperandDimension()
[all …]
/external/tensorflow/tensorflow/core/ops/
Dresource_variable_ops.cc253 int32 batch_dims; in __anon50d007a70402() local
254 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims)); in __anon50d007a70402()
255 if (batch_dims < 0) in __anon50d007a70402()
256 return errors::InvalidArgument("batch_dims is negative (", batch_dims, in __anon50d007a70402()
260 batch_dims + 1, &unused)); in __anon50d007a70402()
263 c->WithRankAtLeast(indices_shape, batch_dims, &unused)); in __anon50d007a70402()
267 batch_dims, &params_subshape1)); in __anon50d007a70402()
271 batch_dims + 1, &params_subshape2)); in __anon50d007a70402()
275 c->Subshape(indices_shape, batch_dims, &indices_subshape)); in __anon50d007a70402()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dlu_decomposition.cc35 const std::vector<int64> batch_dims( in LuDecomposition() local
39 std::vector<int64> pivot_dims = batch_dims; in LuDecomposition()
41 std::vector<int64> perm_dims = batch_dims; in LuDecomposition()
Dsvd.cc119 std::vector<int64> batch_dims(num_batch_dims); in HouseRow() local
121 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); in HouseRow()
184 std::vector<int64> batch_dims(num_batch_dims); in HouseCol() local
186 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); in HouseCol()
258 std::vector<int64> batch_dims(num_batch_dims); in HouseHolderBidiagonalization() local
260 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in HouseHolderBidiagonalization()
265 IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims); in HouseHolderBidiagonalization()
267 IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims); in HouseHolderBidiagonalization()
460 std::vector<int64> batch_dims(num_batch_dims); in OneSidedJacobiUpdate() local
462 batch_dims[i] = ShapeUtil::GetDimension(d_shape, i); in OneSidedJacobiUpdate()
[all …]
Dslicing.cc263 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) { in TorchIndexSelect() argument
268 if (dim < batch_dims) { in TorchIndexSelect()
281 if (batch_dims > 0) { in TorchIndexSelect()
284 to_concat.reserve(batch_dims + 1); in TorchIndexSelect()
285 for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) { in TorchIndexSelect()
292 if (i < batch_dims || i == dim) { in TorchIndexSelect()
301 (1 + batch_dims)); in TorchIndexSelect()
Dself_adjoint_eig.cc120 const std::vector<int64> batch_dims(w_shape.dimensions().begin(), in Update() local
152 std::vector<int64> pq_dims(batch_dims.begin(), batch_dims.end()); in Update()
165 std::vector<int64> broadcast_dims(batch_dims.size()); in Update()
430 std::vector<int64> batch_dims(num_batch_dims); in SelfAdjointEig() local
432 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in SelfAdjointEig()
437 auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); in SelfAdjointEig()
/external/tensorflow/tensorflow/compiler/tests/
Dsvd_op_test.py88 for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
89 self._testSvdCorrectness(dtype, batch_dims + (n, n))
90 self._testSvdCorrectness(dtype, batch_dims + (2 * n, n))
91 self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n))
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dcusolver_rewriter.cc60 std::vector<int64> batch_dims(a_shape.dimensions().begin(), in CreateCholesky() local
62 std::vector<int64> batch_dim_ids(batch_dims.size()); in CreateCholesky()
85 Shape info_shape = ShapeUtil::MakeShape(S32, batch_dims); in CreateCholesky()
109 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, batch_dims), in CreateCholesky()
/external/tensorflow/tensorflow/python/data/experimental/ops/
Ddistribute.py506 batch_dims = [
511 if all(d is not None for d in batch_dims):
513 if all(d == batch_dims[0] for d in batch_dims):
515 batch_dim = batch_dims[0]
/external/tensorflow/tensorflow/python/ops/parallel_for/
Darray_test.py58 outputs.append(array_ops.gather(y, [0, 1, 2], axis=1, batch_dims=1))
59 outputs.append(array_ops.gather(y, [i, 1, 2], axis=2, batch_dims=1))
61 axis=-1, batch_dims=1))
63 array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=2))
64 outputs.append(array_ops.gather(y, [0, 1, 2], axis=1, batch_dims=-1))
66 array_ops.gather(y, [[0, 1, 2]] * 3, axis=2, batch_dims=-2))
78 outputs.append(array_ops.gather_nd(x_i, [0], batch_dims=0))
79 outputs.append(array_ops.gather_nd(x_i, [i], batch_dims=0))
80 outputs.append(array_ops.gather_nd(x_i, [[i], [i], [i]], batch_dims=1))

1234