1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/stl.h"
27 #include "tensorflow/compiler/xla/python/absl_casters.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/protobuf.h"
35
36 namespace xla {
37
38 // Helper that converts a failing StatusOr to an exception.
39 // For use only inside pybind11 code.
40 template <typename T>
ValueOrThrow(StatusOr<T> v)41 T ValueOrThrow(StatusOr<T> v) {
42 if (!v.ok()) {
43 throw std::runtime_error(v.status().ToString());
44 }
45 return v.ConsumeValueOrDie();
46 }
47
48 // Converts a NumPy dtype to a PrimitiveType.
49 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
50
51 // Converts a PrimitiveType to a Numpy dtype.
52 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
53
54 // Returns a numpy-style format descriptor string for `type`.
55 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
56
57 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
58 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
59
60 // Returns the strides for `shape`.
61 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
62
63 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
64 // The literal's leaf arrays are not copied; instead the NumPy arrays share
65 // buffers with the literals. Takes ownership of `literal` and keeps the
66 // necessary pieces alive using Python reference counting.
67 // Requires the GIL.
68 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
69
70 // Converts a Python object into an XLA shape and a vector of leaf buffers.
71 // The leaf buffers correspond to a depth-first, left-to-right traversal of
72 // the Python value.
73 // Requires the GIL.
74 struct PythonBufferTree {
75 // Holds a reference to the arrays pointed to by `leaves`, since we may
76 // need to make a copy if the array is not in a C-style layout.
77 absl::InlinedVector<pybind11::object, 1> arrays;
78 absl::InlinedVector<BorrowingLiteral, 1> leaves;
79 Shape shape;
80 };
81 StatusOr<PythonBufferTree> GetPythonBufferTree(
82 const pybind11::object& argument);
83
84 // Converts a sequence of C++ ints to a Python tuple of ints.
85 // Pybind11 by default converts a std::vector<int64> to a Python list;
86 // we frequently want a tuple instead e.g. for shapes.
87 pybind11::tuple IntSpanToTuple(absl::Span<int64 const> xs);
88 pybind11::tuple IntSpanToTuple(absl::Span<int const> xs);
89
90 // Converts a Python sequence of integers to a std::vector<int64>
91 std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
92
93 // Private helper function used in the implementation of the type caster for
94 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
95 // pointer and shape.
96 struct CastToArrayResult {
97 pybind11::object array; // Holds a reference to the array to keep it alive.
98 const char* buf_ptr;
99 xla::Shape shape;
100 };
101 absl::optional<CastToArrayResult> CastToArray(pybind11::handle h);
102
103 } // namespace xla
104
105 // This namespace is a documented pybind11 extension point.
106 // Caution: Unusually for Google code, this code uses C++ exceptions because
107 // they are the only mechanism for reporting cast failures to pybind11. However,
108 // the exceptions are local to the binding code.
109 namespace pybind11 {
110 namespace detail {
111
112 // Status, StatusOr. Failing statuses become Python exceptions; Status::OK()
113 // becomes None.
114 template <>
115 struct type_caster<xla::Status> {
116 public:
117 PYBIND11_TYPE_CASTER(xla::Status, _("Status"));
118
119 static handle cast(xla::Status src, return_value_policy /* policy */,
120 handle /* parent */) {
121 if (!src.ok()) {
122 throw std::runtime_error(src.ToString());
123 }
124 return none().inc_ref();
125 }
126 };
127
128 template <typename T>
129 struct type_caster<xla::StatusOr<T>> {
130 public:
131 using value_conv = make_caster<T>;
132
133 PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
134 _("StatusOr[") + value_conv::name + _("]"));
135
136 static handle cast(xla::StatusOr<T> src, return_value_policy policy,
137 handle parent) {
138 if (!src.ok()) {
139 throw std::runtime_error(src.status().ToString());
140 }
141 return value_conv::cast(std::forward<xla::StatusOr<T>>(src).ValueOrDie(),
142 policy, parent);
143 }
144 };
145
146 // Literals.
147 // Literal data can be passed to XLA as a NumPy array; its value can be
148 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
149 // We don't have any literal -> numpy conversions here, since all the methods
150 // that want to return arrays build Python objects directly.
151
152 template <>
153 struct type_caster<xla::BorrowingLiteral> {
154 public:
155 PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
156
157 // Pybind appears to keep type_casters alive until the callee has run.
158 absl::InlinedVector<pybind11::array, 1> arrays;
159
160 bool load(handle input, bool) {
161 // TODO(b/79707221): support nested tuples if/when XLA adds support for
162 // nested BorrowingLiterals.
163 if (pybind11::isinstance<pybind11::tuple>(input)) {
164 pybind11::tuple tuple =
165 pybind11::reinterpret_borrow<pybind11::tuple>(input);
166 std::vector<xla::Shape> shapes;
167 std::vector<const char*> buffers;
168 arrays.reserve(tuple.size());
169 shapes.reserve(tuple.size());
170 buffers.reserve(tuple.size());
171 for (pybind11::handle entry : tuple) {
172 auto c = xla::CastToArray(entry);
173 if (!c) {
174 return false;
175 }
176 arrays.push_back(c->array);
177 buffers.push_back(c->buf_ptr);
178 shapes.push_back(c->shape);
179 }
180 value = xla::BorrowingLiteral(buffers,
181 xla::ShapeUtil::MakeTupleShape(shapes));
182 } else {
183 auto c = xla::CastToArray(input);
184 if (!c) {
185 return false;
186 }
187 arrays.push_back(c->array);
188 value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
189 }
190 return true;
191 }
192 };
193
194 template <>
195 struct type_caster<xla::LiteralSlice> {
196 public:
197 PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
198
199 // Pybind appears to keep type_casters alive until the callee has run.
200 type_caster<xla::BorrowingLiteral> literal_caster;
201
202 bool load(handle handle, bool convert) {
203 if (!literal_caster.load(handle, convert)) {
204 return false;
205 }
206 value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
207 return true;
208 }
209 };
210
211 // XLA protocol buffers
212 // We don't actually care that these are the protocol buffers, we merely want
213 // objects that duck type as protocol buffers. The client code currently avoids
214 // depending on Python protocol buffers to avoid conflicting definitions from
215 // different modules that both include XLA.
216
217 template <>
218 struct type_caster<xla::ConvolutionDimensionNumbers> {
219 public:
220 PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
221 _("xla::ConvolutionDimensionNumbers"));
222
223 // PyObject -> C++ conversion.
224 bool load(handle handle, bool) {
225 value.set_input_batch_dimension(
226 getattr(handle, "input_batch_dimension").cast<xla::int64>());
227 value.set_input_feature_dimension(
228 getattr(handle, "input_feature_dimension").cast<xla::int64>());
229 value.set_output_batch_dimension(
230 getattr(handle, "output_batch_dimension").cast<xla::int64>());
231 value.set_output_feature_dimension(
232 getattr(handle, "output_feature_dimension").cast<xla::int64>());
233 value.set_kernel_input_feature_dimension(
234 getattr(handle, "kernel_input_feature_dimension").cast<xla::int64>());
235 value.set_kernel_output_feature_dimension(
236 getattr(handle, "kernel_output_feature_dimension").cast<xla::int64>());
237 std::vector<xla::int64> dims;
238 dims = getattr(handle, "input_spatial_dimensions")
239 .cast<std::vector<xla::int64>>();
240 std::copy(dims.begin(), dims.end(),
241 tensorflow::protobuf::RepeatedFieldBackInserter(
242 value.mutable_input_spatial_dimensions()));
243 dims = getattr(handle, "kernel_spatial_dimensions")
244 .cast<std::vector<xla::int64>>();
245 std::copy(dims.begin(), dims.end(),
246 tensorflow::protobuf::RepeatedFieldBackInserter(
247 value.mutable_kernel_spatial_dimensions()));
248 dims = getattr(handle, "output_spatial_dimensions")
249 .cast<std::vector<xla::int64>>();
250 std::copy(dims.begin(), dims.end(),
251 tensorflow::protobuf::RepeatedFieldBackInserter(
252 value.mutable_output_spatial_dimensions()));
253 return true;
254 }
255 };
256
257 template <>
258 struct type_caster<xla::DotDimensionNumbers> {
259 public:
260 PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
261
262 // PyObject -> C++ conversion.
263 bool load(handle handle, bool) {
264 std::vector<xla::int64> dims;
265 dims = getattr(handle, "lhs_contracting_dimensions")
266 .cast<std::vector<xla::int64>>();
267 std::copy(dims.begin(), dims.end(),
268 tensorflow::protobuf::RepeatedFieldBackInserter(
269 value.mutable_lhs_contracting_dimensions()));
270 dims = getattr(handle, "rhs_contracting_dimensions")
271 .cast<std::vector<xla::int64>>();
272 std::copy(dims.begin(), dims.end(),
273 tensorflow::protobuf::RepeatedFieldBackInserter(
274 value.mutable_rhs_contracting_dimensions()));
275 dims =
276 getattr(handle, "lhs_batch_dimensions").cast<std::vector<xla::int64>>();
277 std::copy(dims.begin(), dims.end(),
278 tensorflow::protobuf::RepeatedFieldBackInserter(
279 value.mutable_lhs_batch_dimensions()));
280 dims =
281 getattr(handle, "rhs_batch_dimensions").cast<std::vector<xla::int64>>();
282 std::copy(dims.begin(), dims.end(),
283 tensorflow::protobuf::RepeatedFieldBackInserter(
284 value.mutable_rhs_batch_dimensions()));
285 return true;
286 }
287 };
288
289 template <>
290 struct type_caster<xla::GatherDimensionNumbers> {
291 public:
292 PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
293 _("xla::GatherDimensionNumbers"));
294
295 // PyObject -> C++ conversion.
296 bool load(handle handle, bool) {
297 std::vector<xla::int64> dims;
298 dims = getattr(handle, "offset_dims").cast<std::vector<xla::int64>>();
299 std::copy(dims.begin(), dims.end(),
300 tensorflow::protobuf::RepeatedFieldBackInserter(
301 value.mutable_offset_dims()));
302 dims =
303 getattr(handle, "collapsed_slice_dims").cast<std::vector<xla::int64>>();
304 std::copy(dims.begin(), dims.end(),
305 tensorflow::protobuf::RepeatedFieldBackInserter(
306 value.mutable_collapsed_slice_dims()));
307 dims = getattr(handle, "start_index_map").cast<std::vector<xla::int64>>();
308 std::copy(dims.begin(), dims.end(),
309 tensorflow::protobuf::RepeatedFieldBackInserter(
310 value.mutable_start_index_map()));
311 value.set_index_vector_dim(
312 getattr(handle, "index_vector_dim").cast<xla::int64>());
313 return true;
314 }
315 };
316
317 template <>
318 struct type_caster<xla::ScatterDimensionNumbers> {
319 public:
320 PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
321 _("xla::ScatterDimensionNumbers"));
322
323 // PyObject -> C++ conversion.
324 bool load(handle handle, bool) {
325 std::vector<xla::int64> dims;
326 dims =
327 getattr(handle, "update_window_dims").cast<std::vector<xla::int64>>();
328 std::copy(dims.begin(), dims.end(),
329 tensorflow::protobuf::RepeatedFieldBackInserter(
330 value.mutable_update_window_dims()));
331 dims =
332 getattr(handle, "inserted_window_dims").cast<std::vector<xla::int64>>();
333 std::copy(dims.begin(), dims.end(),
334 tensorflow::protobuf::RepeatedFieldBackInserter(
335 value.mutable_inserted_window_dims()));
336 dims = getattr(handle, "scatter_dims_to_operand_dims")
337 .cast<std::vector<xla::int64>>();
338 std::copy(dims.begin(), dims.end(),
339 tensorflow::protobuf::RepeatedFieldBackInserter(
340 value.mutable_scatter_dims_to_operand_dims()));
341 value.set_index_vector_dim(
342 getattr(handle, "index_vector_dim").cast<xla::int64>());
343 return true;
344 }
345 };
346
347 template <>
348 struct type_caster<xla::ReplicaGroup> {
349 public:
350 PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
351
352 // PyObject -> C++ conversion.
353 bool load(handle handle, bool) {
354 std::vector<xla::int64> dims;
355 dims = getattr(handle, "replica_ids").cast<std::vector<xla::int64>>();
356 std::copy(dims.begin(), dims.end(),
357 tensorflow::protobuf::RepeatedFieldBackInserter(
358 value.mutable_replica_ids()));
359 return true;
360 }
361 };
362
363 template <>
364 struct type_caster<xla::PaddingConfig> {
365 public:
366 PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
367
368 // PyObject -> C++ conversion.
369 bool load(handle handle, bool) {
370 sequence dimensions =
371 reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
372
373 for (const auto& dimension : dimensions) {
374 xla::PaddingConfig::PaddingConfigDimension* config_dim =
375 value.add_dimensions();
376 config_dim->set_edge_padding_low(
377 getattr(dimension, "edge_padding_low").cast<xla::int64>());
378 config_dim->set_edge_padding_high(
379 getattr(dimension, "edge_padding_high").cast<xla::int64>());
380 config_dim->set_interior_padding(
381 getattr(dimension, "interior_padding").cast<xla::int64>());
382 }
383 return true;
384 }
385 };
386
387 template <>
388 struct type_caster<xla::OpMetadata> {
389 public:
390 PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
391
392 // PyObject -> C++ conversion.
393 bool load(handle handle, bool) {
394 pybind11::handle op_type = getattr(handle, "op_type");
395 if (!op_type.is_none()) {
396 value.set_op_type(op_type.cast<std::string>());
397 }
398 pybind11::handle op_name = getattr(handle, "op_name");
399 if (!op_name.is_none()) {
400 value.set_op_name(op_name.cast<std::string>());
401 }
402 pybind11::handle source_file = getattr(handle, "source_file");
403 if (!source_file.is_none()) {
404 value.set_source_file(source_file.cast<std::string>());
405 }
406 pybind11::handle source_line = getattr(handle, "source_line");
407 if (!source_line.is_none()) {
408 value.set_source_line(source_line.cast<xla::int32>());
409 }
410 return true;
411 }
412 };
413
414 template <>
415 struct type_caster<xla::PrecisionConfig> {
416 public:
417 PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
418
419 // PyObject -> C++ conversion.
420 bool load(handle handle, bool) {
421 if (handle.is_none()) {
422 return true;
423 }
424
425 sequence operand_precisions =
426 reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
427
428 for (const auto& operand_precision : operand_precisions) {
429 value.add_operand_precision(
430 operand_precision.cast<xla::PrecisionConfig::Precision>());
431 }
432 return true;
433 }
434 };
435
436 template <>
437 struct type_caster<xla::OpSharding> {
438 public:
439 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
440
441 // PyObject -> C++ conversion.
442 bool load(handle handle_obj, bool) {
443 if (handle_obj.is_none()) {
444 return true;
445 }
446
447 // Sets `type` field.
448 handle sharding_type = getattr(handle_obj, "type");
449 if (!sharding_type.is_none()) {
450 value.set_type(sharding_type.cast<xla::OpSharding_Type>());
451 }
452
453 // Sets `tile_assignment_dimensions` field.
454 std::vector<xla::int64> dims;
455 dims = getattr(handle_obj, "tile_assignment_dimensions")
456 .cast<std::vector<xla::int64>>();
457 std::copy(dims.begin(), dims.end(),
458 tensorflow::protobuf::RepeatedFieldBackInserter(
459 value.mutable_tile_assignment_dimensions()));
460
461 // Sets `tile_assignment_devices` field.
462 std::vector<xla::int64> devices;
463 devices = getattr(handle_obj, "tile_assignment_devices")
464 .cast<std::vector<xla::int64>>();
465 std::copy(devices.begin(), devices.end(),
466 tensorflow::protobuf::RepeatedFieldBackInserter(
467 value.mutable_tile_assignment_devices()));
468
469 // Sets `tuple_shardings` field.
470 sequence tuple_shardings =
471 reinterpret_borrow<sequence>(getattr(handle_obj, "tuple_shardings"));
472
473 for (const auto& tuple_sharding : tuple_shardings) {
474 xla::OpSharding* sharding = value.add_tuple_shardings();
475
476 handle sharding_type = getattr(tuple_sharding, "type");
477 if (!sharding_type.is_none()) {
478 sharding->set_type(sharding_type.cast<xla::OpSharding_Type>());
479 }
480 std::vector<xla::int64> dims;
481 dims = getattr(tuple_sharding, "tile_assignment_dimensions")
482 .cast<std::vector<xla::int64>>();
483 std::copy(dims.begin(), dims.end(),
484 tensorflow::protobuf::RepeatedFieldBackInserter(
485 sharding->mutable_tile_assignment_dimensions()));
486
487 std::vector<xla::int64> devices;
488 devices = getattr(tuple_sharding, "tile_assignment_devices")
489 .cast<std::vector<xla::int64>>();
490 std::copy(devices.begin(), devices.end(),
491 tensorflow::protobuf::RepeatedFieldBackInserter(
492 sharding->mutable_tile_assignment_devices()));
493
494 sharding->set_replicate_on_last_tile_dim(
495 getattr(tuple_sharding, "replicate_on_last_tile_dim").cast<bool>());
496 }
497
498 // Sets `replicate_on_last_tile_dim` field.
499 value.set_replicate_on_last_tile_dim(
500 getattr(handle_obj, "replicate_on_last_tile_dim").cast<bool>());
501
502 return true;
503 }
504 };
505
506 } // namespace detail
507 } // namespace pybind11
508
509 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
510