Home
last modified time | relevance | path

Searched refs:bcast (Results 1 – 25 of 59) sorted by relevance

123

/external/tensorflow/tensorflow/core/util/
Dmatmul_bcast_test.cc38 MatMulBCast bcast({1, 5, 3}, {4, 3, 7}); in TEST() local
40 EXPECT_TRUE(bcast.IsValid()); in TEST()
41 EXPECT_TRUE(bcast.IsBroadcastingRequired()); in TEST()
43 EXPECT_EQ(1, bcast.x_batch_size()); in TEST()
44 EXPECT_EQ(4, bcast.y_batch_size()); in TEST()
45 EXPECT_EQ(4, bcast.output_batch_size()); in TEST()
47 EXPECT_EQ("[4][0,0,0,0][0,1,2,3]", MatMulBCastToStr(bcast)); in TEST()
51 MatMulBCast bcast({5, 3}, {3, 7}); in TEST() local
53 EXPECT_TRUE(bcast.IsValid()); in TEST()
54 EXPECT_FALSE(bcast.IsBroadcastingRequired()); in TEST()
[all …]
/external/tensorflow/tensorflow/core/kernels/
Dops_util_test.cc124 static void VerifyBoundaries(bcast_struct bcast, error::Code code) { in VerifyBoundaries() argument
127 bcast.input.index, bcast.input.in_size, bcast.input.ksize, in VerifyBoundaries()
128 bcast.input.stride, bcast.input.pad_size, &new_index, &new_size); in VerifyBoundaries()
132 static void VerifyBcastValues(bcast_struct bcast) { in VerifyBcastValues() argument
135 GetBroadcastSize(bcast.input.index, bcast.input.in_size, in VerifyBcastValues()
136 bcast.input.ksize, bcast.input.stride, in VerifyBcastValues()
137 bcast.input.pad_size, &new_index, &new_size)); in VerifyBcastValues()
138 EXPECT_EQ(bcast.output.new_index, new_index); in VerifyBcastValues()
139 EXPECT_EQ(bcast.output.new_size, new_size); in VerifyBcastValues()
178 bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}}; in TEST_F() local
[all …]
Dbroadcast_to_op.h37 const typename Eigen::array<int, NDIMS> &bcast) const { in DoBCast32Bit()
38 To32Bit(out).device(device) = To32Bit(in).broadcast(bcast); in DoBCast32Bit()
45 const typename Eigen::array<Eigen::DenseIndex, NDIMS> &bcast) const { in DoBCast()
46 out.device(device) = in.broadcast(bcast); in DoBCast()
51 const Tensor &input_tensor, const BCast &bcast) const { in ReshapeAndBCast()
57 device, output_tensor.template shaped<T, NDIMS>(bcast.result_shape()), in ReshapeAndBCast()
58 input_tensor.template shaped<T, NDIMS>(bcast.x_reshape()), in ReshapeAndBCast()
59 BCast::ToIndexArrayType<int, NDIMS>(bcast.x_bcast())); in ReshapeAndBCast()
62 device, output_tensor.template shaped<T, NDIMS>(bcast.result_shape()), in ReshapeAndBCast()
63 input_tensor.template shaped<T, NDIMS>(bcast.x_reshape()), in ReshapeAndBCast()
[all …]
Dtraining_ops_gpu.cu.cc114 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; in operator ()() local
115 bcast[0] = grad.dimension(0); in operator ()()
117 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad; in operator ()()
130 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; in operator ()() local
131 bcast[0] = grad.dimension(0); in operator ()()
133 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); in operator ()()
144 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; in operator ()() local
145 bcast[0] = grad.dimension(0); in operator ()()
151 grad / (accum.sqrt() + epsilon.reshape(single).broadcast(bcast)); in operator ()()
152 var.device(d) -= lr.reshape(single).broadcast(bcast) * update; in operator ()()
[all …]
Dbatch_matmul_op_impl.h82 const MatMulBCast& bcast, Tensor* out, int start, int limit) { in Run()
96 const bool should_bcast = bcast.IsBroadcastingRequired(); in Run()
97 const auto& x_batch_indices = bcast.x_batch_indices(); in Run()
98 const auto& y_batch_indices = bcast.y_batch_indices(); in Run()
125 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
133 const bool should_bcast = bcast.IsBroadcastingRequired();
134 const auto& x_batch_indices = bcast.x_batch_indices();
135 const auto& y_batch_indices = bcast.y_batch_indices();
172 bool adj_y, const MatMulBCast& bcast, Tensor* out, int start,
174 const bool should_bcast = bcast.IsBroadcastingRequired();
[all …]
Dsubstr_op.cc131 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); in Compute() local
132 OP_REQUIRES(context, bcast.IsValid(), in Compute()
136 TensorShape output_shape = BCast::ToShape(bcast.result_shape()); in Compute()
144 auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape()); in Compute()
145 auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape()); in Compute()
146 auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); in Compute()
147 auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); in Compute()
154 input_buffer.shaped<tstring, 1>(bcast.result_shape()); in Compute()
156 input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); in Compute()
164 pos_buffer.shaped<T, 1>(bcast.result_shape())); in Compute()
[all …]
Dmatrix_triangular_solve_op_impl.h85 bool adjoint, const MatMulBCast& bcast, Tensor* out, in Run()
87 const bool should_bcast = bcast.IsBroadcastingRequired(); in Run()
88 const auto& x_batch_indices = bcast.x_batch_indices(); in Run()
89 const auto& y_batch_indices = bcast.y_batch_indices(); in Run()
122 const MatMulBCast& bcast, Tensor* out) {
124 const int64 batch_size = bcast.output_batch_size();
142 [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) {
144 in_x, in_y, lower, adjoint, bcast, out, start, limit);
166 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
168 ctx, bcast.IsValid(),
[all …]
Dxent_op.cc49 BCast bcast(BCast::FromShape(logits_in.shape()), in Compute() local
52 OP_REQUIRES(context, bcast.IsValid(), in Compute()
57 shape_in = BCast::ToShape(bcast.output_shape()); in Compute()
90 BCast::ToIndexArray<2>(bcast.x_bcast()), in Compute()
91 BCast::ToIndexArray<2>(bcast.y_bcast()), in Compute()
92 logits_in.template shaped<T, 2>(bcast.x_reshape()), in Compute()
93 labels_in.template shaped<T, 2>(bcast.y_reshape()), in Compute()
Dbias_op.h37 Eigen::DSizes<Eigen::Index, 1> bcast(rest_size); in operator()
39 input.reshape(one_d) + bias.broadcast(bcast); in operator()
44 Eigen::DSizes<int, 1> bcast(rest_size); in operator()
46 To32Bit(input).reshape(one_d) + To32Bit(bias).broadcast(bcast); in operator()
Dbroadcast_to_op.cc79 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(output_shape), in Compute() local
81 OP_REQUIRES(ctx, bcast.IsValid(), in Compute()
85 OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape, in Compute()
91 input_tensor, input_shape, bcast); in Compute()
114 const TensorShape& input_shape, const BCast& bcast) const; \
Dbcast_ops.cc46 BCast bcast(shapes[0], shapes[1]); in Compute() local
47 OP_REQUIRES(ctx, bcast.IsValid(), in Compute()
51 Output(ctx, 0, bcast.output_shape()); in Compute()
95 BCast bcast(shapes[0], shapes[1]); in Compute() local
96 OP_REQUIRES(ctx, bcast.IsValid(), in Compute()
100 Output(ctx, 0, bcast.grad_x_reduce_idx()); in Compute()
101 Output(ctx, 1, bcast.grad_y_reduce_idx()); in Compute()
Dcwise_ops_common.h66 BCast bcast; member
151 auto& bcast = state.bcast; in Compute() local
153 if (!bcast.IsValid()) { in Compute()
192 eigen_device, out->shaped<Tout, 2>(bcast.result_shape()), in Compute()
193 in0.template shaped<Tin, 2>(bcast.x_reshape()), in Compute()
194 BCast::ToIndexArray<2>(bcast.x_bcast()), in Compute()
195 in1.template shaped<Tin, 2>(bcast.y_reshape()), in Compute()
196 BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr); in Compute()
199 eigen_device, out->shaped<Tout, 3>(bcast.result_shape()), in Compute()
200 in0.template shaped<Tin, 3>(bcast.x_reshape()), in Compute()
[all …]
Dcwise_ops_common.cc58 bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) { in BinaryOpState()
59 if (!bcast.IsValid()) { in BinaryOpState()
77 const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); in BinaryOpState()
84 ndims = static_cast<int>(bcast.x_reshape().size()); in BinaryOpState()
Drandom_binomial_op.cc174 const BCast& bcast, typename TTypes<T>::ConstFlat counts, in operator ()()
183 auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs, in operator ()()
189 const bool should_bcast = bcast.IsBroadcastingRequired(); in operator ()()
190 const auto& counts_batch_indices = bcast.x_batch_indices(); in operator ()()
191 const auto& probs_batch_indices = bcast.y_batch_indices(); in operator ()()
340 tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(), in Compute() local
344 OP_REQUIRES(ctx, bcast.IsValid(), in Compute()
360 TensorShape bcast_shape = BCast::ToShape(bcast.output_shape()); in Compute()
381 (shape_tensor.dim_size(0) - bcast.output_shape().size()); in Compute()
428 samples_per_batch, num_elements, bcast, in Compute()
Dcwise_op_select.cc172 BCast bcast( in Compute() local
176 OP_REQUIRES(ctx, bcast.IsValid(), in Compute()
184 BCast cond_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), in Compute()
186 BCast then_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), in Compute()
188 BCast else_bcast(BCast::FromShape(BCast::ToShape(bcast.output_shape())), in Compute()
201 cond_bcast.output_shape() == bcast.output_shape() && in Compute()
202 then_bcast.output_shape() == bcast.output_shape() && in Compute()
203 else_bcast.output_shape() == bcast.output_shape(), in Compute()
210 const TensorShape output_shape = BCast::ToShape(bcast.output_shape()); in Compute()
222 output->shaped<T, NDIMS>(bcast.result_shape()), \ in Compute()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dbcast_ops.cc52 BCast bcast(shapes[0], shapes[1]); in Compile() local
53 OP_REQUIRES(ctx, bcast.IsValid(), in Compile()
58 const int64 len = bcast.output_shape().size(); in Compile()
61 output.flat<int32>()(i) = static_cast<int32>(bcast.output_shape()[i]); in Compile()
102 BCast bcast(shapes[0], shapes[1]); in Compile() local
103 OP_REQUIRES(ctx, bcast.IsValid(), in Compile()
107 Output(ctx, 0, bcast.grad_x_reduce_idx()); in Compile()
108 Output(ctx, 1, bcast.grad_y_reduce_idx()); in Compile()
Dselect_op.cc104 BCast bcast(bcast_then_else.output_shape(), BCast::FromShape(cond_shape), in Compile() local
106 if (!bcast.IsValid()) { in Compile()
114 auto bcasted_cond = BroadcastTo(ctx->Input(0), bcast.output_shape()); in Compile()
118 auto bcasted_then = BroadcastTo(ctx->Input(1), bcast.output_shape()); in Compile()
122 auto bcasted_else = BroadcastTo(ctx->Input(2), bcast.output_shape()); in Compile()
Dcwise_ops.cc43 BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), in Compile() local
45 if (!bcast.IsValid()) { in Compile()
81 rhs_shape.dim_sizes(), bcast, extend_dimension); in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_rematerialization_test.cc70 const HloInstruction* bcast = concat->operand(0); in TEST_F() local
84 EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast))); in TEST_F()
340 auto bcast = builder.AddInstruction( in TEST_F() local
343 vec1024_shape_, HloOpcode::kAdd, bcast, bcast)); in TEST_F()
347 vec1024_shape_, HloOpcode::kAdd, bcast, call_1)); in TEST_F()
351 vec1024_shape_, HloOpcode::kAdd, bcast, call_2)); in TEST_F()
355 vec1024_shape_, HloOpcode::kAdd, bcast, call_3)); in TEST_F()
374 EXPECT_EQ(add_2->operand(0), bcast); in TEST_F()
375 EXPECT_EQ(add_3->operand(0), bcast); in TEST_F()
376 EXPECT_EQ(add_4->operand(0), bcast); in TEST_F()
[all …]
Dhlo_rematerialization_test_utils.h64 auto bcast = builder.AddInstruction( variable
67 HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
76 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
109 auto bcast = builder.AddInstruction( variable
112 HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
118 ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst},
/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dunroll_batch_matmul.cc162 const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type, in createMatMulOps() argument
167 for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { in createMatMulOps()
169 if (bcast.IsBroadcastingRequired()) { in createMatMulOps()
170 lhs_batch_idx = bcast.x_batch_indices()[batch_idx]; in createMatMulOps()
171 rhs_batch_idx = bcast.y_batch_indices()[batch_idx]; in createMatMulOps()
187 {bcast.output_batch_size(), rows, cols}, element_type); in createMatMulOps()
279 tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>( in matchAndRewrite() local
284 if (!bcast.IsValid()) { in matchAndRewrite()
291 sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter); in matchAndRewrite()
293 sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter); in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/lite/toco/graph_transformations/
Dunroll_batch_matmul.cc173 ::tensorflow::MatMulBCast bcast( in Run()
176 CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable"; in Run()
197 bcast.output_batch_size()); in Run()
202 SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a, in Run()
205 SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b, in Run()
211 for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { in Run()
214 const int a_batch_idx = bcast.IsBroadcastingRequired() in Run()
215 ? bcast.x_batch_indices()[batch_idx] in Run()
217 const int b_batch_idx = bcast.IsBroadcastingRequired() in Run()
218 ? bcast.y_batch_indices()[batch_idx] in Run()
[all …]
/external/eigen/unsupported/test/
Dcxx11_tensor_forced_eval.cpp61 Eigen::array<int, 2> bcast; in test_const() local
62 bcast[0] = 3; in test_const()
63 bcast[1] = 1; in test_const()
65 …t_tensor= (input_tensor - input_tensor.maximum(depth_dim).eval().reshape(dims2d).broadcast(bcast)); in test_const()
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dbroadcast.cc102 BCast bcast(BCast::FromShape(lhs_tf_shape), BCast::FromShape(rhs_tf_shape)); in BroadcastOpsToSame() local
103 if (!bcast.IsValid()) { in BroadcastOpsToSame()
107 TF_ASSIGN_OR_RETURN(*lhs, BroadcastTo(*lhs, bcast.output_shape())); in BroadcastOpsToSame()
108 TF_ASSIGN_OR_RETURN(*rhs, BroadcastTo(*rhs, bcast.output_shape())); in BroadcastOpsToSame()
/external/tensorflow/tensorflow/core/grappler/utils/
Dsymbolic_shapes.cc102 BCast bcast(ShapeDims(left), ShapeDims(right), in ShapesBroadcastable() local
104 return bcast.IsValid(); in ShapesBroadcastable()
118 BCast bcast(ShapeDims(left), ShapeDims(right), in ShapeAfterBroadcast() local
120 if (!bcast.IsValid()) { in ShapeAfterBroadcast()
125 for (const auto& dim : bcast.output_shape()) { in ShapeAfterBroadcast()

123