Home
last modified time | relevance | path

Searched refs:xla (Results 1 – 25 of 2370) sorted by relevance

12345678910>>...95

/external/tensorflow/tensorflow/core/tpu/kernels/
Dtopk_ops.cc43 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder, in CreateKthOrderStatisticComputation()
45 const xla::XlaOp input, in CreateKthOrderStatisticComputation()
46 const xla::XlaOp k) { in CreateKthOrderStatisticComputation()
50 xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32); in CreateKthOrderStatisticComputation()
51 xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0); in CreateKthOrderStatisticComputation()
52 xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height}); in CreateKthOrderStatisticComputation()
53 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width}); in CreateKthOrderStatisticComputation()
55 xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF); in CreateKthOrderStatisticComputation()
56 xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height}); in CreateKthOrderStatisticComputation()
59 xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000); in CreateKthOrderStatisticComputation()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Dutil.cc29 xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { in Zeros()
30 return xla::Broadcast( in Zeros()
31 xla::ConstantLiteral(builder, in Zeros()
32 xla::LiteralUtil::Zero(shape.element_type())), in Zeros()
36 xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, in FloatLiteral()
39 case xla::F16: in FloatLiteral()
40 return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value)); in FloatLiteral()
42 case xla::BF16: in FloatLiteral()
43 return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value)); in FloatLiteral()
45 case xla::F32: in FloatLiteral()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dimage_ops.cc45 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, in RGBToHSV()
46 const std::array<xla::XlaOp, 3>& rgb, in RGBToHSV()
54 auto value = xla::Max(xla::Max(red, green), blue); in RGBToHSV()
55 auto minimum = xla::Min(xla::Min(red, green), blue); in RGBToHSV()
56 auto range = xla::Sub(value, minimum); in RGBToHSV()
58 auto zeros = xla::Broadcast(zero, shape.dim_sizes()); in RGBToHSV()
60 xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros); in RGBToHSV()
62 auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); in RGBToHSV()
65 xla::Select(xla::Eq(green, value), in RGBToHSV()
66 xla::Add(xla::Mul(norm, xla::Sub(blue, red)), in RGBToHSV()
[all …]
Dunique_op.cc57 xla::XlaOp MoveAxis(xla::XlaOp a, int64_t from, int64_t to, in MoveAxis()
58 const xla::Shape& input_shape) { in MoveAxis()
65 return xla::Transpose(a, permutation); in MoveAxis()
68 xla::XlaOp CumSumR1(XlaOpKernelContext* ctx, xla::XlaOp input, int64_t size) { in CumSumR1()
69 auto init = xla::Zero(ctx->builder(), xla::S32); in CumSumR1()
70 auto reducer = xla::CreateScalarAddComputation(xla::S32, ctx->builder()); in CumSumR1()
72 return xla::ReduceWindowWithGeneralPadding( in CumSumR1()
86 xla::XlaOp RollingSelectR1(XlaOpKernelContext* ctx, xla::XlaOp data, in RollingSelectR1()
87 xla::XlaOp mask, int64_t size) { in RollingSelectR1()
88 xla::XlaComputation cond, body; in RollingSelectR1()
[all …]
Dtensor_list_utils.cc115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { in IsTensorListInitialized()
116 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); in IsTensorListInitialized()
121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { in IsNestedTensorList()
127 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); in IsNestedTensorList()
128 *is_nested_list = (xla::ShapeUtil::TupleElementCount(list_shape) > 2); in IsNestedTensorList()
132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, in BuildNonNestedTensorList()
133 xla::XlaOp* output_list) { in BuildNonNestedTensorList()
135 *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); in BuildNonNestedTensorList()
139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { in GetTensorListBufferShape()
145 TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); in GetTensorListBufferShape()
[all …]
Dbinary_ops.cc43 xla::XlaOp Computation( \
44 XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \
45 const absl::Span<const int64_t>& lhs_shape, const xla::XlaOp& rhs, \
49 xla::XlaBuilder* b = ctx->builder(); \
59 XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
60 XLA_MAKE_BINARY(AddV2, xla::Add(lhs, rhs, extend_dimensions));
61 XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
62 XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
63 XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
65 XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
[all …]
Dfake_quantize_ops.cc49 void XlaNudge(xla::XlaBuilder* b, const DataType data_type, in XlaNudge()
50 const xla::XlaOp& min, const xla::XlaOp& max, in XlaNudge()
52 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, in XlaNudge()
53 xla::XlaOp* scale) { in XlaNudge()
54 *scale = xla::Div(xla::Sub(max, min), in XlaNudge()
57 xla::XlaOp quant_min = in XlaNudge()
59 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); in XlaNudge()
60 xla::XlaOp quant_max = in XlaNudge()
62 xla::XlaOp nudged_zero_point = in XlaNudge()
63 xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min, in XlaNudge()
[all …]
Dreduction_ops.cc35 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
36 return xla::Zero(builder, xla_reduction_type_); in InitialValue()
38 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
39 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
40 xla::Add(scalar_lhs, scalar_rhs); in BuildReducer()
53 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
54 return xla::One(builder, xla_reduction_type_); in InitialValue()
57 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
58 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
59 xla::Mul(scalar_lhs, scalar_rhs); in BuildReducer()
[all …]
Dstateful_random_ops.cc42 xla::BitGeneratorTy BitGen(Algorithm alg) { in BitGen()
44 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen()
46 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); in BitGen()
47 xla::XlaOp result = in BitGen()
48 xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape); in BitGen()
49 xla::XlaOp data = xla::GetTupleElement(result, 1); in BitGen()
50 xla::XlaOp new_state = in BitGen()
51 xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1}); in BitGen()
52 return xla::RngOutput{data, new_state}; in BitGen()
55 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen()
[all …]
Dstateless_random_ops.cc40 xla::BitGeneratorTy GetBitGeneratorForDevice( in GetBitGeneratorForDevice()
46 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice()
47 std::tie(state, key) = xla::ScramblePhiloxKey(key); in GetBitGeneratorForDevice()
48 xla::XlaOp philox_state = in GetBitGeneratorForDevice()
49 xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0); in GetBitGeneratorForDevice()
50 xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, in GetBitGeneratorForDevice()
52 return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1), in GetBitGeneratorForDevice()
53 /*state=*/xla::GetTupleElement(result, 0)}; in GetBitGeneratorForDevice()
56 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice()
57 state = xla::ConcatScalars(key.builder(), {key, state}); in GetBitGeneratorForDevice()
[all …]
Dtensor_list_utils.h29 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized);
34 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list);
37 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
38 xla::XlaOp* output_list);
43 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape);
48 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer);
53 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index);
58 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
59 xla::XlaOp* result);
62 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
[all …]
Ddynamic_partition_op.cc50 xla::XlaOp CountS32(XlaOpKernelContext* ctx, xla::XlaOp input, in CountS32()
52 xla::XlaOp equal_dim = in CountS32()
53 xla::Compare(input, xla::ConstantR0<int32>(ctx->builder(), target), {}, in CountS32()
54 xla::ComparisonDirection::kEq); in CountS32()
55 xla::XlaOp casted = xla::ConvertElementType(equal_dim, xla::S32); in CountS32()
56 return xla::ReduceAll( in CountS32()
57 casted, xla::Zero(ctx->builder(), xla::S32), in CountS32()
58 xla::CreateScalarAddComputation(xla::S32, ctx->builder())); in CountS32()
61 std::pair<std::vector<xla::XlaOp>, std::vector<xla::XlaOp>>
62 DynamicPartition1D(XlaOpKernelContext* ctx, xla::XlaOp data_1d, in DynamicPartition1D()
[all …]
Din_topk_op.cc67 xla::XlaOp predictions_r2 = context->Input(0); in Compile()
68 xla::XlaOp targets_r1 = context->Input(1); in Compile()
70 xla::XlaBuilder* xla_builder = context->builder(); in Compile()
71 xla::XlaOp iota_r1 = in Compile()
72 xla::Iota(xla_builder, targets_type_, predictions_shape.dim_size(1)); in Compile()
73 xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {batch_size}); in Compile()
75 xla::XlaOp eq_r2 = xla::Eq(targets_r1, iota_r2, {0}); in Compile()
76 xla::XlaOp zero_r0_f32 = xla::Zero(xla_builder, xla::F32); in Compile()
77 xla::XlaOp zero_r2_f32 = xla::ZerosLike(predictions_r2); in Compile()
78 xla::XlaOp select_r2 = xla::Select(eq_r2, predictions_r2, zero_r2_f32); in Compile()
[all …]
Dimage_resize_ops.cc149 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder, in MakeBilinear1DKernel()
150 xla::PrimitiveType type, int64_t n) { in MakeBilinear1DKernel()
157 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type); in MakeBilinear1DKernel()
169 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder, in MakeNearestNeighbor1DKernel()
170 xla::PrimitiveType type, int64_t n) { in MakeNearestNeighbor1DKernel()
174 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type); in MakeNearestNeighbor1DKernel()
181 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder, in MakeGeneralResizeKernel()
182 xla::PrimitiveType type, in MakeGeneralResizeKernel()
191 xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]), in MakeGeneralResizeKernel()
194 return xla::Mul(depthwise_kernel, in MakeGeneralResizeKernel()
[all …]
/external/tensorflow/tensorflow/compiler/xla/client/lib/
DBUILD5 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test")
8 default_visibility = ["//tensorflow/compiler/xla/client:friends"],
30 "//tensorflow/compiler/xla:shape_util",
31 "//tensorflow/compiler/xla:status_macros",
32 "//tensorflow/compiler/xla:types",
33 "//tensorflow/compiler/xla:xla_data_proto_cc",
34 "//tensorflow/compiler/xla/client:xla_builder",
35 "//tensorflow/compiler/xla/client:xla_computation",
45 "//tensorflow/compiler/xla:literal_util",
46 "//tensorflow/compiler/xla:shape_util",
[all …]
Dlogdet_test.cc32 using LogDetTest = xla::ClientLibraryTestBase;
35 xla::XlaBuilder builder(TestName()); in XLA_TEST_F()
37 xla::Array2D<float> a_vals({ in XLA_TEST_F()
44 xla::XlaOp a; in XLA_TEST_F()
46 xla::SignAndLogDet slogdet = xla::SLogDet(a); in XLA_TEST_F()
47 xla::XlaOp logdet = xla::LogDet(a); in XLA_TEST_F()
48 xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet}); in XLA_TEST_F()
49 xla::Literal expected = xla::LiteralUtil::MakeTupleOwned( in XLA_TEST_F()
50 xla::LiteralUtil::CreateR0<float>(1.f), in XLA_TEST_F()
51 xla::LiteralUtil::CreateR0<float>(14.1601f), in XLA_TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dmhlo_to_lhlo_with_xla.h39 class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
45 LhloDialectEmitter(const xla::BufferAssignment& assignment, in LhloDialectEmitter()
46 const xla::HloComputation& computation, ModuleOp module) in LhloDialectEmitter()
53 xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
55 static xla::StatusOr<mhlo::ScatterDimensionNumbersAttr>
56 GetScatterDimensionNumbers(const xla::HloInstruction* instr,
60 xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
61 xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
62 xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
63 const xla::HloInstruction* instr);
[all …]
/external/tensorflow/tensorflow/compiler/xla/tests/
DBUILD4 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend…
26 "//tensorflow/compiler/xla:friends",
50 "//tensorflow/compiler/xla:debug_options_flags",
87 "//tensorflow/compiler/xla:literal",
88 "//tensorflow/compiler/xla:literal_util",
89 "//tensorflow/compiler/xla:shape_util",
90 "//tensorflow/compiler/xla:xla_data_proto_cc",
91 "//tensorflow/compiler/xla/service:hlo",
92 "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
93 "//tensorflow/compiler/xla/service:hlo_verifier",
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
DBUILD51 "//tensorflow/compiler/xla:friends",
96 "//tensorflow/compiler/xla:xla_data_proto",
97 "//tensorflow/compiler/xla/stream_executor:dnn_proto",
108 "//tensorflow/compiler/xla:status_macros",
109 "//tensorflow/compiler/xla:statusor",
110 "//tensorflow/compiler/xla:types",
111 "//tensorflow/compiler/xla/service:executable",
112 "//tensorflow/compiler/xla/service:global_device_id",
123 "//tensorflow/compiler/xla:types",
132 "//tensorflow/compiler/xla:types",
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
DBUILD5 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
6 load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library", "xla_py_test_deps")
43 "//tensorflow/compiler/xla:friends",
51 protodeps = ["//tensorflow/compiler/xla:xla_data_proto"],
106 "//tensorflow/compiler/xla:util",
107 "//tensorflow/compiler/xla/tests:hlo_test_base",
108 "//tensorflow/compiler/xla/tests:xla_internal_test_main",
129 "//tensorflow/compiler/xla/tests:hlo_test_base",
130 "//tensorflow/compiler/xla/tests:xla_internal_test_main",
141 "//tensorflow/compiler/xla:statusor",
[all …]
/external/tensorflow/tensorflow/compiler/xla/client/
DBUILD16 "//tensorflow/compiler/xla:friends",
34 "//tensorflow/compiler/xla:service_interface",
35 "//tensorflow/compiler/xla:types",
36 "//tensorflow/compiler/xla:xla_data_proto_cc",
37 "//tensorflow/compiler/xla:xla_proto_cc",
49 "//tensorflow/compiler/xla:statusor",
50 "//tensorflow/compiler/xla:types",
51 "//tensorflow/compiler/xla:util",
74 "//tensorflow/compiler/xla:debug_options_flags",
75 "//tensorflow/compiler/xla:execution_options_util",
[all …]
/external/tensorflow/tensorflow/compiler/xla/tools/
DBUILD15 load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
18 default_visibility = ["//tensorflow/compiler/xla:internal"],
29 visibility = ["//tensorflow/compiler/xla:internal"],
36 "//tensorflow/compiler/xla:types",
49 "//tensorflow/compiler/xla:shape_util",
50 "//tensorflow/compiler/xla:statusor",
51 "//tensorflow/compiler/xla:types",
52 "//tensorflow/compiler/xla:xla_data_proto_cc",
53 "//tensorflow/compiler/xla/client",
54 "//tensorflow/compiler/xla/client:client_library",
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/gpu/tests/
DBUILD8 load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
32 "//tensorflow/compiler/xla:friends",
52 "//tensorflow/compiler/xla:debug_options_flags",
53 "//tensorflow/compiler/xla:shape_util",
54 "//tensorflow/compiler/xla:types",
55 "//tensorflow/compiler/xla/service:gpu_plugin",
56 "//tensorflow/compiler/xla/service/gpu:gpu_executable",
57 "//tensorflow/compiler/xla/tests:filecheck",
58 "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
59 "//tensorflow/compiler/xla/tests:verified_hlo_module",
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
DBUILD14 "//tensorflow/compiler/xla:friends",
36 "//tensorflow/compiler/xla:comparison_util",
37 "//tensorflow/compiler/xla:literal_util",
38 "//tensorflow/compiler/xla:protobuf_util",
39 "//tensorflow/compiler/xla:shape_util",
40 "//tensorflow/compiler/xla:status",
41 "//tensorflow/compiler/xla:util",
42 "//tensorflow/compiler/xla:window_util",
43 "//tensorflow/compiler/xla:xla_data_proto_cc",
44 "//tensorflow/compiler/xla/client:xla_builder",
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
DBUILD5 load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
31 "//tensorflow/compiler/xla:friends",
116 "//tensorflow/compiler/xla:literal",
117 "//tensorflow/compiler/xla:literal_util",
118 "//tensorflow/compiler/xla:shape_util",
119 "//tensorflow/compiler/xla:status",
120 "//tensorflow/compiler/xla:status_macros",
121 "//tensorflow/compiler/xla:statusor",
122 "//tensorflow/compiler/xla:types",
123 "//tensorflow/compiler/xla:util",
[all …]

12345678910>>...95