1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/ops.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "pybind11/attr.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/lib/comparators.h"
26 #include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
27 #include "tensorflow/compiler/xla/client/lib/math.h"
28 #include "tensorflow/compiler/xla/client/lib/qr.h"
29 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
30 #include "tensorflow/compiler/xla/client/lib/sorting.h"
31 #include "tensorflow/compiler/xla/client/lib/svd.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/python/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36
37 namespace xla {
38
39 namespace py = pybind11;
40
BuildOpsSubmodule(py::module * m)41 void BuildOpsSubmodule(py::module* m) {
42 // ops submodule, containing free functions that add operators to an
43 // XlaBuilder.
44 py::module ops = m->def_submodule("ops", "XLA operations");
45
46 py::enum_<TriangularSolveOptions::Transpose>(
47 ops, "TriangularSolveOptions_Transpose")
48 .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
49 .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
50 .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
51 .value("ADJOINT", TriangularSolveOptions::ADJOINT);
52
53 py::enum_<RandomAlgorithm>(ops, "RandomAlgorithm")
54 .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT)
55 .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY)
56 .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX);
57
58 ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
59 ops.def("AllGather", &AllGather, py::arg("operand"),
60 py::arg("all_gather_dimension"), py::arg("shard_count"),
61 py::arg("replica_groups") = py::list(),
62 py::arg("channel_id") = absl::nullopt,
63 py::arg("shape_with_layout") = absl::nullopt,
64 py::arg("use_global_device_ids") = absl::nullopt);
65 ops.def(
66 "AllReduce",
67 static_cast<XlaOp (*)(
68 XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
69 const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
70 &AllReduce),
71 py::arg("operand"), py::arg("computation"),
72 py::arg("replica_groups") = py::list(),
73 py::arg("channel_id") = absl::nullopt,
74 py::arg("shape_with_layout") = absl::nullopt);
75 ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
76 py::arg("concat_dimension"), py::arg("split_count"),
77 py::arg("replica_groups") = py::list(),
78 py::arg("layout") = absl::nullopt);
79 ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
80 py::arg("source_target_pairs"));
81 ops.def("CreateToken", &CreateToken, py::arg("builder"));
82 ops.def("CrossReplicaSum",
83 static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
84 &CrossReplicaSum),
85 py::arg("operand"), py::arg("replica_groups") = py::list());
86 ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
87 py::arg("new_element_type"));
88 ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
89 ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
90 py::arg("shape"), py::arg("broadcast_dimensions"));
91 ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
92 py::arg("operands"));
93 ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
94 ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
95 ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
96 ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
97 py::arg("dimension"));
98 ops.def("Conditional",
99 static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
100 absl::Span<const XlaOp>)>(&Conditional),
101 py::arg("branch_index"), py::arg("branch_computations"),
102 py::arg("branch_operands"));
103 ops.def("Conditional",
104 static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
105 const XlaComputation&)>(&Conditional),
106 py::arg("predicate"), py::arg("true_operand"),
107 py::arg("true_computation"), py::arg("false_operand"),
108 py::arg("false_computation"));
109 ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
110 ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
111 py::arg("literal"));
112 ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
113 py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
114 py::arg("lhs_dilation"), py::arg("rhs_dilation"),
115 py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
116 py::arg("batch_group_count") = 1,
117 py::arg("precision_config") = nullptr,
118 py::arg("preferred_element_type") = absl::nullopt);
119 ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
120 py::arg("new_element_type"));
121 ops.def(
122 "CustomCall",
123 [](XlaBuilder* builder, const py::bytes& call_target_name,
124 absl::Span<const XlaOp> operands, const Shape& shape,
125 const py::bytes& opaque, bool has_side_effect) -> XlaOp {
126 return CustomCall(builder, call_target_name, operands, shape, opaque,
127 has_side_effect);
128 },
129 py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
130 py::arg("shape"), py::arg("opaque") = py::bytes(""),
131 py::arg("has_side_effect") = false);
132 ops.def(
133 "CustomCallWithLayout",
134 [](XlaBuilder* builder, const py::bytes& call_target_name,
135 absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
136 absl::Span<const Shape> operand_shapes_with_layout,
137 const py::bytes& opaque, bool has_side_effect) -> XlaOp {
138 return CustomCallWithLayout(
139 builder, call_target_name, operands, shape_with_layout,
140 operand_shapes_with_layout, opaque, has_side_effect);
141 },
142 py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
143 py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
144 py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
145 ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
146 py::arg("precision_config") = nullptr,
147 py::arg("preferred_element_type") = absl::nullopt);
148 ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
149 py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
150 py::arg("preferred_element_type") = absl::nullopt);
151 ops.def("DynamicSlice",
152 static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
153 absl::Span<const int64>)>(&DynamicSlice),
154 py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
155 ops.def("DynamicUpdateSlice",
156 static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
157 &DynamicUpdateSlice),
158 py::arg("operand"), py::arg("update"), py::arg("start_indices"));
159
160 ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
161 py::arg("fft_length"));
162
163 ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
164 py::arg("dimension_numbers"), py::arg("slice_sizes"),
165 py::arg("indices_are_sorted") = false);
166 ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
167 py::arg("index"));
168 ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
169 py::arg("shape"), py::arg("config") = "");
170 ops.def("Iota",
171 static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
172 py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
173 ops.def("Iota",
174 static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
175 py::arg("builder"), py::arg("type"), py::arg("size"));
176 ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
177 py::arg("computation"), py::arg("dimensions"),
178 py::arg("static_operands") = py::list());
179 ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
180 ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
181 py::arg("token"), py::arg("shape_with_layout"),
182 py::arg("outfeed_config") = "");
183 ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
184 py::arg("padding_config"));
185 ops.def("Parameter",
186 static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
187 const std::string&, const std::vector<bool>&)>(
188 &Parameter),
189 py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
190 py::arg("name") = "",
191 py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
192 ops.def(
193 "QR",
194 [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
195 TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
196 return std::make_pair(qr.q, qr.r);
197 },
198 py::arg("operand"), py::arg("full_matrices"));
199 ops.def(
200 "LU",
201 [](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
202 LuDecompositionResult lu = LuDecomposition(a);
203 return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
204 },
205 py::arg("operand"));
206 ops.def(
207 "Eigh",
208 [](XlaOp a, bool lower, int64 max_iter,
209 float epsilon) -> std::pair<XlaOp, XlaOp> {
210 auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
211 return std::make_pair(eigh.v, eigh.w);
212 },
213 py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
214 py::arg("epsilon") = 1e-6);
215 ops.def(
216 "SVD",
217 [](XlaOp a, int64 max_iter,
218 float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
219 auto svd = SVD(a, max_iter, epsilon);
220 return std::make_tuple(svd.u, svd.d, svd.v);
221 },
222 py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
223 ops.def("Reduce",
224 static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
225 absl::Span<const XlaOp>, const XlaComputation&,
226 absl::Span<const int64>)>(&Reduce),
227 py::arg("builder"), py::arg("operands"), py::arg("init_values"),
228 py::arg("computation"), py::arg("dimensions_to_reduce"));
229 ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
230 py::arg("exponent_bits"), py::arg("mantissa_bits"));
231 ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
232 py::arg("operand"), py::arg("init_value"), py::arg("computation"),
233 py::arg("window_dimensions"), py::arg("window_strides"),
234 py::arg("base_dilations"), py::arg("window_dilations"),
235 py::arg("padding"));
236 ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
237 ops.def("Reshape",
238 static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
239 absl::Span<const int64>)>(&Reshape),
240 py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
241 ops.def("Reshape",
242 static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
243 py::arg("operand"), py::arg("new_sizes"));
244 ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
245 ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"),
246 py::arg("initial_state"), py::arg("shape"));
247 ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
248 py::arg("shape"));
249 ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
250 py::arg("shape"));
251 ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
252 py::arg("updates"), py::arg("update_computation"),
253 py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
254 py::arg("unique_indices") = false);
255 ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
256 py::arg("on_false"));
257 ops.def("SelectAndScatterWithGeneralPadding",
258 &SelectAndScatterWithGeneralPadding, py::arg("operand"),
259 py::arg("select"), py::arg("window_dimensions"),
260 py::arg("window_strides"), py::arg("padding"), py::arg("source"),
261 py::arg("init_value"), py::arg("scatter"));
262 ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
263 py::arg("limit_indices"), py::arg("strides"));
264 ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
265 py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
266 ops.def(
267 "Sort",
268 [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
269 absl::optional<const XlaComputation*> comparator, int64 dimension,
270 bool is_stable) -> XlaOp {
271 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
272 std::vector<PrimitiveType> operand_types;
273 for (const auto& operand : operands) {
274 TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
275 operand_types.push_back(operand_shape.element_type());
276 }
277
278 if (comparator) {
279 return Sort(operands, **comparator, dimension, is_stable);
280 } else {
281 return Sort(operands,
282 CreateScalarLtComputation(operand_types, builder),
283 dimension, is_stable);
284 }
285 });
286 },
287 py::arg("builder"), py::arg("operands"),
288 py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
289 py::arg("is_stable") = false);
290 ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
291 ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
292 ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
293 py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
294 py::arg("transpose_a"));
295 ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
296 ops.def("While", &While, py::arg("condition"), py::arg("body"),
297 py::arg("init"));
298
299 ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
300 ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
301 ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
302 ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
303 ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
304 py::arg("b"), py::arg("x"));
305 ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
306
307 #define BINARY_OP(op) \
308 ops.def( \
309 #op, \
310 [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
311 return dims ? op(a, b, *dims) : op(a, b); \
312 }, \
313 py::arg("lhs"), py::arg("rhs"), \
314 py::arg("broadcast_dimensions") = absl::nullopt)
315 BINARY_OP(Eq);
316 BINARY_OP(Ne);
317 BINARY_OP(Ge);
318 BINARY_OP(Gt);
319 BINARY_OP(Lt);
320 BINARY_OP(Le);
321 BINARY_OP(Add);
322 BINARY_OP(Sub);
323 BINARY_OP(Mul);
324 BINARY_OP(Div);
325 BINARY_OP(Rem);
326 BINARY_OP(Max);
327 BINARY_OP(Min);
328 BINARY_OP(And);
329 BINARY_OP(Or);
330 BINARY_OP(Xor);
331 BINARY_OP(ShiftLeft);
332 BINARY_OP(ShiftRightArithmetic);
333 BINARY_OP(ShiftRightLogical);
334 BINARY_OP(Atan2);
335 BINARY_OP(Pow);
336 BINARY_OP(Complex);
337 #undef BINARY_OP
338
339 #define UNARY_OP(op) ops.def(#op, &op)
340 UNARY_OP(Not);
341 UNARY_OP(PopulationCount);
342 UNARY_OP(Clz);
343 UNARY_OP(Abs);
344 UNARY_OP(Exp);
345 UNARY_OP(Expm1);
346 UNARY_OP(Floor);
347 UNARY_OP(Ceil);
348 UNARY_OP(Round);
349 UNARY_OP(Log);
350 UNARY_OP(Log1p);
351 UNARY_OP(Sign);
352 UNARY_OP(Cos);
353 UNARY_OP(Sin);
354 UNARY_OP(Tanh);
355 UNARY_OP(IsFinite);
356 UNARY_OP(Neg);
357 UNARY_OP(Sqrt);
358 UNARY_OP(Rsqrt);
359 UNARY_OP(Square);
360 UNARY_OP(Reciprocal);
361 UNARY_OP(Erfc);
362 UNARY_OP(Erf);
363 UNARY_OP(ErfInv);
364 UNARY_OP(Lgamma);
365 UNARY_OP(Digamma);
366 UNARY_OP(BesselI0e);
367 UNARY_OP(BesselI1e);
368 UNARY_OP(Acos);
369 UNARY_OP(Asin);
370 UNARY_OP(Atan);
371 UNARY_OP(Tan);
372 UNARY_OP(Acosh);
373 UNARY_OP(Asinh);
374 UNARY_OP(Atanh);
375 UNARY_OP(Cosh);
376 UNARY_OP(Sinh);
377 UNARY_OP(Real);
378 UNARY_OP(Imag);
379 UNARY_OP(Conj);
380 #undef UNARY_OP
381 }
382
383 } // namespace xla
384