• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl.h"
28 #include "tensorflow/compiler/xla/python/absl_casters.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/python/status_casters.h"
31 #include "tensorflow/compiler/xla/shape.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 namespace xla {
39 
40 // Converts a NumPy dtype to a PrimitiveType.
41 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
42 
43 // Converts a PrimitiveType to a Numpy dtype.
44 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
45 
46 // Returns a numpy-style format descriptor string for `type`.
47 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
48 
49 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
50 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
51 
52 struct NumpyScalarTypes {
53   pybind11::object np_bool;
54   pybind11::object np_int8;
55   pybind11::object np_int16;
56   pybind11::object np_int32;
57   pybind11::object np_int64;
58   pybind11::object np_uint8;
59   pybind11::object np_uint16;
60   pybind11::object np_uint32;
61   pybind11::object np_uint64;
62   pybind11::object np_bfloat16;
63   pybind11::object np_float16;
64   pybind11::object np_float32;
65   pybind11::object np_float64;
66   pybind11::object np_complex64;
67   pybind11::object np_complex128;
68   pybind11::object np_longlong;
69   pybind11::object np_intc;
70 };
71 const NumpyScalarTypes& GetNumpyScalarTypes();
72 
73 // For S64/U64/F64/C128 types, returns the largest 32-bit equivalent.
74 PrimitiveType Squash64BitTypes(PrimitiveType type);
75 
76 // Returns the strides for `shape`.
77 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
78 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape);
79 
80 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
81 // The literal's leaf arrays are not copied; instead the NumPy arrays share
82 // buffers with the literals. Takes ownership of `literal` and keeps the
83 // necessary pieces alive using Python reference counting.
84 // Requires the GIL.
85 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
86 
87 // Converts a Python object into an XLA shape and a vector of leaf buffers.
88 // The leaf buffers correspond to a depth-first, left-to-right traversal of
89 // the Python value.
90 // Requires the GIL.
91 struct PythonBufferTree {
92   // Holds a reference to the arrays pointed to by `leaves`, since we may
93   // need to make a copy if the array is not in a C-style layout.
94   absl::InlinedVector<pybind11::object, 1> arrays;
95   absl::InlinedVector<BorrowingLiteral, 1> leaves;
96   Shape shape;
97 };
98 StatusOr<PythonBufferTree> GetPythonBufferTree(
99     const pybind11::object& argument);
100 
101 // Converts a sequence of C++ ints to a Python tuple of ints.
102 // Pybind11 by default converts a std::vector<T> to a Python list;
103 // we frequently want a tuple instead e.g. for shapes.
104 template <typename T>
SpanToTuple(absl::Span<T const> xs)105 pybind11::tuple SpanToTuple(absl::Span<T const> xs) {
106   pybind11::tuple out(xs.size());
107   for (int i = 0; i < xs.size(); ++i) {
108     out[i] = pybind11::cast(xs[i]);
109   }
110   return out;
111 }
112 template <>
113 pybind11::tuple SpanToTuple(absl::Span<int const> xs);
114 template <>
115 pybind11::tuple SpanToTuple(absl::Span<int64 const> xs);
116 
117 // Converts a Python iterable/sequence of T to std::vector<T>
118 template <typename T>
IterableToVector(const pybind11::iterable & iterable)119 std::vector<T> IterableToVector(const pybind11::iterable& iterable) {
120   std::vector<T> output;
121   for (auto item : iterable) {
122     output.push_back(item.cast<T>());
123   }
124   return output;
125 }
126 template <typename T>
SequenceToVector(const pybind11::sequence & sequence)127 std::vector<T> SequenceToVector(const pybind11::sequence& sequence) {
128   std::vector<T> output;
129   output.reserve(sequence.size());
130   for (auto item : sequence) {
131     output.push_back(item.cast<T>());
132   }
133   return output;
134 }
135 
136 // Private helper function used in the implementation of the type caster for
137 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
138 // pointer and shape.
139 struct CastToArrayResult {
140   pybind11::object array;  // Holds a reference to the array to keep it alive.
141   const char* buf_ptr;
142   xla::Shape shape;
143 };
144 absl::optional<CastToArrayResult> CastToArray(pybind11::handle h);
145 
146 }  // namespace xla
147 
148 // This namespace is a documented pybind11 extension point.
149 // Caution: Unusually for Google code, this code uses C++ exceptions because
150 // they are the only mechanism for reporting cast failures to pybind11. However,
151 // the exceptions are local to the binding code.
152 namespace pybind11 {
153 namespace detail {
154 
155 // Literals.
156 // Literal data can be passed to XLA as a NumPy array; its value can be
157 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
158 // We don't have any literal -> numpy conversions here, since all the methods
159 // that want to return arrays build Python objects directly.
160 
161 template <>
162 struct type_caster<xla::BorrowingLiteral> {
163  public:
164   PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
165 
166   // Pybind appears to keep type_casters alive until the callee has run.
167   absl::InlinedVector<pybind11::array, 1> arrays;
168 
169   bool load(handle input, bool) {
170     // TODO(b/79707221): support nested tuples if/when XLA adds support for
171     // nested BorrowingLiterals.
172     if (pybind11::isinstance<pybind11::tuple>(input)) {
173       pybind11::tuple tuple =
174           pybind11::reinterpret_borrow<pybind11::tuple>(input);
175       std::vector<xla::Shape> shapes;
176       std::vector<const char*> buffers;
177       arrays.reserve(tuple.size());
178       shapes.reserve(tuple.size());
179       buffers.reserve(tuple.size());
180       for (pybind11::handle entry : tuple) {
181         auto c = xla::CastToArray(entry);
182         if (!c) {
183           return false;
184         }
185         arrays.push_back(c->array);
186         buffers.push_back(c->buf_ptr);
187         shapes.push_back(c->shape);
188       }
189       value = xla::BorrowingLiteral(buffers,
190                                     xla::ShapeUtil::MakeTupleShape(shapes));
191     } else {
192       auto c = xla::CastToArray(input);
193       if (!c) {
194         return false;
195       }
196       arrays.push_back(c->array);
197       value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
198     }
199     return true;
200   }
201 };
202 
203 template <>
204 struct type_caster<xla::LiteralSlice> {
205  public:
206   PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
207 
208   // Pybind appears to keep type_casters alive until the callee has run.
209   type_caster<xla::BorrowingLiteral> literal_caster;
210 
211   bool load(handle handle, bool convert) {
212     if (!literal_caster.load(handle, convert)) {
213       return false;
214     }
215     value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
216     return true;
217   }
218 };
219 
220 // XLA protocol buffers
221 // We don't actually care that these are the protocol buffers, we merely want
222 // objects that duck type as protocol buffers. The client code currently avoids
223 // depending on Python protocol buffers to avoid conflicting definitions from
224 // different modules that both include XLA.
225 
226 template <>
227 struct type_caster<xla::ConvolutionDimensionNumbers> {
228  public:
229   PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
230                        _("xla::ConvolutionDimensionNumbers"));
231 
232   // PyObject -> C++ conversion.
233   bool load(handle handle, bool) {
234     value.set_input_batch_dimension(
235         getattr(handle, "input_batch_dimension").cast<xla::int64>());
236     value.set_input_feature_dimension(
237         getattr(handle, "input_feature_dimension").cast<xla::int64>());
238     value.set_output_batch_dimension(
239         getattr(handle, "output_batch_dimension").cast<xla::int64>());
240     value.set_output_feature_dimension(
241         getattr(handle, "output_feature_dimension").cast<xla::int64>());
242     value.set_kernel_input_feature_dimension(
243         getattr(handle, "kernel_input_feature_dimension").cast<xla::int64>());
244     value.set_kernel_output_feature_dimension(
245         getattr(handle, "kernel_output_feature_dimension").cast<xla::int64>());
246     std::vector<xla::int64> dims;
247     dims = getattr(handle, "input_spatial_dimensions")
248                .cast<std::vector<xla::int64>>();
249     std::copy(dims.begin(), dims.end(),
250               tensorflow::protobuf::RepeatedFieldBackInserter(
251                   value.mutable_input_spatial_dimensions()));
252     dims = getattr(handle, "kernel_spatial_dimensions")
253                .cast<std::vector<xla::int64>>();
254     std::copy(dims.begin(), dims.end(),
255               tensorflow::protobuf::RepeatedFieldBackInserter(
256                   value.mutable_kernel_spatial_dimensions()));
257     dims = getattr(handle, "output_spatial_dimensions")
258                .cast<std::vector<xla::int64>>();
259     std::copy(dims.begin(), dims.end(),
260               tensorflow::protobuf::RepeatedFieldBackInserter(
261                   value.mutable_output_spatial_dimensions()));
262     return true;
263   }
264 };
265 
266 template <>
267 struct type_caster<xla::DotDimensionNumbers> {
268  public:
269   PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
270 
271   // PyObject -> C++ conversion.
272   bool load(handle handle, bool) {
273     std::vector<xla::int64> dims;
274     dims = getattr(handle, "lhs_contracting_dimensions")
275                .cast<std::vector<xla::int64>>();
276     std::copy(dims.begin(), dims.end(),
277               tensorflow::protobuf::RepeatedFieldBackInserter(
278                   value.mutable_lhs_contracting_dimensions()));
279     dims = getattr(handle, "rhs_contracting_dimensions")
280                .cast<std::vector<xla::int64>>();
281     std::copy(dims.begin(), dims.end(),
282               tensorflow::protobuf::RepeatedFieldBackInserter(
283                   value.mutable_rhs_contracting_dimensions()));
284     dims =
285         getattr(handle, "lhs_batch_dimensions").cast<std::vector<xla::int64>>();
286     std::copy(dims.begin(), dims.end(),
287               tensorflow::protobuf::RepeatedFieldBackInserter(
288                   value.mutable_lhs_batch_dimensions()));
289     dims =
290         getattr(handle, "rhs_batch_dimensions").cast<std::vector<xla::int64>>();
291     std::copy(dims.begin(), dims.end(),
292               tensorflow::protobuf::RepeatedFieldBackInserter(
293                   value.mutable_rhs_batch_dimensions()));
294     return true;
295   }
296 };
297 
298 template <>
299 struct type_caster<xla::GatherDimensionNumbers> {
300  public:
301   PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
302                        _("xla::GatherDimensionNumbers"));
303 
304   // PyObject -> C++ conversion.
305   bool load(handle handle, bool) {
306     std::vector<xla::int64> dims;
307     dims = getattr(handle, "offset_dims").cast<std::vector<xla::int64>>();
308     std::copy(dims.begin(), dims.end(),
309               tensorflow::protobuf::RepeatedFieldBackInserter(
310                   value.mutable_offset_dims()));
311     dims =
312         getattr(handle, "collapsed_slice_dims").cast<std::vector<xla::int64>>();
313     std::copy(dims.begin(), dims.end(),
314               tensorflow::protobuf::RepeatedFieldBackInserter(
315                   value.mutable_collapsed_slice_dims()));
316     dims = getattr(handle, "start_index_map").cast<std::vector<xla::int64>>();
317     std::copy(dims.begin(), dims.end(),
318               tensorflow::protobuf::RepeatedFieldBackInserter(
319                   value.mutable_start_index_map()));
320     value.set_index_vector_dim(
321         getattr(handle, "index_vector_dim").cast<xla::int64>());
322     return true;
323   }
324 };
325 
326 template <>
327 struct type_caster<xla::ScatterDimensionNumbers> {
328  public:
329   PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
330                        _("xla::ScatterDimensionNumbers"));
331 
332   // PyObject -> C++ conversion.
333   bool load(handle handle, bool) {
334     std::vector<xla::int64> dims;
335     dims =
336         getattr(handle, "update_window_dims").cast<std::vector<xla::int64>>();
337     std::copy(dims.begin(), dims.end(),
338               tensorflow::protobuf::RepeatedFieldBackInserter(
339                   value.mutable_update_window_dims()));
340     dims =
341         getattr(handle, "inserted_window_dims").cast<std::vector<xla::int64>>();
342     std::copy(dims.begin(), dims.end(),
343               tensorflow::protobuf::RepeatedFieldBackInserter(
344                   value.mutable_inserted_window_dims()));
345     dims = getattr(handle, "scatter_dims_to_operand_dims")
346                .cast<std::vector<xla::int64>>();
347     std::copy(dims.begin(), dims.end(),
348               tensorflow::protobuf::RepeatedFieldBackInserter(
349                   value.mutable_scatter_dims_to_operand_dims()));
350     value.set_index_vector_dim(
351         getattr(handle, "index_vector_dim").cast<xla::int64>());
352     return true;
353   }
354 };
355 
356 template <>
357 struct type_caster<xla::ReplicaGroup> {
358  public:
359   PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
360 
361   // PyObject -> C++ conversion.
362   bool load(handle handle, bool) {
363     std::vector<xla::int64> dims;
364     dims = getattr(handle, "replica_ids").cast<std::vector<xla::int64>>();
365     std::copy(dims.begin(), dims.end(),
366               tensorflow::protobuf::RepeatedFieldBackInserter(
367                   value.mutable_replica_ids()));
368     return true;
369   }
370 };
371 
372 template <>
373 struct type_caster<xla::PaddingConfig> {
374  public:
375   PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
376 
377   // PyObject -> C++ conversion.
378   bool load(handle handle, bool) {
379     sequence dimensions =
380         reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
381 
382     for (const auto& dimension : dimensions) {
383       xla::PaddingConfig::PaddingConfigDimension* config_dim =
384           value.add_dimensions();
385       config_dim->set_edge_padding_low(
386           getattr(dimension, "edge_padding_low").cast<xla::int64>());
387       config_dim->set_edge_padding_high(
388           getattr(dimension, "edge_padding_high").cast<xla::int64>());
389       config_dim->set_interior_padding(
390           getattr(dimension, "interior_padding").cast<xla::int64>());
391     }
392     return true;
393   }
394 };
395 
396 template <>
397 struct type_caster<xla::OpMetadata> {
398  public:
399   PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
400 
401   // PyObject -> C++ conversion.
402   bool load(handle handle, bool) {
403     pybind11::handle op_type = getattr(handle, "op_type");
404     if (!op_type.is_none()) {
405       value.set_op_type(op_type.cast<std::string>());
406     }
407     pybind11::handle op_name = getattr(handle, "op_name");
408     if (!op_name.is_none()) {
409       value.set_op_name(op_name.cast<std::string>());
410     }
411     pybind11::handle source_file = getattr(handle, "source_file");
412     if (!source_file.is_none()) {
413       value.set_source_file(source_file.cast<std::string>());
414     }
415     pybind11::handle source_line = getattr(handle, "source_line");
416     if (!source_line.is_none()) {
417       value.set_source_line(source_line.cast<xla::int32>());
418     }
419     return true;
420   }
421 };
422 
423 template <>
424 struct type_caster<xla::PrecisionConfig> {
425  public:
426   PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
427 
428   // PyObject -> C++ conversion.
429   bool load(handle handle, bool) {
430     if (handle.is_none()) {
431       return true;
432     }
433 
434     sequence operand_precisions =
435         reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
436 
437     for (const auto& operand_precision : operand_precisions) {
438       value.add_operand_precision(
439           operand_precision.cast<xla::PrecisionConfig::Precision>());
440     }
441     return true;
442   }
443 };
444 
445 template <>
446 struct type_caster<xla::OpSharding> {
447  public:
448   PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
449 
450   // PyObject -> C++ conversion.
451   bool load(handle handle_obj, bool) {
452     if (handle_obj.is_none()) {
453       return true;
454     }
455 
456     // Sets `type` field.
457     handle sharding_type = getattr(handle_obj, "type");
458     if (!sharding_type.is_none()) {
459       value.set_type(sharding_type.cast<xla::OpSharding_Type>());
460     }
461 
462     // Sets `tile_assignment_dimensions` field.
463     std::vector<xla::int64> dims;
464     dims = getattr(handle_obj, "tile_assignment_dimensions")
465                .cast<std::vector<xla::int64>>();
466     std::copy(dims.begin(), dims.end(),
467               tensorflow::protobuf::RepeatedFieldBackInserter(
468                   value.mutable_tile_assignment_dimensions()));
469 
470     // Sets `tile_assignment_devices` field.
471     std::vector<xla::int64> devices;
472     devices = getattr(handle_obj, "tile_assignment_devices")
473                   .cast<std::vector<xla::int64>>();
474     std::copy(devices.begin(), devices.end(),
475               tensorflow::protobuf::RepeatedFieldBackInserter(
476                   value.mutable_tile_assignment_devices()));
477 
478     // Sets `tuple_shardings` field.
479     sequence tuple_shardings =
480         reinterpret_borrow<sequence>(getattr(handle_obj, "tuple_shardings"));
481 
482     for (const auto& tuple_sharding : tuple_shardings) {
483       xla::OpSharding* sharding = value.add_tuple_shardings();
484 
485       handle sharding_type = getattr(tuple_sharding, "type");
486       if (!sharding_type.is_none()) {
487         sharding->set_type(sharding_type.cast<xla::OpSharding_Type>());
488       }
489       std::vector<xla::int64> dims;
490       dims = getattr(tuple_sharding, "tile_assignment_dimensions")
491                  .cast<std::vector<xla::int64>>();
492       std::copy(dims.begin(), dims.end(),
493                 tensorflow::protobuf::RepeatedFieldBackInserter(
494                     sharding->mutable_tile_assignment_dimensions()));
495 
496       std::vector<xla::int64> devices;
497       devices = getattr(tuple_sharding, "tile_assignment_devices")
498                     .cast<std::vector<xla::int64>>();
499       std::copy(devices.begin(), devices.end(),
500                 tensorflow::protobuf::RepeatedFieldBackInserter(
501                     sharding->mutable_tile_assignment_devices()));
502 
503       sharding->set_replicate_on_last_tile_dim(
504           getattr(tuple_sharding, "replicate_on_last_tile_dim").cast<bool>());
505     }
506 
507     // Sets `replicate_on_last_tile_dim` field.
508     value.set_replicate_on_last_tile_dim(
509         getattr(handle_obj, "replicate_on_last_tile_dim").cast<bool>());
510 
511     return true;
512   }
513 };
514 
515 }  // namespace detail
516 }  // namespace pybind11
517 
518 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
519