Home
last modified time | relevance | path

Searched refs:accumulation_type (Results 1 – 14 of 14) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dlrn_ops.cc51 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); in Compile() local
52 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); in Compile()
55 squared, XlaHelpers::Zero(builder, accumulation_type), in Compile()
56 *ctx->GetOrCreateAdd(accumulation_type), in Compile()
137 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); in Compile() local
139 XlaHelpers::ConvertElementType(in_image, accumulation_type); in Compile()
142 squared, XlaHelpers::Zero(builder, accumulation_type), in Compile()
143 *ctx->GetOrCreateAdd(accumulation_type), in Compile()
157 auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); in Compile()
159 converted_dy, XlaHelpers::Zero(builder, accumulation_type), in Compile()
[all …]
Dl2loss_op.cc38 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); in Compile() local
39 auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type); in Compile()
41 auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type), in Compile()
42 *ctx->GetOrCreateAdd(accumulation_type), dims); in Compile()
Dfake_quantize_ops.cc198 const DataType accumulation_type = in Compile() local
218 XlaHelpers::ConvertElementType(select1, accumulation_type), in Compile()
219 XlaHelpers::Zero(b, accumulation_type), in Compile()
220 *ctx->GetOrCreateAdd(accumulation_type)); in Compile()
227 XlaHelpers::ConvertElementType(select2, accumulation_type), in Compile()
228 XlaHelpers::Zero(b, accumulation_type), in Compile()
229 *ctx->GetOrCreateAdd(accumulation_type)); in Compile()
313 const DataType accumulation_type = in Compile() local
348 xla::Reduce(XlaHelpers::ConvertElementType(select1, accumulation_type), in Compile()
349 XlaHelpers::Zero(b, accumulation_type), in Compile()
[all …]
Dsoftmax_op.cc66 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); in CrossEntropyWithLogits() local
68 XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type); in CrossEntropyWithLogits()
70 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in CrossEntropyWithLogits()
71 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); in CrossEntropyWithLogits()
87 auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type), in CrossEntropyWithLogits()
88 XlaHelpers::Zero(b, accumulation_type), in CrossEntropyWithLogits()
89 *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); in CrossEntropyWithLogits()
Dlower_upper_bound_ops.cc67 const DataType accumulation_type = XlaHelpers::SumAccumulationType(out_dtype); in BuildLowerUpperBoundOp() local
71 XlaHelpers::ConvertElementType(comparison, accumulation_type); in BuildLowerUpperBoundOp()
77 xla::Reduce(comparison_int, XlaHelpers::Zero(builder, accumulation_type), in BuildLowerUpperBoundOp()
78 *ctx->GetOrCreateAdd(accumulation_type), {2}); in BuildLowerUpperBoundOp()
Dbatch_norm_op.cc313 const DataType accumulation_type = in Compile() local
316 XlaHelpers::ConvertElementType(grad_backprop, accumulation_type); in Compile()
318 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
319 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); in Compile()
329 converted = XlaHelpers::ConvertElementType(mul, accumulation_type); in Compile()
331 xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
332 *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); in Compile()
Dimage_ops.cc204 const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); in Compile() local
205 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); in Compile()
206 auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), in Compile()
207 *context->GetOrCreateAdd(accumulation_type), in Compile()
211 reduce, XlaHelpers::FloatLiteral(b, accumulation_type, height * width)); in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/
Dresult_caster_test.cc52 const PrimitiveType accumulation_type = in TEST_P() local
54 const bool should_cast = result_type != accumulation_type; in TEST_P()
62 primitive_util::LowercasePrimitiveTypeName(accumulation_type)); in TEST_P()
/external/tensorflow/tensorflow/lite/toco/
Dtoco_cmdline_flags.cc198 Flag("accumulation_type", parsed_flags.accumulation_type.bind(), in ParseTocoFlagsFromCommandLineFlags()
199 parsed_flags.accumulation_type.default_value(), in ParseTocoFlagsFromCommandLineFlags()
302 PARSE_TOCO_FLAG(IODataType, accumulation_type, FlagRequirement::kNone); in ReadTocoFlagsFromCommandLineFlags()
Dargs.h198 Arg<std::string> accumulation_type; member
Dtoco_flags.proto251 optional IODataType accumulation_type = 37; field
/external/tensorflow/tensorflow/lite/python/
Dconvert.py491 accumulation_type=None, argument
631 if accumulation_type:
632 conversion_flags.accumulation_type = convert_tensor_tf_type_to_tflite_type(
633 accumulation_type, usage="accumulation_type flag")
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dpooling.cc81 PrimitiveType accumulation_type = init_shape.element_type(); in ComputeSums() local
82 auto add_computation = CreateScalarAddComputation(accumulation_type, b); in ComputeSums()
/external/tensorflow/tensorflow/compiler/mlir/lite/python/
Dtf_tfl_flatbuffer_helpers.cc294 if (toco_flags.accumulation_type() == toco::IODataType::FLOAT16) { in PopulateQuantizationSpecs()