Home
last modified time | relevance | path

Searched refs:side_input (Results 1 – 25 of 26) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dgpu_fusion.cc72 Value side_input; in matchAndRewrite() local
82 side_input = add_op.y(); in matchAndRewrite()
88 side_input = add_op.x(); in matchAndRewrite()
99 if (side_input) state.operands.push_back(side_input); in matchAndRewrite()
/external/tensorflow/tensorflow/core/kernels/
Dfused_batch_norm_op.cu.cc138 const T* __restrict__ side_input, float epsilon, in run()
159 shifted_v += U(side_input[index]); in run()
205 const IT* side_input = reinterpret_cast<const IT*>(_side_input); in run() local
242 reinterpret_cast<const half2*>(side_input)[index]); in run()
274 shifted_v = __hadd(shifted_v, side_input[index]); in run()
299 const T* side_input, float epsilon, T* out) { in FusedBatchNormInferenceMetaKernel() argument
311 scale, offset, mean, var, side_input, in FusedBatchNormInferenceMetaKernel()
323 typename TTypes<T, 4>::ConstTensor side_input, U epsilon, in operator ()()
353 estimated_mean.data(), estimated_variance.data(), side_input.data(), \ in operator ()()
356 const bool no_side_input = side_input.dimensions().TotalSize() == 0; in operator ()()
[all …]
Dfused_batch_norm_ex_op_test.cc159 Output side_input = ops::Const(root.WithOpName("side_input"), in RunFusedBatchNorm() local
171 ops::Add(root.WithOpName("with_side_input"), fwd.y, side_input); in RunFusedBatchNorm()
250 Output side_input = ops::Const(root.WithOpName("side_input"), in RunFusedBatchNormEx() local
261 side_inputs.push_back({side_input.name(), 0, t_dtype}); in RunFusedBatchNormEx()
385 Tensor side_input(t_dtype, input_shape); in VerifyTensorsNear() local
386 side_input.flat<T>().setRandom(); in VerifyTensorsNear()
387 side_input.flat<T>() += side_input.flat<T>().constant(static_cast<T>(5.0)); in VerifyTensorsNear()
402 is_training ? empty : var, side_input, &fbn_forward, in VerifyTensorsNear()
407 run_default(y_backprop, side_input, scale, offset, in VerifyTensorsNear()
412 is_training ? empty : var, side_input, &fbn_ex_forward, in VerifyTensorsNear()
Dfused_batch_norm_op.cc97 const Tensor* side_input, U epsilon, U exponential_avg_factor, in operator ()()
103 OP_REQUIRES(context, side_input == nullptr, in operator ()()
238 const Tensor* side_input, U epsilon, U exponential_avg_factor, in operator ()()
244 OP_REQUIRES(context, side_input == nullptr, in operator ()()
772 const Tensor& estimated_variance, const Tensor* side_input, in operator ()()
839 const bool has_side_input = side_input != nullptr; in operator ()()
851 side_input->tensor<T, 4>(), epsilon, activation_mode, in operator ()()
922 side_input != nullptr in operator ()()
923 ? StreamExecutorUtil::AsDeviceMemory<T>(*side_input) in operator ()()
1192 typename TTypes<T, 4>::ConstTensor side_input, U epsilon, \
[all …]
Dfused_batch_norm_op.h50 typename TTypes<T, 4>::ConstTensor side_input, U epsilon,
/external/tensorflow/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/
Dconv_to_jitrt.mlir271 %side_input: memref<1x3x3x64xf64, #map0>,
275 // CHECK: call @xla.gpu.conv.forward.fused.side_input(
282 %input, %filter, %bias, %side_input, %output, %scratch)
316 // CHECK: func private @xla.gpu.conv.forward.fused.side_input(
321 // CHECK-SAME: "xla.gpu.conv.forward.fused.side_input"}
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dgpu_conv_runner.cc124 se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf); in RunGpuConvForwardActivation() local
126 if (side_input.is_null()) { in RunGpuConvForwardActivation()
138 side_input = output_buf; in RunGpuConvForwardActivation()
174 filter_buf, side_input, params.fusion->bias_buf, output_buf); in RunGpuConvForwardActivation()
Dcudnn_fused_conv_rewriter.cc347 HloInstruction* side_input; in FuseSideInputAlpha() local
350 .WithOperand(3, m::Op(&side_input)); in FuseSideInputAlpha()
369 HloInstruction* before_reshape = side_input; in FuseSideInputAlpha()
432 new_operands[3] = clone(side_input); in FuseSideInputAlpha()
Djitrt_custom_calls.cc721 Optional<SideInputAttrs> side_input = llvm::None) { in GetConvDescriptor() argument
794 if (side_input.has_value()) in GetConvDescriptor()
796 side_input->side_input_scale); in GetConvDescriptor()
809 Optional<runtime::StridedMemrefView> side_input, in operator ()()
845 if (side_input.has_value()) in operator ()()
846 buffers.push_back(GetDeviceAddress(*side_input)); in operator ()()
/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/
Dlhlo_gpu_ops.td107 // side_input * side_input_scale +
116 Arg<LHLO_Buffer, "", [MemRead]>:$side_input,
/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo_gpu/
Dlhlo_gpu_ops.mlir163 …17x9x9xf16>, %filter : memref<3x3x17x32xf16>, %bias : memref<32xf16>, %side_input: memref<32xf16>…
165 …lmhlo_gpu.conv_forward_fused_with_side_input(%input, %filter, %bias, %side_input, %output, %scratc…
/external/tensorflow/tensorflow/core/grappler/optimizers/
Dremapper.cc119 int side_input = kMissingIndex; member
1764 matched->side_input = add_regular_fanin_1.node_index(); in FindFusedBatchNormEx()
1772 matched->side_input = add_regular_fanin_0.node_index(); in FindFusedBatchNormEx()
1858 fwd_matched.side_input == kMissingIndex; in FindFusedBatchNormGradEx()
1860 fwd_matched.side_input != kMissingIndex; in FindFusedBatchNormGradEx()
2774 << (matched.side_input != kMissingIndex in AddFusedBatchNormExNode()
2775 ? graph->node(matched.side_input).name() in AddFusedBatchNormExNode()
2800 if (matched.side_input != kMissingIndex) { in AddFusedBatchNormExNode()
2802 const NodeDef& side_input = graph->node(matched.side_input); in AddFusedBatchNormExNode() local
2803 fused_op.add_input(side_input.name()); // 5: side_input in AddFusedBatchNormExNode()
[all …]
Dmkl_remapper_test.cc342 auto side_input = in TEST_F() local
346 ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_FLOAT); in TEST_F()
Dremapper_test.cc383 auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT, in TEST_F() local
386 ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF); in TEST_F()
480 auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT, in TEST_F() local
483 ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF); in TEST_F()
/external/tensorflow/tensorflow/core/util/autotune_maps/
Dconv_parameters.proto42 // activation function (including no activation) as well as the side_input.
/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/
Dcuda_dnn.h284 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
298 const DeviceMemory<Eigen::half>& side_input,
543 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
Dcuda_dnn.cc5413 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc, in DoBatchNormalizationForward() argument
5424 offset, estimated_mean, estimated_variance, side_input, x_desc, in DoBatchNormalizationForward()
5436 const DeviceMemory<Eigen::half>& side_input, in DoBatchNormalizationForward() argument
5448 estimated_mean, estimated_variance, side_input, x_desc, in DoBatchNormalizationForward()
5462 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc, in DoBatchNormalizationForwardImpl() argument
5486 if (side_input.is_null()) { in DoBatchNormalizationForwardImpl()
5522 !side_input.is_null()) { in DoBatchNormalizationForwardImpl()
5564 /*zData=*/side_input.opaque(), in DoBatchNormalizationForwardImpl()
/external/tensorflow/tensorflow/stream_executor/rocm/
Drocm_dnn.h288 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
302 const DeviceMemory<Eigen::half>& side_input,
647 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
Drocm_dnn.cc3641 const DeviceMemory<Eigen::half>& side_input, in DoBatchNormalizationForward() argument
3652 estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, in DoBatchNormalizationForward()
3662 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc, in DoBatchNormalizationForward() argument
3672 estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, in DoBatchNormalizationForward()
3684 const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc, in DoBatchNormalizationForwardImpl() argument
/external/tensorflow/tensorflow/core/grappler/costs/
Dop_level_cost_estimator.cc1778 auto& side_input = op_context.op_info.inputs(3); in PredictFusedConv2DBiasActivation() local
1804 if (side_input.shape().dim_size() > 0) { in PredictFusedConv2DBiasActivation()
1805 component_ops.push_back(FusedChildContext(op_context, "Mul", side_input, in PredictFusedConv2DBiasActivation()
1806 {side_input, side_input_scale})); in PredictFusedConv2DBiasActivation()
Dop_level_cost_estimator_test.cc281 auto side_input = op_context.op_info.add_inputs(); in DescribeFusedConv2DBiasActivation() local
284 DescribeTensor4D(batch, ox, oy, oz, side_input); in DescribeFusedConv2DBiasActivation()
286 DescribeTensor4D(batch, oz, ox, oy, side_input); in DescribeFusedConv2DBiasActivation()
291 DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input); in DescribeFusedConv2DBiasActivation()
/external/tensorflow/tensorflow/compiler/xla/stream_executor/
Dstream.cc333 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc, in ThenBatchNormalizationForward() argument
345 this, x, scale, offset, estimated_mean, estimated_variance, side_input, in ThenBatchNormalizationForward()
386 const DeviceMemory<Eigen::half> &side_input, in ThenBatchNormalizationForward() argument
399 this, x, scale, offset, estimated_mean, estimated_variance, side_input, in ThenBatchNormalizationForward()
Ddnn.h1167 const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc, in DoBatchNormalizationForward() argument
1186 const DeviceMemory<Eigen::half>& side_input, in DoBatchNormalizationForward() argument
Dstream.h257 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
283 const DeviceMemory<Eigen::half> &side_input,
/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/
Dconvert_nodes.cc3139 TRT_ShapedWeights side_input = inputs.at(3).weights(); in ConvertFusedConv2DBiasActivation() local
3140 if (side_input.count() != 0) { in ConvertFusedConv2DBiasActivation()

12