• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/types.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/python/lib/core/bfloat16.h"
21 
22 namespace xla {
23 
24 namespace py = pybind11;
25 
DtypeToPrimitiveType(const py::dtype & np_type)26 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
27   static auto* types =
28       new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
29           {{'b', 1}, PRED},
30           {{'i', 1}, S8},
31           {{'i', 2}, S16},
32           {{'i', 4}, S32},
33           {{'i', 8}, S64},
34           {{'u', 1}, U8},
35           {{'u', 2}, U16},
36           {{'u', 4}, U32},
37           {{'u', 8}, U64},
38           {{'V', 2}, BF16},  // array protocol code for raw data (void*)
39           {{'f', 2}, F16},
40           {{'f', 4}, F32},
41           {{'f', 8}, F64},
42           {{'c', 8}, C64},
43           {{'c', 16}, C128},
44       });
45   auto it = types->find({np_type.kind(), np_type.itemsize()});
46   if (it == types->end()) {
47     return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
48                            np_type.itemsize());
49   }
50   return it->second;
51 }
52 
PrimitiveTypeToDtype(PrimitiveType type)53 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
54   switch (type) {
55     case PRED:
56       return py::dtype::of<bool>();
57     case S8:
58       return py::dtype::of<int8>();
59     case S16:
60       return py::dtype::of<int16>();
61     case S32:
62       return py::dtype::of<int32>();
63     case S64:
64       return py::dtype::of<int64>();
65     case U8:
66       return py::dtype::of<uint8>();
67     case U16:
68       return py::dtype::of<uint16>();
69     case U32:
70       return py::dtype::of<uint32>();
71     case U64:
72       return py::dtype::of<uint64>();
73     case BF16: {
74       py::handle bfloat16(tensorflow::Bfloat16Dtype());
75       return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
76     }
77     case F16:
78       return py::dtype("e");  // PEP 3118 code for "float16
79     case F32:
80       return py::dtype::of<float>();
81     case F64:
82       return py::dtype::of<double>();
83     case C64:
84       return py::dtype::of<std::complex<float>>();
85     case C128:
86       return py::dtype::of<std::complex<double>>();
87     default:
88       return Unimplemented("Unimplemented primitive type %s",
89                            PrimitiveType_Name(type));
90   }
91 }
92 
GetNumpyScalarTypes()93 const NumpyScalarTypes& GetNumpyScalarTypes() {
94   static const NumpyScalarTypes* singleton = []() {
95     NumpyScalarTypes* dtypes = new NumpyScalarTypes();
96     const auto numpy = py::module::import("numpy");
97     dtypes->np_bool = py::object(numpy.attr("bool_"));
98     dtypes->np_int8 = py::object(numpy.attr("int8"));
99     dtypes->np_int16 = py::object(numpy.attr("int16"));
100     dtypes->np_int32 = py::object(numpy.attr("int32"));
101     dtypes->np_int64 = py::object(numpy.attr("int64"));
102     dtypes->np_uint8 = py::object(numpy.attr("uint8"));
103     dtypes->np_uint16 = py::object(numpy.attr("uint16"));
104     dtypes->np_uint32 = py::object(numpy.attr("uint32"));
105     dtypes->np_uint64 = py::object(numpy.attr("uint64"));
106     dtypes->np_bfloat16 =
107         py::reinterpret_borrow<py::object>(tensorflow::Bfloat16Dtype());
108     dtypes->np_float16 = py::object(numpy.attr("float16"));
109     dtypes->np_float32 = py::object(numpy.attr("float32"));
110     dtypes->np_float64 = py::object(numpy.attr("float64"));
111     dtypes->np_complex64 = py::object(numpy.attr("complex64"));
112     dtypes->np_complex128 = py::object(numpy.attr("complex128"));
113     dtypes->np_longlong = py::object(numpy.attr("longlong"));
114     dtypes->np_intc = py::object(numpy.attr("intc"));
115     return dtypes;
116   }();
117   return *singleton;
118 }
119 
120 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)121 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
122   // We use an "=" prefix to indicate that we prefer "standard" types like
123   // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
124   // its format descriptors.
125   switch (type) {
126     case PRED:
127       return std::string("?");
128     case S8:
129       return std::string("=b");
130     case S16:
131       return std::string("=h");
132     case S32:
133       return std::string("=i");
134     case S64:
135       return std::string("=q");
136     case U8:
137       return std::string("=B");
138     case U16:
139       return std::string("=H");
140     case U32:
141       return std::string("=I");
142     case U64:
143       return std::string("=Q");
144     case F16:
145       return std::string("=e");
146     case F32:
147       return std::string("=f");
148     case F64:
149       return std::string("=d");
150     case C64:
151       return std::string("=Zf");
152     case C128:
153       return std::string("=Zd");
154     default:
155       return Unimplemented("Unimplemented primitive type %s",
156                            PrimitiveType_Name(type));
157   }
158 }
159 
TypeDescriptorForPrimitiveType(PrimitiveType type)160 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
161   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
162                 "Big endian support not implemented");
163   switch (type) {
164     case PRED:
165       return py::str("|b1");
166     case S8:
167       return py::str("|i1");
168     case S16:
169       return py::str("<i2");
170     case S32:
171       return py::str("<i4");
172     case S64:
173       return py::str("<i8");
174     case U8:
175       return py::str("|u1");
176     case U16:
177       return py::str("<u2");
178     case U32:
179       return py::str("<u4");
180     case U64:
181       return py::str("<u8");
182     case BF16:
183       return py::str("<V2");
184     case F16:
185       return py::str("<f2");
186     case F32:
187       return py::str("<f4");
188     case F64:
189       return py::str("<f8");
190     case C64:
191       return py::str("<c8");
192     case C128:
193       return py::str("<c16");
194     default:
195       return Unimplemented("Unimplemented primitive type %s",
196                            PrimitiveType_Name(type));
197   }
198 }
199 
Squash64BitTypes(PrimitiveType type)200 PrimitiveType Squash64BitTypes(PrimitiveType type) {
201   switch (type) {
202     case S64:
203       return S32;
204     case U64:
205       return U32;
206     case F64:
207       return F32;
208     case C128:
209       return C64;
210     default:
211       return type;
212   }
213 }
214 
215 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)216 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
217   std::vector<ssize_t> strides;
218   CHECK(shape.IsArray());
219   CHECK(shape.has_layout());
220 
221   strides.resize(shape.dimensions_size());
222   ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
223   for (int i : shape.layout().minor_to_major()) {
224     strides.at(i) = stride;
225     stride *= shape.dimensions(i);
226   }
227   return strides;
228 }
229 
ByteStridesForShapeInt64(const Shape & shape)230 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape) {
231   std::vector<int64_t> strides;
232   CHECK(shape.IsArray());
233   CHECK(shape.has_layout());
234 
235   strides.resize(shape.dimensions_size());
236   int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
237   for (int i : shape.layout().minor_to_major()) {
238     strides.at(i) = stride;
239     stride *= shape.dimensions(i);
240   }
241   return strides;
242 }
243 
LiteralToPython(std::shared_ptr<xla::Literal> literal)244 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
245   xla::Literal& m = *literal;
246   if (m.shape().IsTuple()) {
247     std::vector<Literal> elems = m.DecomposeTuple();
248     std::vector<py::object> arrays(elems.size());
249     for (int i = 0; i < elems.size(); ++i) {
250       TF_ASSIGN_OR_RETURN(
251           arrays[i],
252           LiteralToPython(absl::make_unique<Literal>(std::move(elems[i]))));
253     }
254     py::tuple result(elems.size());
255     for (int i = 0; i < elems.size(); ++i) {
256       PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
257     }
258     return result;
259   }
260   TF_RET_CHECK(m.shape().IsArray());
261 
262   py::object literal_object = py::cast(literal);
263   TF_ASSIGN_OR_RETURN(py::dtype dtype,
264                       PrimitiveTypeToDtype(m.shape().element_type()));
265   return py::array(dtype, m.shape().dimensions(),
266                    ByteStridesForShape(m.shape()), m.untyped_data(),
267                    literal_object);
268 }
269 
GetPythonBufferTree(const py::object & argument)270 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
271   PythonBufferTree tree;
272   if (py::isinstance<py::tuple>(argument)) {
273     py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
274     std::vector<Shape> host_shapes(tuple.size());
275     for (int i = 0; i < host_shapes.size(); ++i) {
276       TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
277                           GetPythonBufferTree(tuple[i]));
278       tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
279       std::move(subtree.leaves.begin(), subtree.leaves.end(),
280                 std::back_inserter(tree.leaves));
281       tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
282       std::move(subtree.arrays.begin(), subtree.arrays.end(),
283                 std::back_inserter(tree.arrays));
284       host_shapes[i] = std::move(subtree.shape);
285     }
286     tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
287   } else {
288     pybind11::detail::type_caster<BorrowingLiteral> caster;
289     if (!caster.load(argument, /*convert=*/true)) {
290       return InvalidArgument("Invalid array value.");
291     }
292     DCHECK_EQ(caster.arrays.size(), 1);
293     tree.arrays.push_back(std::move(caster.arrays.front()));
294     tree.leaves.push_back(std::move(*caster));
295     tree.shape = tree.leaves.front().shape();
296   }
297   return tree;
298 }
299 
300 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)301 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
302   py::tuple out(xs.size());
303   for (int i = 0; i < xs.size(); ++i) {
304     out[i] = py::int_(xs[i]);
305   }
306   return out;
307 }
308 
309 template <>
SpanToTuple(absl::Span<int const> xs)310 pybind11::tuple SpanToTuple(absl::Span<int const> xs) {
311   return IntSpanToTupleHelper(xs);
312 }
313 template <>
SpanToTuple(absl::Span<int64 const> xs)314 pybind11::tuple SpanToTuple(absl::Span<int64 const> xs) {
315   return IntSpanToTupleHelper(xs);
316 }
317 
CastToArray(py::handle h)318 absl::optional<CastToArrayResult> CastToArray(py::handle h) {
319   py::array array = py::array::ensure(
320       h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
321   if (!array) {
322     return absl::nullopt;
323   }
324   auto type_or_status = DtypeToPrimitiveType(array.dtype());
325   if (!type_or_status.ok()) {
326     throw std::runtime_error(type_or_status.status().ToString());
327   }
328   PrimitiveType type = type_or_status.ValueOrDie();
329 
330   absl::InlinedVector<int64, 4> dims(array.ndim());
331   for (int i = 0; i < array.ndim(); ++i) {
332     dims[i] = array.shape(i);
333   }
334   Shape shape = ShapeUtil::MakeShape(type, dims);
335   if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
336     throw std::runtime_error(absl::StrCat(
337         "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
338         ShapeUtil::ByteSizeOf(shape)));
339   }
340   return CastToArrayResult{array, static_cast<const char*>(array.data()),
341                            shape};
342 }
343 
344 }  // namespace xla
345