1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/xla_compiler.h"
17
18 #include <cstdint>
19 #include <string>
20 #include <vector>
21
22 #include "absl/hash/hash.h"
23 #include "absl/synchronization/mutex.h"
24 #include "absl/types/optional.h"
25 #include "absl/types/span.h"
26 #include "pybind11/attr.h"
27 #include "pybind11/cast.h"
28 #include "pybind11/numpy.h"
29 #include "pybind11/pybind11.h"
30 #include "pybind11/pytypes.h"
31 #include "pybind11/stl_bind.h"
32 #include "tensorflow/compiler/xla/client/executable_build_options.h"
33 #include "tensorflow/compiler/xla/client/xla_builder.h"
34 #include "tensorflow/compiler/xla/client/xla_computation.h"
35 #include "tensorflow/compiler/xla/debug_options_flags.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/python/py_client.h"
38 #include "tensorflow/compiler/xla/python/types.h"
39 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
40 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_parser.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/service/platform_util.h"
46 #include "tensorflow/compiler/xla/shape.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/compiler/xla/xla_data.pb.h"
51
52 namespace xla {
53 namespace {
54
55 namespace py = pybind11;
56
57 struct Uniquer {
58 absl::Mutex mu;
59 NameUniquer name_uniquer TF_GUARDED_BY(mu);
60 };
61
GetUniquer()62 Uniquer* GetUniquer() {
63 static Uniquer* uniquer = new Uniquer;
64 return uniquer;
65 }
66
UniquifyName(const std::string & name)67 static std::string UniquifyName(const std::string& name) {
68 Uniquer* uniquer = GetUniquer();
69 absl::MutexLock lock(&uniquer->mu);
70 return uniquer->name_uniquer.GetUniqueName(name);
71 }
72
73 // Converts a computation to a serialized HloModuleProto.
GetComputationSerializedProto(const XlaComputation & computation)74 StatusOr<py::bytes> GetComputationSerializedProto(
75 const XlaComputation& computation) {
76 std::string result;
77 if (!computation.proto().SerializeToString(&result)) {
78 return Unknown("Failed to serialize the HloModuleProto.");
79 }
80 return py::bytes(result);
81 }
82
GetHloModule(const XlaComputation & computation)83 StatusOr<std::shared_ptr<HloModule>> GetHloModule(
84 const XlaComputation& computation) {
85 TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
86 HloModule::CreateModuleConfigFromProto(
87 computation.proto(), GetDebugOptionsFromFlags()));
88 TF_ASSIGN_OR_RETURN(
89 std::unique_ptr<HloModule> module,
90 HloModule::CreateFromProto(computation.proto(), module_config));
91 return std::shared_ptr<HloModule>(std::move(module));
92 }
93
94 // Converts a computation to textual HLO form.
GetComputationHloText(const XlaComputation & computation)95 StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) {
96 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
97 GetHloModule(computation));
98 HloPrintOptions options;
99 options = HloPrintOptions::ShortParsable();
100 options.set_print_large_constants(false);
101 return hlo_module->ToString(options);
102 }
103
104 // Converts a computation to HLO dot graph form.
GetComputationHloDotGraph(const XlaComputation & computation)105 StatusOr<std::string> GetComputationHloDotGraph(
106 const XlaComputation& computation) {
107 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
108 GetHloModule(computation));
109 return RenderGraph(*hlo_module->entry_computation(), /*label=*/"",
110 hlo_module->config().debug_options(),
111 RenderedGraphFormat::kDot);
112 }
113
114 // Hashes the HLO module.
HashComputation(const XlaComputation & computation)115 StatusOr<uint64> HashComputation(const XlaComputation& computation) {
116 TF_ASSIGN_OR_RETURN(std::shared_ptr<HloModule> hlo_module,
117 GetHloModule(computation));
118 return hlo_module->Hash();
119 }
120 // Safe version of ShapeUtil::MakeShapeWithLayout that fails gracefully on
121 // invalid input.
MakeShapeWithLayout(PrimitiveType element_type,absl::Span<const int64> dims,absl::optional<absl::Span<const int64>> minor_to_major)122 StatusOr<Shape> MakeShapeWithLayout(
123 PrimitiveType element_type, absl::Span<const int64> dims,
124 absl::optional<absl::Span<const int64>> minor_to_major) {
125 TF_ASSIGN_OR_RETURN(Shape shape,
126 ShapeUtil::MakeValidatedShape(element_type, dims));
127 if (minor_to_major) {
128 *shape.mutable_layout() = LayoutUtil::MakeLayout(*minor_to_major);
129 TF_RETURN_IF_ERROR(
130 LayoutUtil::ValidateLayoutForShape(shape.layout(), shape));
131 } else {
132 shape.clear_layout();
133 }
134 return shape;
135 }
136
137 // Registers a 'fn_capsule' as a CPU custom call target.
138 // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object,
139 // with name "xla._CUSTOM_CALL_TARGET".
140 // 'platform' is an XLA platform name, e.g., "Host" or "CUDA".
PyRegisterCustomCallTarget(const std::string & fn_name,py::capsule capsule,const std::string & platform)141 Status PyRegisterCustomCallTarget(const std::string& fn_name,
142 py::capsule capsule,
143 const std::string& platform) {
144 static const char* const kName = "xla._CUSTOM_CALL_TARGET";
145 // TODO(phawkins): remove old name after fixing users.
146 static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET";
147 if (absl::string_view(capsule.name()) != kName &&
148 absl::string_view(capsule.name()) != kOldCpuName) {
149 return InvalidArgument(
150 "Argument to RegisterCustomCallTargetRegistry was not a "
151 "xla._CUSTOM_CALL_TARGET capsule.");
152 }
153 CustomCallTargetRegistry::Global()->Register(
154 fn_name, static_cast<void*>(capsule), platform);
155 return Status::OK();
156 }
157
158 } // namespace
159
BuildXlaCompilerSubmodule(py::module & m)160 void BuildXlaCompilerSubmodule(py::module& m) {
161 // Shapes
162 py::class_<Shape> shape_class(m, "Shape");
163 shape_class
164 .def(py::init([](const string& s) {
165 return absl::make_unique<Shape>(ValueOrThrow(ParseShape(s)));
166 }))
167 .def_static(
168 "tuple_shape",
169 [](std::vector<Shape> shapes) -> Shape {
170 return ShapeUtil::MakeTupleShape(shapes);
171 },
172 "Constructs a tuple shape.")
173 .def_static(
174 "array_shape",
175 [](PrimitiveType type, py::object dims_seq,
176 absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
177 std::vector<int64> dims = IntSequenceToVector(dims_seq);
178 if (layout_seq) {
179 std::vector<int64> layout = IntSequenceToVector(*layout_seq);
180 return MakeShapeWithLayout(type, dims, layout);
181 } else {
182 return MakeShapeWithLayout(type, dims, absl::nullopt);
183 }
184 },
185 "Constructs an array shape.", py::arg("type"), py::arg("dims"),
186 py::arg("layout") = absl::nullopt)
187 .def_static(
188 "array_shape",
189 [](py::dtype dtype, py::object dims_seq,
190 absl::optional<py::object> layout_seq) -> StatusOr<Shape> {
191 PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype));
192 std::vector<int64> dims = IntSequenceToVector(dims_seq);
193 if (layout_seq) {
194 std::vector<int64> layout = IntSequenceToVector(*layout_seq);
195 return MakeShapeWithLayout(type, dims, layout);
196 } else {
197 return MakeShapeWithLayout(type, dims, absl::nullopt);
198 }
199 },
200 "Constructs an array shape.", py::arg("type"), py::arg("dims"),
201 py::arg("layout") = absl::nullopt)
202 .def_static("token_shape", []() { return ShapeUtil::MakeTokenShape(); })
203 .def("dimensions",
204 [](const Shape& shape) -> py::tuple {
205 return IntSpanToTuple(shape.dimensions());
206 })
207 .def("xla_element_type", &Shape::element_type)
208 .def("element_type",
209 [](const Shape& shape) {
210 return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
211 })
212 .def("numpy_dtype",
213 [](const Shape& shape) {
214 if (shape.IsTuple()) {
215 return py::dtype("O");
216 }
217 return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type()));
218 })
219 .def("is_tuple", &Shape::IsTuple)
220 .def("is_array", &Shape::IsArray)
221 .def("rank", &Shape::rank)
222 .def("to_serialized_proto",
223 [](const Shape& shape) {
224 ShapeProto proto = shape.ToProto();
225 return py::bytes(proto.SerializeAsString());
226 })
227 .def("tuple_shapes",
228 [](const Shape& shape) {
229 return std::vector<Shape>(shape.tuple_shapes());
230 })
231 .def("leaf_count",
232 [](const Shape& shape) { return ShapeUtil::GetLeafCount(shape); })
233 .def(
234 "with_major_to_minor_layout_if_absent",
235 [](const Shape& shape) {
236 Shape out = shape;
237 ShapeUtil::ForEachMutableSubshape(
238 &out, [](Shape* subshape, const ShapeIndex&) {
239 if (!subshape->has_layout()) {
240 LayoutUtil::SetToDefaultLayout(subshape);
241 }
242 });
243 return out;
244 },
245 "Returns a copy of a shape with missing layouts set to "
246 "major-to-minor.")
247 .def("__eq__", [](const Shape& shape,
248 const Shape& other) { return shape == other; })
249 .def("__ne__", [](const Shape& shape,
250 const Shape& other) { return shape != other; })
251 .def("__hash__",
252 [](const Shape& shape) { return absl::Hash<Shape>()(shape); })
253 .def("__repr__", [](const Shape& shape) {
254 return shape.ToString(/*print_layout=*/true);
255 });
256
257 py::class_<ProgramShape>(m, "ProgramShape")
258 .def(py::init(
259 [](absl::Span<const Shape> params, Shape result) -> ProgramShape {
260 ProgramShape program_shape;
261 for (const Shape& param : params) {
262 *program_shape.add_parameters() = param;
263 }
264 *program_shape.mutable_result() = result;
265 return program_shape;
266 }))
267 .def("parameter_shapes",
268 static_cast<const std::vector<Shape>& (ProgramShape::*)() const>(
269 &ProgramShape::parameters))
270 .def("result_shape", &ProgramShape::result)
271 .def("__repr__", &ProgramShape::ToString);
272
273 // Literals
274 py::class_<Literal, std::shared_ptr<Literal>>(m, "Literal")
275 .def("__repr__", &Literal::ToString);
276
277 py::class_<XlaComputation>(m, "XlaComputation")
278 .def(py::init([](const py::bytes& serialized_hlo_module_proto)
279 -> std::unique_ptr<XlaComputation> {
280 HloModuleProto proto;
281 proto.ParseFromString(std::string(serialized_hlo_module_proto));
282 return absl::make_unique<XlaComputation>(proto);
283 }))
284 .def("get_hlo_module", &GetHloModule)
285 .def("program_shape", &XlaComputation::GetProgramShape)
286 .def("as_serialized_hlo_module_proto", &GetComputationSerializedProto)
287 .def("as_hlo_text", &GetComputationHloText)
288 .def("as_hlo_dot_graph", &GetComputationHloDotGraph)
289 .def("hash", &HashComputation)
290 .def("as_hlo_module", &GetHloModule);
291
292 py::class_<HloPrintOptions> hlo_print_options_class(m, "HloPrintOptions");
293 hlo_print_options_class.def(py::init<>())
294 .def_static("short_parsable", &HloPrintOptions::ShortParsable)
295 .def_static("canonical", &HloPrintOptions::Canonical)
296 .def_static("fingerprint", &HloPrintOptions::Fingerprint)
297 .def_property("print_large_constants",
298 &HloPrintOptions::print_large_constants,
299 &HloPrintOptions::set_print_large_constants)
300 .def_property("print_metadata", &HloPrintOptions::print_metadata,
301 &HloPrintOptions::set_print_metadata)
302 .def_property("print_backend_config",
303 &HloPrintOptions::print_backend_config,
304 &HloPrintOptions::set_print_backend_config)
305 .def_property("print_result_shape", &HloPrintOptions::print_result_shape,
306 &HloPrintOptions::set_print_result_shape)
307 .def_property("print_operand_shape",
308 &HloPrintOptions::print_operand_shape,
309 &HloPrintOptions::set_print_operand_shape)
310 .def_property("print_operand_names",
311 &HloPrintOptions::print_operand_names,
312 &HloPrintOptions::set_print_operand_names)
313 .def_property("print_ids", &HloPrintOptions::print_ids,
314 &HloPrintOptions::set_print_ids)
315 .def_property("print_extra_attributes",
316 &HloPrintOptions::print_extra_attributes,
317 &HloPrintOptions::set_print_extra_attributes)
318 .def_property("print_program_shape",
319 &HloPrintOptions::print_program_shape,
320 &HloPrintOptions::set_print_program_shape)
321 .def_property("print_percent", &HloPrintOptions::print_percent,
322 &HloPrintOptions::set_print_percent)
323 .def_property("print_control_dependencies",
324 &HloPrintOptions::print_control_dependencies,
325 &HloPrintOptions::set_print_control_dependencies)
326 .def_property("compact_operands", &HloPrintOptions::compact_operands,
327 &HloPrintOptions::set_compact_operands)
328 .def_property("include_layout_in_shapes",
329 &HloPrintOptions::include_layout_in_shapes,
330 &HloPrintOptions::set_include_layout_in_shapes)
331 .def_property("canonicalize_instruction_names",
332 &HloPrintOptions::canonicalize_instruction_names,
333 &HloPrintOptions::set_canonicalize_instruction_names)
334 .def_property("canonicalize_computations",
335 &HloPrintOptions::canonicalize_computations,
336 &HloPrintOptions::set_canonicalize_computations)
337 .def_property("indent_amount", &HloPrintOptions::indent_amount,
338 &HloPrintOptions::set_indent_amount)
339 .def_property("is_in_nested_computation",
340 &HloPrintOptions::is_in_nested_computation,
341 &HloPrintOptions::set_is_in_nested_computation)
342 .def_property(
343 "leading_and_trailing_instructions_number",
344 &HloPrintOptions::leading_and_trailing_instructions_number,
345 &HloPrintOptions::set_leading_and_trailing_instructions_number);
346
347 py::class_<HloModule, std::shared_ptr<HloModule>> hlo_module_class(
348 m, "HloModule");
349 hlo_module_class.def(
350 "to_string",
351 static_cast<std::string (HloModule::*)(const HloPrintOptions&) const>(
352 &HloModule::ToString),
353 py::arg("options") = HloPrintOptions());
354
355 m.def("hlo_module_to_dot_graph",
356 [](const HloModule& hlo_module) -> StatusOr<std::string> {
357 return RenderGraph(*hlo_module.entry_computation(), /*label=*/"",
358 hlo_module.config().debug_options(),
359 RenderedGraphFormat::kDot);
360 });
361 m.def(
362 "hlo_module_cost_analysis",
363 [](PyClient* client,
364 const HloModule& module) -> StatusOr<std::map<string, float>> {
365 TF_ASSIGN_OR_RETURN(auto analysis,
366 client->pjrt_client()->GetHloCostAnalysis());
367 TF_RETURN_IF_ERROR(module.entry_computation()->Accept(analysis.get()));
368 return analysis->properties();
369 });
370
371 py::class_<XlaOp> xla_op_class(m, "XlaOp");
372
373 py::class_<XlaBuilder>(m, "XlaBuilder")
374 .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> {
375 return absl::make_unique<XlaBuilder>(UniquifyName(name));
376 }))
377 // TODO(phawkins): delete capitalized names after updating callers.
378 .def(
379 "Build",
380 [](XlaBuilder& builder, absl::optional<XlaOp> root) {
381 return root ? builder.Build(*root) : builder.Build();
382 },
383 "Builds a computation from the contents of the builder.",
384 py::arg("root") = absl::nullopt)
385 .def("GetShape", &XlaBuilder::GetShape)
386 .def(
387 "build",
388 [](XlaBuilder& builder, absl::optional<XlaOp> root) {
389 return root ? builder.Build(*root) : builder.Build();
390 },
391 "Builds a computation from the contents of the builder.",
392 py::arg("root") = absl::nullopt)
393 .def("clear_op_metadata", &XlaBuilder::ClearOpMetadata)
394 .def("get_shape", &XlaBuilder::GetShape)
395 .def(
396 "get_program_shape",
397 [](const XlaBuilder& builder,
398 absl::optional<XlaOp> root) -> StatusOr<ProgramShape> {
399 return root ? builder.GetProgramShape(*root)
400 : builder.GetProgramShape();
401 },
402 py::arg("root") = absl::nullopt)
403 .def("is_constant", &XlaBuilder::IsConstant)
404 .def("set_op_metadata", &XlaBuilder::SetOpMetadata)
405 .def("set_sharding", &XlaBuilder::SetSharding)
406 .def("clear_sharding", &XlaBuilder::ClearSharding)
407 .def("setup_alias",
408 [](XlaBuilder& builder, const std::vector<int64>& output_index,
409 int64 param_number, const std::vector<int64>& param_index) {
410 builder.SetUpAlias(
411 ShapeIndex(output_index.begin(), output_index.end()),
412 param_number,
413 ShapeIndex(param_index.begin(), param_index.end()));
414 });
415
416 // Device assignments
417 py::class_<DeviceAssignment>(m, "DeviceAssignment")
418 .def_static("create",
419 [](py::array_t<int> array) -> StatusOr<DeviceAssignment> {
420 if (array.ndim() != 2) {
421 return InvalidArgument(
422 "Argument to DeviceAssignment constructor must be a "
423 "2D array, received an %dD array.",
424 array.ndim());
425 }
426 DeviceAssignment result(array.shape(0), array.shape(1));
427 for (int i = 0; i < array.shape(0); ++i) {
428 for (int j = 0; j < array.shape(1); ++j) {
429 result(i, j) = array.at(i, j);
430 }
431 }
432 return result;
433 })
434 .def("replica_count", &DeviceAssignment::replica_count)
435 .def("computation_count", &DeviceAssignment::computation_count)
436 .def("__repr__", &DeviceAssignment::ToString);
437
438 py::class_<CompileOptions> compile_options(m, "CompileOptions");
439 compile_options
440 .def(py::init([]() -> CompileOptions {
441 CompileOptions options;
442 DebugOptions* debug_options =
443 options.executable_build_options.mutable_debug_options();
444 // Sets fast-math-disabling default options expected by JAX.
445 debug_options->set_xla_cpu_enable_fast_min_max(false);
446 debug_options->set_xla_gpu_enable_fast_min_max(false);
447 return options;
448 }))
449 .def_readwrite("argument_layouts", &CompileOptions::argument_layouts)
450 .def_readwrite("parameter_is_tupled_arguments",
451 &CompileOptions::parameter_is_tupled_arguments)
452 .def_readonly("executable_build_options",
453 &CompileOptions::executable_build_options)
454 // TODO(phawkins): the following fields exist for backward compatibility.
455 // Remove them after JAX has been updated not to use them.
456 .def_readwrite("tuple_arguments",
457 &CompileOptions::parameter_is_tupled_arguments)
458 .def_property(
459 "num_replicas",
460 [](const CompileOptions& options) {
461 return options.executable_build_options.num_replicas();
462 },
463 [](CompileOptions& options, int num_replicas) {
464 options.executable_build_options.set_num_replicas(num_replicas);
465 })
466 .def_property(
467 "num_partitions",
468 [](const CompileOptions& options) {
469 return options.executable_build_options.num_partitions();
470 },
471 [](CompileOptions& options, int num_partitions) {
472 options.executable_build_options.set_num_partitions(num_partitions);
473 })
474 .def_property(
475 "device_assignment",
476 [](const CompileOptions& options)
477 -> absl::optional<DeviceAssignment> {
478 return options.executable_build_options.has_device_assignment()
479 ? absl::optional<DeviceAssignment>(
480 options.executable_build_options
481 .device_assignment())
482 : absl::nullopt;
483 },
484 [](CompileOptions& options,
485 const DeviceAssignment& device_assignment) {
486 options.executable_build_options.set_device_assignment(
487 device_assignment);
488 });
489
490 // Custom-call targets.
491 m.def("register_custom_call_target", &PyRegisterCustomCallTarget);
492
493 py::class_<DebugOptions>(m, "DebugOptions")
494 .def("__repr__", &DebugOptions::DebugString)
495 .def_property("xla_cpu_enable_fast_math",
496 &DebugOptions::xla_cpu_enable_fast_math,
497 &DebugOptions::set_xla_cpu_enable_fast_math)
498 .def_property("xla_cpu_fast_math_honor_infs",
499 &DebugOptions::xla_cpu_fast_math_honor_infs,
500 &DebugOptions::set_xla_cpu_fast_math_honor_infs)
501 .def_property("xla_cpu_fast_math_honor_nans",
502 &DebugOptions::xla_cpu_fast_math_honor_nans,
503 &DebugOptions::set_xla_cpu_fast_math_honor_nans)
504 .def_property("xla_cpu_fast_math_honor_division",
505 &DebugOptions::xla_cpu_fast_math_honor_division,
506 &DebugOptions::set_xla_cpu_fast_math_honor_division)
507 .def_property("xla_cpu_fast_math_honor_functions",
508 &DebugOptions::xla_cpu_fast_math_honor_functions,
509 &DebugOptions::set_xla_cpu_fast_math_honor_functions)
510 .def_property("xla_gpu_enable_fast_min_max",
511 &DebugOptions::xla_gpu_enable_fast_min_max,
512 &DebugOptions::set_xla_gpu_enable_fast_min_max)
513 .def_property("xla_backend_optimization_level",
514 &DebugOptions::xla_backend_optimization_level,
515 &DebugOptions::set_xla_backend_optimization_level)
516 .def_property("xla_cpu_enable_xprof_traceme",
517 &DebugOptions::xla_cpu_enable_xprof_traceme,
518 &DebugOptions::set_xla_cpu_enable_xprof_traceme)
519 .def_property("xla_llvm_disable_expensive_passes",
520 &DebugOptions::xla_llvm_disable_expensive_passes,
521 &DebugOptions::set_xla_llvm_disable_expensive_passes)
522 .def_property("xla_test_all_input_layouts",
523 &DebugOptions::xla_test_all_input_layouts,
524 &DebugOptions::set_xla_test_all_input_layouts);
525
526 py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions")
527 .def(py::init<>())
528 .def("__repr__", &ExecutableBuildOptions::ToString)
529 .def_property(
530 "result_layout",
531 [](const ExecutableBuildOptions& options) -> absl::optional<Shape> {
532 return options.result_layout()
533 ? absl::optional<Shape>(*options.result_layout())
534 : absl::nullopt;
535 },
536 &ExecutableBuildOptions::set_result_layout)
537 .def_property("num_replicas", &ExecutableBuildOptions::num_replicas,
538 &ExecutableBuildOptions::set_num_replicas)
539 .def_property("num_partitions", &ExecutableBuildOptions::num_partitions,
540 &ExecutableBuildOptions::set_num_partitions)
541 .def_property_readonly(
542 "debug_options", &ExecutableBuildOptions::mutable_debug_options,
543 py::return_value_policy::reference, py::keep_alive<1, 0>())
544 .def_property(
545 "device_assignment",
546 [](const ExecutableBuildOptions& options)
547 -> absl::optional<DeviceAssignment> {
548 return options.has_device_assignment()
549 ? absl::optional<DeviceAssignment>(
550 options.device_assignment())
551 : absl::nullopt;
552 },
553 &ExecutableBuildOptions::set_device_assignment)
554 .def_property("use_spmd_partitioning",
555 &ExecutableBuildOptions::use_spmd_partitioning,
556 &ExecutableBuildOptions::set_use_spmd_partitioning);
557
558 py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision")
559 .value("DEFAULT", PrecisionConfig::DEFAULT)
560 .value("HIGH", PrecisionConfig::HIGH)
561 .value("HIGHEST", PrecisionConfig::HIGHEST);
562
563 py::enum_<OpSharding::Type>(m, "OpSharding_Type")
564 .value("REPLICATED", OpSharding::REPLICATED)
565 .value("MAXIMAL", OpSharding::MAXIMAL)
566 .value("TUPLE", OpSharding::TUPLE)
567 .value("OTHER", OpSharding::OTHER);
568
569 py::enum_<ChannelHandle::ChannelType>(m, "ChannelHandle_ChannelType")
570 .value("CHANNEL_TYPE_INVALID", ChannelHandle::CHANNEL_TYPE_INVALID)
571 .value("DEVICE_TO_DEVICE", ChannelHandle::DEVICE_TO_DEVICE)
572 .value("DEVICE_TO_HOST", ChannelHandle::DEVICE_TO_HOST)
573 .value("HOST_TO_DEVICE", ChannelHandle::HOST_TO_DEVICE);
574
575 py::class_<ChannelHandle>(m, "ChannelHandle")
576 .def_property_readonly("type", &ChannelHandle::type)
577 .def_property_readonly("handle", &ChannelHandle::handle)
578 .def("__repr__", [](ChannelHandle* h) { return h->DebugString(); });
579
580 py::enum_<FftType>(m, "FftType")
581 .value("FFT", FftType::FFT)
582 .value("IFFT", FftType::IFFT)
583 .value("RFFT", FftType::RFFT)
584 .value("IRFFT", FftType::IRFFT);
585 }
586 } // namespace xla
587