Home
last modified time | relevance | path

Searched refs:split_shape (Results 1 – 4 of 4) sorted by relevance

/external/ComputeLibrary/tests/validation/fixtures/
DSplitFixture.h174 for(const auto &split_shape : split_shapes) in compute_target() local
176 TensorType dst = create_tensor<TensorType>(split_shape, data_type); in compute_target()
229 for(const auto &split_shape : split_shapes) in compute_reference() local
232 const size_t axis_split_step = split_shape[axis]; in compute_reference()
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
Dxla_sharding_test.py106 split_shape = xla_sharding.get_sharding_tile_shape(split_sharding)
108 self.assertEqual(expected_shape, split_shape)
/external/tensorflow/tensorflow/python/ops/
Darray_grad.py846 split_shape = array_ops.reshape(
848 axes = math_ops.range(0, array_ops.size(split_shape), 2)
854 split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
855 input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
/external/tensorflow/tensorflow/cc/gradients/
Darray_grad.cc614 auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1}); in TileGrad() local
615 auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2); in TileGrad()
617 scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output); in TileGrad()