Home
last modified time | relevance | path

Searched refs:XlaOp (Results 1 – 25 of 177) sorted by relevance

12345678

/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h52 class XlaOp {
54 XlaOp() : handle_(-1), builder_(nullptr) { in XlaOp() function
55 static_assert(std::is_trivially_destructible<XlaOp>::value, in XlaOp()
58 ~XlaOp() = default;
60 XlaOp(const XlaOp& other) = default;
61 XlaOp& operator=(const XlaOp& other) = default;
81 bool IsIdenticalTo(const XlaOp& rhs) const { in IsIdenticalTo()
85 friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
91 explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} in XlaOp() function
92 XlaOp(int64 handle, XlaBuilder* builder) in XlaOp() function
[all …]
Dxla_builder.cc74 XlaOp operator-(const XlaOp& x) { return Neg(x); } in operator -()
75 XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); } in operator +()
76 XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); } in operator -()
77 XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); } in operator *()
78 XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); } in operator /()
79 XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); } in operator %()
81 XlaOp operator~(const XlaOp& x) { return Not(x); } in operator ~()
82 XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); } in operator &()
83 XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); } in operator |()
84 XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); } in operator ^()
[all …]
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmath.h26 XlaOp IsPosInf(XlaOp operand);
27 XlaOp IsNegInf(XlaOp operand);
28 XlaOp IsInf(XlaOp operand);
29 XlaOp IsNan(XlaOp operand);
34 XlaOp IsNegZero(XlaOp operand);
38 XlaOp NextAfter(XlaOp from, XlaOp to);
41 XlaOp Square(XlaOp operand);
44 XlaOp Reciprocal(XlaOp operand);
48 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients);
51 XlaOp Erfc(XlaOp x);
[all …]
Dmath.cc33 static XlaOp DoWithUpcastToF32(XlaOp operand, in DoWithUpcastToF32()
35 const std::function<XlaOp(XlaOp)>& operation) { in DoWithUpcastToF32()
37 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in DoWithUpcastToF32()
45 XlaOp result = operation(operand); in DoWithUpcastToF32()
55 static Status EnsureOperandIsRealFp(absl::string_view op_name, XlaOp operand) { in EnsureOperandIsRealFp()
67 XlaOp IsPosInf(XlaOp operand) { in IsPosInf()
69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in IsPosInf()
78 XlaOp IsNegInf(XlaOp operand) { in IsNegInf()
80 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in IsNegInf()
89 XlaOp IsInf(XlaOp operand) { in IsInf()
[all …]
Dslicing.h27 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);
31 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
36 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
40 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
43 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
44 absl::Span<const XlaOp> starts);
58 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim);
66 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim);
Dmatrix.h31 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
41 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
45 XlaOp TriangleMask(XlaOp x, int diagonal);
48 XlaOp Triangle(XlaOp x, bool lower);
51 XlaOp UpperTriangle(XlaOp x);
54 XlaOp LowerTriangle(XlaOp x);
73 xla::XlaOp BatchDot(
74 xla::XlaOp x, xla::XlaOp y,
97 xla::XlaOp Einsum(
98 xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
[all …]
Darithmetic.cc32 using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
58 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { in CreateScalarAddComputation()
67 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { in CreateScalarMultiplyComputation()
75 [](XlaBuilder* b, const XlaOp& lhs, in CreateScalarGeComputation()
76 const XlaOp& rhs) { return Ge(lhs, rhs); }); in CreateScalarGeComputation()
83 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { in CreateScalarMaxComputation()
92 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { in CreateScalarMinComputation()
101 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { in CreateScalarAndComputation()
109 [](XlaBuilder* b, const XlaOp& lhs, in CreateScalarOrComputation()
110 const XlaOp& rhs) { return Or(lhs, rhs); }); in CreateScalarOrComputation()
[all …]
Dprng.h29 using ThreeFry2x32State = std::array<XlaOp, 2>;
35 XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
36 XlaOp minval, XlaOp maxval);
40 XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
49 XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
55 ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
56 XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
Dslicing.cc21 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start, in SliceInMinorDims()
24 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in SliceInMinorDims()
52 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) { in UpdateSlice()
54 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSlice()
61 std::vector<XlaOp> start_ops(start.size()); in UpdateSlice()
69 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, in UpdateSliceInMinorDims()
72 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSliceInMinorDims()
94 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims( in PrependZerosInMajorDims()
95 XlaOp x, absl::Span<const XlaOp> starts) { in PrependZerosInMajorDims()
100 std::vector<XlaOp> padded_starts(n_dims, zero); in PrependZerosInMajorDims()
[all …]
Dsvd.cc50 XlaOp v;
51 XlaOp beta;
52 XlaOp a;
60 XlaOp c; // cosine.
61 XlaOp s; // sine.
66 XlaOp v;
67 XlaOp w;
79 XlaOp off_diagonal_norm;
80 XlaOp total_norm;
114 StatusOr<HouseHolderResult> HouseRow(XlaOp a, XlaOp i, XlaOp j, XlaOp eps, in HouseRow()
[all …]
Dconstants.h34 XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { in ConstantR0WithType()
85 XlaOp ScalarLike(XlaOp prototype, T value) { in ScalarLike()
87 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in ScalarLike()
99 XlaOp FullLike(XlaOp prototype, T value) { in FullLike()
101 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in FullLike()
115 XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
118 XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
121 XlaOp ZerosLike(XlaOp prototype);
124 XlaOp One(XlaBuilder* builder, PrimitiveType type);
128 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
[all …]
Dself_adjoint_eig.cc44 XlaOp c; // cosine.
45 XlaOp s; // sine.
50 XlaOp v;
51 XlaOp w;
55 XlaOp off_diagonal_norm;
56 XlaOp total_norm;
77 StatusOr<JacobiRotation> SymmetricShurDecomposition2x2(XlaOp a, XlaOp p, in SymmetricShurDecomposition2x2()
78 XlaOp q, XlaOp tol) { in SymmetricShurDecomposition2x2()
113 StatusOr<JacobiUpdate> Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q, in Update()
114 XlaOp tol, int64 n) { in Update()
[all …]
Dloops.h32 typedef std::function<StatusOr<XlaOp>(absl::Span<const XlaOp>, XlaBuilder*)>
37 typedef std::function<StatusOr<std::vector<XlaOp>>(absl::Span<const XlaOp>,
49 StatusOr<std::vector<XlaOp>> WhileLoopHelper(
52 absl::Span<const XlaOp> initial_values, absl::string_view name,
60 typedef std::function<StatusOr<std::vector<XlaOp>>(
61 XlaOp, absl::Span<const XlaOp>, XlaBuilder*)>
64 StatusOr<std::vector<XlaOp>> ForEachIndex(
67 absl::Span<const XlaOp> initial_values, absl::string_view name,
Dmatrix.cc41 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, in IdentityMatrix()
49 XlaOp GetMatrixDiagonal(XlaOp x, int k) { in GetMatrixDiagonal()
51 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in GetMatrixDiagonal()
91 XlaOp TriangleMask(XlaOp x, int diagonal) { in TriangleMask()
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in TriangleMask()
103 XlaOp indicator; in TriangleMask()
109 XlaOp Triangle(XlaOp x, bool lower) { in Triangle()
114 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); } in UpperTriangle()
116 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); } in LowerTriangle()
159 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y, in Einsum()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dfake_quantize_ops.cc49 const xla::XlaOp& min, const xla::XlaOp& max, in XlaNudge()
51 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, in XlaNudge()
52 xla::XlaOp* scale) { in XlaNudge()
56 xla::XlaOp quant_min = in XlaNudge()
58 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); in XlaNudge()
59 xla::XlaOp quant_max = in XlaNudge()
61 xla::XlaOp nudged_zero_point = in XlaNudge()
69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, in Quantize()
71 const xla::XlaOp& nudged_input_min, in Quantize()
72 const xla::XlaOp& nudged_input_max, in Quantize()
[all …]
Dvariable_ops.cc70 xla::XlaOp handle; in Compile()
96 xla::XlaOp handle; in Compile()
112 xla::XlaOp handle; in Compile()
132 xla::XlaOp resource_handle; in Compile()
139 xla::XlaOp gather; in Compile()
153 std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, in ResourceScatterOp() argument
165 xla::XlaOp var_value; in Compile()
169 const xla::XlaOp indices = context->Input(1); in Compile()
170 const xla::XlaOp updates = context->Input(2); in Compile()
181 const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
[all …]
Dimage_ops.cc36 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, in RGBToHSV()
37 const std::array<xla::XlaOp, 3>& rgb, in RGBToHSV() argument
69 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b, in HSVToRGB()
70 const std::array<xla::XlaOp, 3>& hsv, in HSVToRGB() argument
72 xla::XlaOp hue = hsv[0]; in HSVToRGB()
73 xla::XlaOp saturation = hsv[1]; in HSVToRGB()
74 xla::XlaOp value = hsv[2]; in HSVToRGB()
111 xla::XlaOp input = context->Input(0); in Compile()
113 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, in Compile()
116 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, in Compile()
[all …]
Dreduction_ops.cc35 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
38 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
39 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
53 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
57 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
58 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
71 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
75 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
76 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
89 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
[all …]
Dstateful_random_ops.cc38 std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter( in GetInputsFromCounter()
39 xla::XlaOp counter, const int64 size) { in GetInputsFromCounter()
51 std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32( in StatefulRngUniformU32()
52 xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { in StatefulRngUniformU32()
69 std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64( in StatefulRngUniformU64()
70 xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { in StatefulRngUniformU64()
81 std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key, in StatefulRngUniform()
82 xla::XlaOp counter, in StatefulRngUniform()
84 xla::XlaOp minval, in StatefulRngUniform()
85 xla::XlaOp maxval) { in StatefulRngUniform()
[all …]
Dsegment_reduction_ops.cc36 virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
39 virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
83 auto combiner = [this](xla::XlaOp a, xla::XlaOp b, in Compile()
101 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
104 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; in Combine()
116 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
119 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; in Combine()
131 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
134 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { in Combine()
148 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
[all …]
Dtraining_ops.cc34 xla::XlaOp handle; in Compile()
59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, in ProximalGradientDescentUpdate()
60 xla::XlaOp l1, xla::XlaOp l2, in ProximalGradientDescentUpdate()
61 xla::XlaOp grad) { in ProximalGradientDescentUpdate()
62 xla::XlaOp one = xla::ScalarLike(lr, 1.0); in ProximalGradientDescentUpdate()
63 xla::XlaOp zero = xla::ScalarLike(lr, 0.0); in ProximalGradientDescentUpdate()
64 xla::XlaOp prox_var = var - grad * lr; in ProximalGradientDescentUpdate()
65 xla::XlaOp l1_gt_zero = xla::Sign(prox_var) * in ProximalGradientDescentUpdate()
68 xla::XlaOp l1_le_zero = prox_var / (one + lr * l2); in ProximalGradientDescentUpdate()
80 xla::XlaOp var; in Compile()
[all …]
Dtensor_list_ops.cc50 xla::XlaOp index; in Compile()
65 int64 leading_dim, DataType dtype, xla::XlaOp* list) { in CreateZerosList()
68 xla::XlaOp element_shape_handle = ctx->Input(element_shape_index); in CreateZerosList()
108 xla::XlaOp buffer; in Compile()
111 xla::XlaOp output_list; in Compile()
144 xla::XlaOp buffer; in Compile()
148 xla::XlaOp output_list; in Compile()
214 xla::XlaOp state = ctx->Input(0); in Compile()
219 xla::XlaOp buffer; in Compile()
221 xla::XlaOp index = ctx->Input(1); in Compile()
[all …]
Dtensor_list_utils.h35 Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index,
36 xla::XlaOp* output_list);
39 Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer);
42 Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index);
45 Status GetTensorListBufferShape(const xla::XlaOp& op,
56 Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized);
61 Status InitializeTensorList(const xla::XlaOp& uninitialized_list,
63 xla::XlaOp* output_list);
Dstateless_random_ops.cc36 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { in MaybeConvertF32ToBF16()
48 xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) { in Uniform2NormalUsingSqrtErfinv()
58 xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType dtype, in StatelessRandomUniformImpl()
59 xla::XlaOp seed, xla::XlaOp minval, in StatelessRandomUniformImpl()
60 xla::XlaOp maxval) { in StatelessRandomUniformImpl()
61 xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); in StatelessRandomUniformImpl()
62 xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); in StatelessRandomUniformImpl()
85 xla::XlaOp seed = ctx->Input(1); in Compile()
89 xla::XlaOp uniform = StatelessRandomUniformImpl( in Compile()
133 xla::XlaOp seed = ctx->Input(1); in Compile()
[all …]
Dtensor_list_utils.cc31 Status BuildTensorList(const xla::XlaOp& buffer, const xla::XlaOp& push_index, in BuildTensorList()
32 xla::XlaOp* output_list) { in BuildTensorList()
38 Status GetTensorListBuffer(const xla::XlaOp& op, xla::XlaOp* buffer) { in GetTensorListBuffer()
44 Status GetTensorListPushIndex(const xla::XlaOp& op, xla::XlaOp* push_index) { in GetTensorListPushIndex()
50 Status GetTensorListBufferShape(const xla::XlaOp& op, in GetTensorListBufferShape()
67 Status IsTensorListInitialized(const xla::XlaOp& op, bool* is_initialized) { in IsTensorListInitialized()
74 Status InitializeTensorList(const xla::XlaOp& uninitialized_list, in InitializeTensorList()
76 xla::XlaOp* output_list) { in InitializeTensorList()
87 xla::XlaOp input_buffer; in InitializeTensorList()
95 xla::XlaOp push_index; in InitializeTensorList()

12345678