1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl.h"
28 #include "tensorflow/compiler/xla/python/absl_casters.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/python/status_casters.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 namespace xla {
39
40 // Converts a NumPy dtype to a PrimitiveType.
41 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
42
43 // Converts a PrimitiveType to a Numpy dtype.
44 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
45
46 // Returns a numpy-style format descriptor string for `type`.
47 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
48
49 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
50 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
51
52 struct NumpyScalarTypes {
53 pybind11::object np_bool;
54 pybind11::object np_int8;
55 pybind11::object np_int16;
56 pybind11::object np_int32;
57 pybind11::object np_int64;
58 pybind11::object np_uint8;
59 pybind11::object np_uint16;
60 pybind11::object np_uint32;
61 pybind11::object np_uint64;
62 pybind11::object np_bfloat16;
63 pybind11::object np_float16;
64 pybind11::object np_float32;
65 pybind11::object np_float64;
66 pybind11::object np_complex64;
67 pybind11::object np_complex128;
68 pybind11::object np_longlong;
69 pybind11::object np_intc;
70 };
71 const NumpyScalarTypes& GetNumpyScalarTypes();
72
73 // For S64/U64/F64/C128 types, returns the largest 32-bit equivalent.
74 PrimitiveType Squash64BitTypes(PrimitiveType type);
75
76 // Returns the strides for `shape`.
77 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
78 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape);
79
80 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
81 // The literal's leaf arrays are not copied; instead the NumPy arrays share
82 // buffers with the literals. Takes ownership of `literal` and keeps the
83 // necessary pieces alive using Python reference counting.
84 // Requires the GIL.
85 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
86
87 // Converts a Python object into an XLA shape and a vector of leaf buffers.
88 // The leaf buffers correspond to a depth-first, left-to-right traversal of
89 // the Python value.
90 // Requires the GIL.
91 struct PythonBufferTree {
92 // Holds a reference to the arrays pointed to by `leaves`, since we may
93 // need to make a copy if the array is not in a C-style layout.
94 absl::InlinedVector<pybind11::object, 1> arrays;
95 absl::InlinedVector<BorrowingLiteral, 1> leaves;
96 Shape shape;
97 };
98 StatusOr<PythonBufferTree> GetPythonBufferTree(
99 const pybind11::object& argument);
100
101 // Converts a sequence of C++ ints to a Python tuple of ints.
102 // Pybind11 by default converts a std::vector<T> to a Python list;
103 // we frequently want a tuple instead e.g. for shapes.
104 template <typename T>
SpanToTuple(absl::Span<T const> xs)105 pybind11::tuple SpanToTuple(absl::Span<T const> xs) {
106 pybind11::tuple out(xs.size());
107 for (int i = 0; i < xs.size(); ++i) {
108 out[i] = pybind11::cast(xs[i]);
109 }
110 return out;
111 }
112 template <>
113 pybind11::tuple SpanToTuple(absl::Span<int const> xs);
114 template <>
115 pybind11::tuple SpanToTuple(absl::Span<int64 const> xs);
116
117 // Converts a Python iterable/sequence of T to std::vector<T>
118 template <typename T>
IterableToVector(const pybind11::iterable & iterable)119 std::vector<T> IterableToVector(const pybind11::iterable& iterable) {
120 std::vector<T> output;
121 for (auto item : iterable) {
122 output.push_back(item.cast<T>());
123 }
124 return output;
125 }
126 template <typename T>
SequenceToVector(const pybind11::sequence & sequence)127 std::vector<T> SequenceToVector(const pybind11::sequence& sequence) {
128 std::vector<T> output;
129 output.reserve(sequence.size());
130 for (auto item : sequence) {
131 output.push_back(item.cast<T>());
132 }
133 return output;
134 }
135
136 // Private helper function used in the implementation of the type caster for
137 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
138 // pointer and shape.
139 struct CastToArrayResult {
140 pybind11::object array; // Holds a reference to the array to keep it alive.
141 const char* buf_ptr;
142 xla::Shape shape;
143 };
144 absl::optional<CastToArrayResult> CastToArray(pybind11::handle h);
145
146 } // namespace xla
147
148 // This namespace is a documented pybind11 extension point.
149 // Caution: Unusually for Google code, this code uses C++ exceptions because
150 // they are the only mechanism for reporting cast failures to pybind11. However,
151 // the exceptions are local to the binding code.
152 namespace pybind11 {
153 namespace detail {
154
155 // Literals.
156 // Literal data can be passed to XLA as a NumPy array; its value can be
157 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
158 // We don't have any literal -> numpy conversions here, since all the methods
159 // that want to return arrays build Python objects directly.
160
161 template <>
162 struct type_caster<xla::BorrowingLiteral> {
163 public:
164 PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
165
166 // Pybind appears to keep type_casters alive until the callee has run.
167 absl::InlinedVector<pybind11::array, 1> arrays;
168
169 bool load(handle input, bool) {
170 // TODO(b/79707221): support nested tuples if/when XLA adds support for
171 // nested BorrowingLiterals.
172 if (pybind11::isinstance<pybind11::tuple>(input)) {
173 pybind11::tuple tuple =
174 pybind11::reinterpret_borrow<pybind11::tuple>(input);
175 std::vector<xla::Shape> shapes;
176 std::vector<const char*> buffers;
177 arrays.reserve(tuple.size());
178 shapes.reserve(tuple.size());
179 buffers.reserve(tuple.size());
180 for (pybind11::handle entry : tuple) {
181 auto c = xla::CastToArray(entry);
182 if (!c) {
183 return false;
184 }
185 arrays.push_back(c->array);
186 buffers.push_back(c->buf_ptr);
187 shapes.push_back(c->shape);
188 }
189 value = xla::BorrowingLiteral(buffers,
190 xla::ShapeUtil::MakeTupleShape(shapes));
191 } else {
192 auto c = xla::CastToArray(input);
193 if (!c) {
194 return false;
195 }
196 arrays.push_back(c->array);
197 value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
198 }
199 return true;
200 }
201 };
202
203 template <>
204 struct type_caster<xla::LiteralSlice> {
205 public:
206 PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
207
208 // Pybind appears to keep type_casters alive until the callee has run.
209 type_caster<xla::BorrowingLiteral> literal_caster;
210
211 bool load(handle handle, bool convert) {
212 if (!literal_caster.load(handle, convert)) {
213 return false;
214 }
215 value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
216 return true;
217 }
218 };
219
220 // XLA protocol buffers
221 // We don't actually care that these are the protocol buffers, we merely want
222 // objects that duck type as protocol buffers. The client code currently avoids
223 // depending on Python protocol buffers to avoid conflicting definitions from
224 // different modules that both include XLA.
225
226 template <>
227 struct type_caster<xla::ConvolutionDimensionNumbers> {
228 public:
229 PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
230 _("xla::ConvolutionDimensionNumbers"));
231
232 // PyObject -> C++ conversion.
233 bool load(handle handle, bool) {
234 value.set_input_batch_dimension(
235 getattr(handle, "input_batch_dimension").cast<xla::int64>());
236 value.set_input_feature_dimension(
237 getattr(handle, "input_feature_dimension").cast<xla::int64>());
238 value.set_output_batch_dimension(
239 getattr(handle, "output_batch_dimension").cast<xla::int64>());
240 value.set_output_feature_dimension(
241 getattr(handle, "output_feature_dimension").cast<xla::int64>());
242 value.set_kernel_input_feature_dimension(
243 getattr(handle, "kernel_input_feature_dimension").cast<xla::int64>());
244 value.set_kernel_output_feature_dimension(
245 getattr(handle, "kernel_output_feature_dimension").cast<xla::int64>());
246 std::vector<xla::int64> dims;
247 dims = getattr(handle, "input_spatial_dimensions")
248 .cast<std::vector<xla::int64>>();
249 std::copy(dims.begin(), dims.end(),
250 tensorflow::protobuf::RepeatedFieldBackInserter(
251 value.mutable_input_spatial_dimensions()));
252 dims = getattr(handle, "kernel_spatial_dimensions")
253 .cast<std::vector<xla::int64>>();
254 std::copy(dims.begin(), dims.end(),
255 tensorflow::protobuf::RepeatedFieldBackInserter(
256 value.mutable_kernel_spatial_dimensions()));
257 dims = getattr(handle, "output_spatial_dimensions")
258 .cast<std::vector<xla::int64>>();
259 std::copy(dims.begin(), dims.end(),
260 tensorflow::protobuf::RepeatedFieldBackInserter(
261 value.mutable_output_spatial_dimensions()));
262 return true;
263 }
264 };
265
266 template <>
267 struct type_caster<xla::DotDimensionNumbers> {
268 public:
269 PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
270
271 // PyObject -> C++ conversion.
272 bool load(handle handle, bool) {
273 std::vector<xla::int64> dims;
274 dims = getattr(handle, "lhs_contracting_dimensions")
275 .cast<std::vector<xla::int64>>();
276 std::copy(dims.begin(), dims.end(),
277 tensorflow::protobuf::RepeatedFieldBackInserter(
278 value.mutable_lhs_contracting_dimensions()));
279 dims = getattr(handle, "rhs_contracting_dimensions")
280 .cast<std::vector<xla::int64>>();
281 std::copy(dims.begin(), dims.end(),
282 tensorflow::protobuf::RepeatedFieldBackInserter(
283 value.mutable_rhs_contracting_dimensions()));
284 dims =
285 getattr(handle, "lhs_batch_dimensions").cast<std::vector<xla::int64>>();
286 std::copy(dims.begin(), dims.end(),
287 tensorflow::protobuf::RepeatedFieldBackInserter(
288 value.mutable_lhs_batch_dimensions()));
289 dims =
290 getattr(handle, "rhs_batch_dimensions").cast<std::vector<xla::int64>>();
291 std::copy(dims.begin(), dims.end(),
292 tensorflow::protobuf::RepeatedFieldBackInserter(
293 value.mutable_rhs_batch_dimensions()));
294 return true;
295 }
296 };
297
298 template <>
299 struct type_caster<xla::GatherDimensionNumbers> {
300 public:
301 PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
302 _("xla::GatherDimensionNumbers"));
303
304 // PyObject -> C++ conversion.
305 bool load(handle handle, bool) {
306 std::vector<xla::int64> dims;
307 dims = getattr(handle, "offset_dims").cast<std::vector<xla::int64>>();
308 std::copy(dims.begin(), dims.end(),
309 tensorflow::protobuf::RepeatedFieldBackInserter(
310 value.mutable_offset_dims()));
311 dims =
312 getattr(handle, "collapsed_slice_dims").cast<std::vector<xla::int64>>();
313 std::copy(dims.begin(), dims.end(),
314 tensorflow::protobuf::RepeatedFieldBackInserter(
315 value.mutable_collapsed_slice_dims()));
316 dims = getattr(handle, "start_index_map").cast<std::vector<xla::int64>>();
317 std::copy(dims.begin(), dims.end(),
318 tensorflow::protobuf::RepeatedFieldBackInserter(
319 value.mutable_start_index_map()));
320 value.set_index_vector_dim(
321 getattr(handle, "index_vector_dim").cast<xla::int64>());
322 return true;
323 }
324 };
325
326 template <>
327 struct type_caster<xla::ScatterDimensionNumbers> {
328 public:
329 PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
330 _("xla::ScatterDimensionNumbers"));
331
332 // PyObject -> C++ conversion.
333 bool load(handle handle, bool) {
334 std::vector<xla::int64> dims;
335 dims =
336 getattr(handle, "update_window_dims").cast<std::vector<xla::int64>>();
337 std::copy(dims.begin(), dims.end(),
338 tensorflow::protobuf::RepeatedFieldBackInserter(
339 value.mutable_update_window_dims()));
340 dims =
341 getattr(handle, "inserted_window_dims").cast<std::vector<xla::int64>>();
342 std::copy(dims.begin(), dims.end(),
343 tensorflow::protobuf::RepeatedFieldBackInserter(
344 value.mutable_inserted_window_dims()));
345 dims = getattr(handle, "scatter_dims_to_operand_dims")
346 .cast<std::vector<xla::int64>>();
347 std::copy(dims.begin(), dims.end(),
348 tensorflow::protobuf::RepeatedFieldBackInserter(
349 value.mutable_scatter_dims_to_operand_dims()));
350 value.set_index_vector_dim(
351 getattr(handle, "index_vector_dim").cast<xla::int64>());
352 return true;
353 }
354 };
355
356 template <>
357 struct type_caster<xla::ReplicaGroup> {
358 public:
359 PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
360
361 // PyObject -> C++ conversion.
362 bool load(handle handle, bool) {
363 std::vector<xla::int64> dims;
364 dims = getattr(handle, "replica_ids").cast<std::vector<xla::int64>>();
365 std::copy(dims.begin(), dims.end(),
366 tensorflow::protobuf::RepeatedFieldBackInserter(
367 value.mutable_replica_ids()));
368 return true;
369 }
370 };
371
372 template <>
373 struct type_caster<xla::PaddingConfig> {
374 public:
375 PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
376
377 // PyObject -> C++ conversion.
378 bool load(handle handle, bool) {
379 sequence dimensions =
380 reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
381
382 for (const auto& dimension : dimensions) {
383 xla::PaddingConfig::PaddingConfigDimension* config_dim =
384 value.add_dimensions();
385 config_dim->set_edge_padding_low(
386 getattr(dimension, "edge_padding_low").cast<xla::int64>());
387 config_dim->set_edge_padding_high(
388 getattr(dimension, "edge_padding_high").cast<xla::int64>());
389 config_dim->set_interior_padding(
390 getattr(dimension, "interior_padding").cast<xla::int64>());
391 }
392 return true;
393 }
394 };
395
396 template <>
397 struct type_caster<xla::OpMetadata> {
398 public:
399 PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
400
401 // PyObject -> C++ conversion.
402 bool load(handle handle, bool) {
403 pybind11::handle op_type = getattr(handle, "op_type");
404 if (!op_type.is_none()) {
405 value.set_op_type(op_type.cast<std::string>());
406 }
407 pybind11::handle op_name = getattr(handle, "op_name");
408 if (!op_name.is_none()) {
409 value.set_op_name(op_name.cast<std::string>());
410 }
411 pybind11::handle source_file = getattr(handle, "source_file");
412 if (!source_file.is_none()) {
413 value.set_source_file(source_file.cast<std::string>());
414 }
415 pybind11::handle source_line = getattr(handle, "source_line");
416 if (!source_line.is_none()) {
417 value.set_source_line(source_line.cast<xla::int32>());
418 }
419 return true;
420 }
421 };
422
423 template <>
424 struct type_caster<xla::PrecisionConfig> {
425 public:
426 PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
427
428 // PyObject -> C++ conversion.
429 bool load(handle handle, bool) {
430 if (handle.is_none()) {
431 return true;
432 }
433
434 sequence operand_precisions =
435 reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
436
437 for (const auto& operand_precision : operand_precisions) {
438 value.add_operand_precision(
439 operand_precision.cast<xla::PrecisionConfig::Precision>());
440 }
441 return true;
442 }
443 };
444
445 template <>
446 struct type_caster<xla::OpSharding> {
447 public:
448 PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
449
450 // PyObject -> C++ conversion.
451 bool load(handle handle_obj, bool) {
452 if (handle_obj.is_none()) {
453 return true;
454 }
455
456 // Sets `type` field.
457 handle sharding_type = getattr(handle_obj, "type");
458 if (!sharding_type.is_none()) {
459 value.set_type(sharding_type.cast<xla::OpSharding_Type>());
460 }
461
462 // Sets `tile_assignment_dimensions` field.
463 std::vector<xla::int64> dims;
464 dims = getattr(handle_obj, "tile_assignment_dimensions")
465 .cast<std::vector<xla::int64>>();
466 std::copy(dims.begin(), dims.end(),
467 tensorflow::protobuf::RepeatedFieldBackInserter(
468 value.mutable_tile_assignment_dimensions()));
469
470 // Sets `tile_assignment_devices` field.
471 std::vector<xla::int64> devices;
472 devices = getattr(handle_obj, "tile_assignment_devices")
473 .cast<std::vector<xla::int64>>();
474 std::copy(devices.begin(), devices.end(),
475 tensorflow::protobuf::RepeatedFieldBackInserter(
476 value.mutable_tile_assignment_devices()));
477
478 // Sets `tuple_shardings` field.
479 sequence tuple_shardings =
480 reinterpret_borrow<sequence>(getattr(handle_obj, "tuple_shardings"));
481
482 for (const auto& tuple_sharding : tuple_shardings) {
483 xla::OpSharding* sharding = value.add_tuple_shardings();
484
485 handle sharding_type = getattr(tuple_sharding, "type");
486 if (!sharding_type.is_none()) {
487 sharding->set_type(sharding_type.cast<xla::OpSharding_Type>());
488 }
489 std::vector<xla::int64> dims;
490 dims = getattr(tuple_sharding, "tile_assignment_dimensions")
491 .cast<std::vector<xla::int64>>();
492 std::copy(dims.begin(), dims.end(),
493 tensorflow::protobuf::RepeatedFieldBackInserter(
494 sharding->mutable_tile_assignment_dimensions()));
495
496 std::vector<xla::int64> devices;
497 devices = getattr(tuple_sharding, "tile_assignment_devices")
498 .cast<std::vector<xla::int64>>();
499 std::copy(devices.begin(), devices.end(),
500 tensorflow::protobuf::RepeatedFieldBackInserter(
501 sharding->mutable_tile_assignment_devices()));
502
503 sharding->set_replicate_on_last_tile_dim(
504 getattr(tuple_sharding, "replicate_on_last_tile_dim").cast<bool>());
505 }
506
507 // Sets `replicate_on_last_tile_dim` field.
508 value.set_replicate_on_last_tile_dim(
509 getattr(handle_obj, "replicate_on_last_tile_dim").cast<bool>());
510
511 return true;
512 }
513 };
514
515 } // namespace detail
516 } // namespace pybind11
517
518 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
519