• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "absl/types/optional.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/stl.h"
27 #include "tensorflow/compiler/xla/python/absl_casters.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/status.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 
36 namespace xla {
37 
38 // Helper that converts a failing StatusOr to an exception.
39 // For use only inside pybind11 code.
40 template <typename T>
ValueOrThrow(StatusOr<T> v)41 T ValueOrThrow(StatusOr<T> v) {
42   if (!v.ok()) {
43     throw std::runtime_error(v.status().ToString());
44   }
45   return v.ConsumeValueOrDie();
46 }
47 
48 // Converts a NumPy dtype to a PrimitiveType.
49 StatusOr<PrimitiveType> DtypeToPrimitiveType(const pybind11::dtype& np_type);
50 
51 // Converts a PrimitiveType to a Numpy dtype.
52 StatusOr<pybind11::dtype> PrimitiveTypeToDtype(PrimitiveType type);
53 
54 // Returns a numpy-style format descriptor string for `type`.
55 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type);
56 
57 // Returns a numpy-style typestr for `type`, as returned by np.dtype(...).str
58 StatusOr<pybind11::str> TypeDescriptorForPrimitiveType(PrimitiveType type);
59 
60 // Returns the strides for `shape`.
61 std::vector<ssize_t> ByteStridesForShape(const Shape& shape);
62 
63 // Converts a literal to (possibly-nested tuples of) NumPy arrays.
64 // The literal's leaf arrays are not copied; instead the NumPy arrays share
65 // buffers with the literals. Takes ownership of `literal` and keeps the
66 // necessary pieces alive using Python reference counting.
67 // Requires the GIL.
68 StatusOr<pybind11::object> LiteralToPython(std::shared_ptr<Literal> literal);
69 
70 // Converts a Python object into an XLA shape and a vector of leaf buffers.
71 // The leaf buffers correspond to a depth-first, left-to-right traversal of
72 // the Python value.
73 // Requires the GIL.
74 struct PythonBufferTree {
75   // Holds a reference to the arrays pointed to by `leaves`, since we may
76   // need to make a copy if the array is not in a C-style layout.
77   absl::InlinedVector<pybind11::object, 1> arrays;
78   absl::InlinedVector<BorrowingLiteral, 1> leaves;
79   Shape shape;
80 };
81 StatusOr<PythonBufferTree> GetPythonBufferTree(
82     const pybind11::object& argument);
83 
84 // Converts a sequence of C++ ints to a Python tuple of ints.
85 // Pybind11 by default converts a std::vector<int64> to a Python list;
86 // we frequently want a tuple instead e.g. for shapes.
87 pybind11::tuple IntSpanToTuple(absl::Span<int64 const> xs);
88 pybind11::tuple IntSpanToTuple(absl::Span<int const> xs);
89 
90 // Converts a Python sequence of integers to a std::vector<int64>
91 std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
92 
93 // Private helper function used in the implementation of the type caster for
94 // xla::BorrowingLiteral. Converts a Python array-like object into a buffer
95 // pointer and shape.
96 struct CastToArrayResult {
97   pybind11::object array;  // Holds a reference to the array to keep it alive.
98   const char* buf_ptr;
99   xla::Shape shape;
100 };
101 absl::optional<CastToArrayResult> CastToArray(pybind11::handle h);
102 
103 }  // namespace xla
104 
105 // This namespace is a documented pybind11 extension point.
106 // Caution: Unusually for Google code, this code uses C++ exceptions because
107 // they are the only mechanism for reporting cast failures to pybind11. However,
108 // the exceptions are local to the binding code.
109 namespace pybind11 {
110 namespace detail {
111 
112 // Status, StatusOr. Failing statuses become Python exceptions; Status::OK()
113 // becomes None.
114 template <>
115 struct type_caster<xla::Status> {
116  public:
117   PYBIND11_TYPE_CASTER(xla::Status, _("Status"));
118 
119   static handle cast(xla::Status src, return_value_policy /* policy */,
120                      handle /* parent */) {
121     if (!src.ok()) {
122       throw std::runtime_error(src.ToString());
123     }
124     return none().inc_ref();
125   }
126 };
127 
128 template <typename T>
129 struct type_caster<xla::StatusOr<T>> {
130  public:
131   using value_conv = make_caster<T>;
132 
133   PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
134                        _("StatusOr[") + value_conv::name + _("]"));
135 
136   static handle cast(xla::StatusOr<T> src, return_value_policy policy,
137                      handle parent) {
138     if (!src.ok()) {
139       throw std::runtime_error(src.status().ToString());
140     }
141     return value_conv::cast(std::forward<xla::StatusOr<T>>(src).ValueOrDie(),
142                             policy, parent);
143   }
144 };
145 
146 // Literals.
147 // Literal data can be passed to XLA as a NumPy array; its value can be
148 // cast to an xla::BorrowingLiteral or xla::LiteralSlice in a zero-copy way.
149 // We don't have any literal -> numpy conversions here, since all the methods
150 // that want to return arrays build Python objects directly.
151 
152 template <>
153 struct type_caster<xla::BorrowingLiteral> {
154  public:
155   PYBIND11_TYPE_CASTER(xla::BorrowingLiteral, _("xla::BorrowingLiteral"));
156 
157   // Pybind appears to keep type_casters alive until the callee has run.
158   absl::InlinedVector<pybind11::array, 1> arrays;
159 
160   bool load(handle input, bool) {
161     // TODO(b/79707221): support nested tuples if/when XLA adds support for
162     // nested BorrowingLiterals.
163     if (pybind11::isinstance<pybind11::tuple>(input)) {
164       pybind11::tuple tuple =
165           pybind11::reinterpret_borrow<pybind11::tuple>(input);
166       std::vector<xla::Shape> shapes;
167       std::vector<const char*> buffers;
168       arrays.reserve(tuple.size());
169       shapes.reserve(tuple.size());
170       buffers.reserve(tuple.size());
171       for (pybind11::handle entry : tuple) {
172         auto c = xla::CastToArray(entry);
173         if (!c) {
174           return false;
175         }
176         arrays.push_back(c->array);
177         buffers.push_back(c->buf_ptr);
178         shapes.push_back(c->shape);
179       }
180       value = xla::BorrowingLiteral(buffers,
181                                     xla::ShapeUtil::MakeTupleShape(shapes));
182     } else {
183       auto c = xla::CastToArray(input);
184       if (!c) {
185         return false;
186       }
187       arrays.push_back(c->array);
188       value = xla::BorrowingLiteral(c->buf_ptr, c->shape);
189     }
190     return true;
191   }
192 };
193 
194 template <>
195 struct type_caster<xla::LiteralSlice> {
196  public:
197   PYBIND11_TYPE_CASTER(xla::LiteralSlice, _("xla::LiteralSlice"));
198 
199   // Pybind appears to keep type_casters alive until the callee has run.
200   type_caster<xla::BorrowingLiteral> literal_caster;
201 
202   bool load(handle handle, bool convert) {
203     if (!literal_caster.load(handle, convert)) {
204       return false;
205     }
206     value = static_cast<const xla::BorrowingLiteral&>(literal_caster);
207     return true;
208   }
209 };
210 
211 // XLA protocol buffers
212 // We don't actually care that these are the protocol buffers, we merely want
213 // objects that duck type as protocol buffers. The client code currently avoids
214 // depending on Python protocol buffers to avoid conflicting definitions from
215 // different modules that both include XLA.
216 
217 template <>
218 struct type_caster<xla::ConvolutionDimensionNumbers> {
219  public:
220   PYBIND11_TYPE_CASTER(xla::ConvolutionDimensionNumbers,
221                        _("xla::ConvolutionDimensionNumbers"));
222 
223   // PyObject -> C++ conversion.
224   bool load(handle handle, bool) {
225     value.set_input_batch_dimension(
226         getattr(handle, "input_batch_dimension").cast<xla::int64>());
227     value.set_input_feature_dimension(
228         getattr(handle, "input_feature_dimension").cast<xla::int64>());
229     value.set_output_batch_dimension(
230         getattr(handle, "output_batch_dimension").cast<xla::int64>());
231     value.set_output_feature_dimension(
232         getattr(handle, "output_feature_dimension").cast<xla::int64>());
233     value.set_kernel_input_feature_dimension(
234         getattr(handle, "kernel_input_feature_dimension").cast<xla::int64>());
235     value.set_kernel_output_feature_dimension(
236         getattr(handle, "kernel_output_feature_dimension").cast<xla::int64>());
237     std::vector<xla::int64> dims;
238     dims = getattr(handle, "input_spatial_dimensions")
239                .cast<std::vector<xla::int64>>();
240     std::copy(dims.begin(), dims.end(),
241               tensorflow::protobuf::RepeatedFieldBackInserter(
242                   value.mutable_input_spatial_dimensions()));
243     dims = getattr(handle, "kernel_spatial_dimensions")
244                .cast<std::vector<xla::int64>>();
245     std::copy(dims.begin(), dims.end(),
246               tensorflow::protobuf::RepeatedFieldBackInserter(
247                   value.mutable_kernel_spatial_dimensions()));
248     dims = getattr(handle, "output_spatial_dimensions")
249                .cast<std::vector<xla::int64>>();
250     std::copy(dims.begin(), dims.end(),
251               tensorflow::protobuf::RepeatedFieldBackInserter(
252                   value.mutable_output_spatial_dimensions()));
253     return true;
254   }
255 };
256 
257 template <>
258 struct type_caster<xla::DotDimensionNumbers> {
259  public:
260   PYBIND11_TYPE_CASTER(xla::DotDimensionNumbers, _("xla::DotDimensionNumbers"));
261 
262   // PyObject -> C++ conversion.
263   bool load(handle handle, bool) {
264     std::vector<xla::int64> dims;
265     dims = getattr(handle, "lhs_contracting_dimensions")
266                .cast<std::vector<xla::int64>>();
267     std::copy(dims.begin(), dims.end(),
268               tensorflow::protobuf::RepeatedFieldBackInserter(
269                   value.mutable_lhs_contracting_dimensions()));
270     dims = getattr(handle, "rhs_contracting_dimensions")
271                .cast<std::vector<xla::int64>>();
272     std::copy(dims.begin(), dims.end(),
273               tensorflow::protobuf::RepeatedFieldBackInserter(
274                   value.mutable_rhs_contracting_dimensions()));
275     dims =
276         getattr(handle, "lhs_batch_dimensions").cast<std::vector<xla::int64>>();
277     std::copy(dims.begin(), dims.end(),
278               tensorflow::protobuf::RepeatedFieldBackInserter(
279                   value.mutable_lhs_batch_dimensions()));
280     dims =
281         getattr(handle, "rhs_batch_dimensions").cast<std::vector<xla::int64>>();
282     std::copy(dims.begin(), dims.end(),
283               tensorflow::protobuf::RepeatedFieldBackInserter(
284                   value.mutable_rhs_batch_dimensions()));
285     return true;
286   }
287 };
288 
289 template <>
290 struct type_caster<xla::GatherDimensionNumbers> {
291  public:
292   PYBIND11_TYPE_CASTER(xla::GatherDimensionNumbers,
293                        _("xla::GatherDimensionNumbers"));
294 
295   // PyObject -> C++ conversion.
296   bool load(handle handle, bool) {
297     std::vector<xla::int64> dims;
298     dims = getattr(handle, "offset_dims").cast<std::vector<xla::int64>>();
299     std::copy(dims.begin(), dims.end(),
300               tensorflow::protobuf::RepeatedFieldBackInserter(
301                   value.mutable_offset_dims()));
302     dims =
303         getattr(handle, "collapsed_slice_dims").cast<std::vector<xla::int64>>();
304     std::copy(dims.begin(), dims.end(),
305               tensorflow::protobuf::RepeatedFieldBackInserter(
306                   value.mutable_collapsed_slice_dims()));
307     dims = getattr(handle, "start_index_map").cast<std::vector<xla::int64>>();
308     std::copy(dims.begin(), dims.end(),
309               tensorflow::protobuf::RepeatedFieldBackInserter(
310                   value.mutable_start_index_map()));
311     value.set_index_vector_dim(
312         getattr(handle, "index_vector_dim").cast<xla::int64>());
313     return true;
314   }
315 };
316 
317 template <>
318 struct type_caster<xla::ScatterDimensionNumbers> {
319  public:
320   PYBIND11_TYPE_CASTER(xla::ScatterDimensionNumbers,
321                        _("xla::ScatterDimensionNumbers"));
322 
323   // PyObject -> C++ conversion.
324   bool load(handle handle, bool) {
325     std::vector<xla::int64> dims;
326     dims =
327         getattr(handle, "update_window_dims").cast<std::vector<xla::int64>>();
328     std::copy(dims.begin(), dims.end(),
329               tensorflow::protobuf::RepeatedFieldBackInserter(
330                   value.mutable_update_window_dims()));
331     dims =
332         getattr(handle, "inserted_window_dims").cast<std::vector<xla::int64>>();
333     std::copy(dims.begin(), dims.end(),
334               tensorflow::protobuf::RepeatedFieldBackInserter(
335                   value.mutable_inserted_window_dims()));
336     dims = getattr(handle, "scatter_dims_to_operand_dims")
337                .cast<std::vector<xla::int64>>();
338     std::copy(dims.begin(), dims.end(),
339               tensorflow::protobuf::RepeatedFieldBackInserter(
340                   value.mutable_scatter_dims_to_operand_dims()));
341     value.set_index_vector_dim(
342         getattr(handle, "index_vector_dim").cast<xla::int64>());
343     return true;
344   }
345 };
346 
347 template <>
348 struct type_caster<xla::ReplicaGroup> {
349  public:
350   PYBIND11_TYPE_CASTER(xla::ReplicaGroup, _("xla::ReplicaGroup"));
351 
352   // PyObject -> C++ conversion.
353   bool load(handle handle, bool) {
354     std::vector<xla::int64> dims;
355     dims = getattr(handle, "replica_ids").cast<std::vector<xla::int64>>();
356     std::copy(dims.begin(), dims.end(),
357               tensorflow::protobuf::RepeatedFieldBackInserter(
358                   value.mutable_replica_ids()));
359     return true;
360   }
361 };
362 
363 template <>
364 struct type_caster<xla::PaddingConfig> {
365  public:
366   PYBIND11_TYPE_CASTER(xla::PaddingConfig, _("xla::PaddingConfig"));
367 
368   // PyObject -> C++ conversion.
369   bool load(handle handle, bool) {
370     sequence dimensions =
371         reinterpret_borrow<sequence>(getattr(handle, "dimensions"));
372 
373     for (const auto& dimension : dimensions) {
374       xla::PaddingConfig::PaddingConfigDimension* config_dim =
375           value.add_dimensions();
376       config_dim->set_edge_padding_low(
377           getattr(dimension, "edge_padding_low").cast<xla::int64>());
378       config_dim->set_edge_padding_high(
379           getattr(dimension, "edge_padding_high").cast<xla::int64>());
380       config_dim->set_interior_padding(
381           getattr(dimension, "interior_padding").cast<xla::int64>());
382     }
383     return true;
384   }
385 };
386 
387 template <>
388 struct type_caster<xla::OpMetadata> {
389  public:
390   PYBIND11_TYPE_CASTER(xla::OpMetadata, _("xla::OpMetadata"));
391 
392   // PyObject -> C++ conversion.
393   bool load(handle handle, bool) {
394     pybind11::handle op_type = getattr(handle, "op_type");
395     if (!op_type.is_none()) {
396       value.set_op_type(op_type.cast<std::string>());
397     }
398     pybind11::handle op_name = getattr(handle, "op_name");
399     if (!op_name.is_none()) {
400       value.set_op_name(op_name.cast<std::string>());
401     }
402     pybind11::handle source_file = getattr(handle, "source_file");
403     if (!source_file.is_none()) {
404       value.set_source_file(source_file.cast<std::string>());
405     }
406     pybind11::handle source_line = getattr(handle, "source_line");
407     if (!source_line.is_none()) {
408       value.set_source_line(source_line.cast<xla::int32>());
409     }
410     return true;
411   }
412 };
413 
414 template <>
415 struct type_caster<xla::PrecisionConfig> {
416  public:
417   PYBIND11_TYPE_CASTER(xla::PrecisionConfig, _("xla::PrecisionConfig"));
418 
419   // PyObject -> C++ conversion.
420   bool load(handle handle, bool) {
421     if (handle.is_none()) {
422       return true;
423     }
424 
425     sequence operand_precisions =
426         reinterpret_borrow<sequence>(getattr(handle, "operand_precision"));
427 
428     for (const auto& operand_precision : operand_precisions) {
429       value.add_operand_precision(
430           operand_precision.cast<xla::PrecisionConfig::Precision>());
431     }
432     return true;
433   }
434 };
435 
436 template <>
437 struct type_caster<xla::OpSharding> {
438  public:
439   PYBIND11_TYPE_CASTER(xla::OpSharding, _("xla::OpSharding"));
440 
441   // PyObject -> C++ conversion.
442   bool load(handle handle_obj, bool) {
443     if (handle_obj.is_none()) {
444       return true;
445     }
446 
447     // Sets `type` field.
448     handle sharding_type = getattr(handle_obj, "type");
449     if (!sharding_type.is_none()) {
450       value.set_type(sharding_type.cast<xla::OpSharding_Type>());
451     }
452 
453     // Sets `tile_assignment_dimensions` field.
454     std::vector<xla::int64> dims;
455     dims = getattr(handle_obj, "tile_assignment_dimensions")
456                .cast<std::vector<xla::int64>>();
457     std::copy(dims.begin(), dims.end(),
458               tensorflow::protobuf::RepeatedFieldBackInserter(
459                   value.mutable_tile_assignment_dimensions()));
460 
461     // Sets `tile_assignment_devices` field.
462     std::vector<xla::int64> devices;
463     devices = getattr(handle_obj, "tile_assignment_devices")
464                   .cast<std::vector<xla::int64>>();
465     std::copy(devices.begin(), devices.end(),
466               tensorflow::protobuf::RepeatedFieldBackInserter(
467                   value.mutable_tile_assignment_devices()));
468 
469     // Sets `tuple_shardings` field.
470     sequence tuple_shardings =
471         reinterpret_borrow<sequence>(getattr(handle_obj, "tuple_shardings"));
472 
473     for (const auto& tuple_sharding : tuple_shardings) {
474       xla::OpSharding* sharding = value.add_tuple_shardings();
475 
476       handle sharding_type = getattr(tuple_sharding, "type");
477       if (!sharding_type.is_none()) {
478         sharding->set_type(sharding_type.cast<xla::OpSharding_Type>());
479       }
480       std::vector<xla::int64> dims;
481       dims = getattr(tuple_sharding, "tile_assignment_dimensions")
482                  .cast<std::vector<xla::int64>>();
483       std::copy(dims.begin(), dims.end(),
484                 tensorflow::protobuf::RepeatedFieldBackInserter(
485                     sharding->mutable_tile_assignment_dimensions()));
486 
487       std::vector<xla::int64> devices;
488       devices = getattr(tuple_sharding, "tile_assignment_devices")
489                     .cast<std::vector<xla::int64>>();
490       std::copy(devices.begin(), devices.end(),
491                 tensorflow::protobuf::RepeatedFieldBackInserter(
492                     sharding->mutable_tile_assignment_devices()));
493 
494       sharding->set_replicate_on_last_tile_dim(
495           getattr(tuple_sharding, "replicate_on_last_tile_dim").cast<bool>());
496     }
497 
498     // Sets `replicate_on_last_tile_dim` field.
499     value.set_replicate_on_last_tile_dim(
500         getattr(handle_obj, "replicate_on_last_tile_dim").cast<bool>());
501 
502     return true;
503   }
504 };
505 
506 }  // namespace detail
507 }  // namespace pybind11
508 
509 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TYPES_H_
510