Home
last modified time | relevance | path

Searched refs:batch_dims (Results 1 – 25 of 34) sorted by relevance

12

/external/tensorflow/tensorflow/python/kernel_tests/
Dgather_op_test.py256 batch_dims=0,
261 batch_dims=0,
266 batch_dims=0,
275 batch_dims=1,
280 batch_dims=2,
285 batch_dims=-1,
290 batch_dims=-1,
297 batch_dims=1,
302 batch_dims=1,
309 batch_dims=2,
[all …]
Dmatrix_triangular_solve_op_test.py32 def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None): argument
41 batch_dims=batch_dims,
45 def _verifySolveAllWaysReal(self, x, y, batch_dims=None): argument
46 self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
48 def _verifySolveAllWaysComplex(self, x, y, batch_dims=None): argument
49 self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
56 batch_dims=None, argument
73 if batch_dims is not None:
74 a = np.tile(a, batch_dims + [1, 1])
75 a_np = np.tile(a_np, batch_dims + [1, 1])
[all …]
Dmatrix_solve_op_test.py39 def _verifySolve(self, x, y, batch_dims=None): argument
53 if batch_dims is not None:
54 a = np.tile(a, batch_dims + [1, 1])
55 a_np = np.tile(a_np, batch_dims + [1, 1])
56 b = np.tile(b, batch_dims + [1, 1])
93 for batch_dims in [[2], [2, 2], [7, 4]]:
94 self._verifySolve(matrix, rhs, batch_dims=batch_dims)
Dresource_variable_ops_test.py1099 batch_dims=0,
1104 batch_dims=0,
1109 batch_dims=0,
1118 batch_dims=1,
1123 batch_dims=2,
1128 batch_dims=1,
1133 batch_dims=2,
1140 batch_dims=1,
1145 batch_dims=1,
1152 batch_dims=2,
[all …]
Dmatrix_logarithm_op_test.py112 for batch_dims in [(), (1,), (3,), (2, 2)]:
114 shape = batch_dims + (size, size)
123 for batch_dims in [(), (1,), (3,), (2, 2)]:
125 shape = batch_dims + (size, size)
Dqr_op_test.py208 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
211 shape = batch_dims + (rows, cols)
229 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
230 shape = batch_dims + (rows, cols)
Dsvd_op_test.py318 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
319 shape = batch_dims + (rows, cols)
337 for batch_dims in [(), (3,)]:
338 shape = batch_dims + mat_shape
Dbatch_matmul_op_test.py38 batch_dims = x.shape[:-2]
39 num = np.prod(batch_dims)
40 z = np.empty(list(batch_dims) + [d0, d2], dtype=x.dtype)
Dmatrix_inverse_op_test.py135 for batch_dims in [(), (1,), (3,), (2, 2)]:
137 shape = batch_dims + (size, size)
Dself_adjoint_eig_op_test.py247 for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
248 shape = batch_dims + (size, size)
Dmatrix_exponential_op_test.py218 def _TestRandomSmall(dtype, batch_dims, size): argument
222 shape = batch_dims + (size, size)
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dqr.cc75 Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims, in House() argument
81 std::vector<int64> batch_dim_ids(batch_dims.size()); in House()
83 const int64 minor_dim = batch_dims.size(); in House()
89 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); in House()
105 *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), in House()
108 Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); in House()
111 std::vector<int64>(batch_dims.size(), 1)); in House()
168 std::vector<int64> batch_dims(num_batch_dims); in QRBlock() local
170 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in QRBlock()
186 batch_dims, m, &v, &tau, &beta)); in QRBlock()
[all …]
Dsvd.cc124 std::vector<int64> batch_dims(num_batch_dims); in HouseRow() local
126 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); in HouseRow()
190 std::vector<int64> batch_dims(num_batch_dims); in HouseCol() local
192 batch_dims[k] = ShapeUtil::GetDimension(a_shape, k); in HouseCol()
264 std::vector<int64> batch_dims(num_batch_dims); in HouseHolderBidiagonalization() local
266 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in HouseHolderBidiagonalization()
271 IdentityMatrix(builder, a_shape.element_type(), m, m), batch_dims); in HouseHolderBidiagonalization()
273 IdentityMatrix(builder, a_shape.element_type(), n, n), batch_dims); in HouseHolderBidiagonalization()
469 std::vector<int64> batch_dims(num_batch_dims); in OneSidedJacobiUpdate() local
471 batch_dims[i] = ShapeUtil::GetDimension(d_shape, i); in OneSidedJacobiUpdate()
[all …]
Dself_adjoint_eig.cc120 const std::vector<int64> batch_dims(w_shape.dimensions().begin(), in Update() local
152 std::vector<int64> pq_dims(batch_dims.begin(), batch_dims.end()); in Update()
165 std::vector<int64> broadcast_dims(batch_dims.size()); in Update()
430 std::vector<int64> batch_dims(num_batch_dims); in SelfAdjointEig() local
432 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); in SelfAdjointEig()
437 auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); in SelfAdjointEig()
/external/tensorflow/tensorflow/python/ops/
Darray_ops.py3293 batch_dims=0): argument
3360 if batch_dims != 0:
3362 return _batch_gather(params, indices, batch_dims, axis)
3364 axis = batch_dims
3381 batch_dims=0, name=None): argument
3383 axis=axis, batch_dims=batch_dims)
3402 return _batch_gather(params, indices, batch_dims=indices.shape.ndims - 1)
3405 def _batch_gather(params, indices, batch_dims, axis=None): argument
3431 if batch_dims is not None and not isinstance(batch_dims, int):
3432 raise TypeError("batch_dims must be an int; got %r" % batch_dims)
[all …]
/external/tensorflow/tensorflow/core/ops/
Dresource_variable_ops.cc270 int32 batch_dims; in __anon3f9cfa090402() local
271 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims)); in __anon3f9cfa090402()
272 if (batch_dims < 0) in __anon3f9cfa090402()
273 return errors::InvalidArgument("batch_dims is negative (", batch_dims, in __anon3f9cfa090402()
277 batch_dims + 1, &unused)); in __anon3f9cfa090402()
280 c->WithRankAtLeast(indices_shape, batch_dims, &unused)); in __anon3f9cfa090402()
284 batch_dims, &params_subshape1)); in __anon3f9cfa090402()
288 batch_dims + 1, &params_subshape2)); in __anon3f9cfa090402()
292 c->Subshape(indices_shape, batch_dims, &indices_subshape)); in __anon3f9cfa090402()
Dmath_ops.cc144 ShapeHandle batch_dims; in __anonb22bfa860202() local
147 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims)); in __anonb22bfa860202()
156 batch_dims, c->Matrix(output_rows, output_cols), &out)); in __anonb22bfa860202()
/external/tensorflow/tensorflow/compiler/tests/
Dsvd_op_test.py74 for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
75 self._testSvdCorrectness(dtype, batch_dims + (n, n))
76 self._testSvdCorrectness(dtype, batch_dims + (2 * n, n))
77 self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n))
Dself_adjoint_eig_op_test.py57 for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10):
58 self._test(dtype, batch_dims + (n, n))
Dqr_op_test.py107 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
108 self._test(dtype, batch_dims + (rows, cols), full_matrices)
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dsoftmax_op.cc47 std::vector<int64> batch_dims(logits_shape.dims() - 1); in Compile() local
48 std::iota(batch_dims.begin(), batch_dims.end(), 0); in Compile()
63 auto shifted_logits = xla::Sub(logits, logits_max, batch_dims); in Compile()
78 ? xla::Sub(shifted_logits, xla::Log(sum), batch_dims) in Compile()
80 : xla::Div(exp_shifted, sum, batch_dims); in Compile()
/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_gather_ops.py36 def gather(params, indices, validate_indices=None, axis=0, batch_dims=0, argument
92 if not isinstance(batch_dims, int) or batch_dims != 0:
Dragged_dispatch.py407 axis=0, batch_dims=0): argument
413 batch_dims=batch_dims,
/external/tensorflow/tensorflow/contrib/distributions/python/ops/
Dtest_util.py89 batch_dims = array_ops.shape(dist.batch_shape_tensor())[0]
90 edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]])
/external/tensorflow/tensorflow/compiler/xla/service/
Dindexed_array_analysis.cc977 absl::Span<const int64> batch_dims) { in GetOnlyNonContractingNonBatchDim() argument
981 !absl::c_linear_search(batch_dims, dim)) { in GetOnlyNonContractingNonBatchDim()
1002 absl::Span<const int64> batch_dims) { in CanFoldDotIntoIndexedArray() argument
1005 contracting_dims, batch_dims); in CanFoldDotIntoIndexedArray()

12