1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/types.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/python/lib/core/bfloat16.h"
21
22 namespace xla {
23
24 namespace py = pybind11;
25
DtypeToPrimitiveType(const py::dtype & np_type)26 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
27 static auto* types =
28 new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
29 {{'b', 1}, PRED},
30 {{'i', 1}, S8},
31 {{'i', 2}, S16},
32 {{'i', 4}, S32},
33 {{'i', 8}, S64},
34 {{'u', 1}, U8},
35 {{'u', 2}, U16},
36 {{'u', 4}, U32},
37 {{'u', 8}, U64},
38 {{'V', 2}, BF16}, // array protocol code for raw data (void*)
39 {{'f', 2}, F16},
40 {{'f', 4}, F32},
41 {{'f', 8}, F64},
42 {{'c', 8}, C64},
43 {{'c', 16}, C128},
44 });
45 auto it = types->find({np_type.kind(), np_type.itemsize()});
46 if (it == types->end()) {
47 return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
48 np_type.itemsize());
49 }
50 return it->second;
51 }
52
PrimitiveTypeToDtype(PrimitiveType type)53 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
54 switch (type) {
55 case PRED:
56 return py::dtype::of<bool>();
57 case S8:
58 return py::dtype::of<int8>();
59 case S16:
60 return py::dtype::of<int16>();
61 case S32:
62 return py::dtype::of<int32>();
63 case S64:
64 return py::dtype::of<int64>();
65 case U8:
66 return py::dtype::of<uint8>();
67 case U16:
68 return py::dtype::of<uint16>();
69 case U32:
70 return py::dtype::of<uint32>();
71 case U64:
72 return py::dtype::of<uint64>();
73 case BF16: {
74 py::handle bfloat16(tensorflow::Bfloat16Dtype());
75 return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
76 }
77 case F16:
78 return py::dtype("e"); // PEP 3118 code for "float16
79 case F32:
80 return py::dtype::of<float>();
81 case F64:
82 return py::dtype::of<double>();
83 case C64:
84 return py::dtype::of<std::complex<float>>();
85 case C128:
86 return py::dtype::of<std::complex<double>>();
87 default:
88 return Unimplemented("Unimplemented primitive type %s",
89 PrimitiveType_Name(type));
90 }
91 }
92
GetNumpyScalarTypes()93 const NumpyScalarTypes& GetNumpyScalarTypes() {
94 static const NumpyScalarTypes* singleton = []() {
95 NumpyScalarTypes* dtypes = new NumpyScalarTypes();
96 const auto numpy = py::module::import("numpy");
97 dtypes->np_bool = py::object(numpy.attr("bool_"));
98 dtypes->np_int8 = py::object(numpy.attr("int8"));
99 dtypes->np_int16 = py::object(numpy.attr("int16"));
100 dtypes->np_int32 = py::object(numpy.attr("int32"));
101 dtypes->np_int64 = py::object(numpy.attr("int64"));
102 dtypes->np_uint8 = py::object(numpy.attr("uint8"));
103 dtypes->np_uint16 = py::object(numpy.attr("uint16"));
104 dtypes->np_uint32 = py::object(numpy.attr("uint32"));
105 dtypes->np_uint64 = py::object(numpy.attr("uint64"));
106 dtypes->np_bfloat16 =
107 py::reinterpret_borrow<py::object>(tensorflow::Bfloat16Dtype());
108 dtypes->np_float16 = py::object(numpy.attr("float16"));
109 dtypes->np_float32 = py::object(numpy.attr("float32"));
110 dtypes->np_float64 = py::object(numpy.attr("float64"));
111 dtypes->np_complex64 = py::object(numpy.attr("complex64"));
112 dtypes->np_complex128 = py::object(numpy.attr("complex128"));
113 dtypes->np_longlong = py::object(numpy.attr("longlong"));
114 dtypes->np_intc = py::object(numpy.attr("intc"));
115 return dtypes;
116 }();
117 return *singleton;
118 }
119
120 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)121 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
122 // We use an "=" prefix to indicate that we prefer "standard" types like
123 // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
124 // its format descriptors.
125 switch (type) {
126 case PRED:
127 return std::string("?");
128 case S8:
129 return std::string("=b");
130 case S16:
131 return std::string("=h");
132 case S32:
133 return std::string("=i");
134 case S64:
135 return std::string("=q");
136 case U8:
137 return std::string("=B");
138 case U16:
139 return std::string("=H");
140 case U32:
141 return std::string("=I");
142 case U64:
143 return std::string("=Q");
144 case F16:
145 return std::string("=e");
146 case F32:
147 return std::string("=f");
148 case F64:
149 return std::string("=d");
150 case C64:
151 return std::string("=Zf");
152 case C128:
153 return std::string("=Zd");
154 default:
155 return Unimplemented("Unimplemented primitive type %s",
156 PrimitiveType_Name(type));
157 }
158 }
159
TypeDescriptorForPrimitiveType(PrimitiveType type)160 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
161 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
162 "Big endian support not implemented");
163 switch (type) {
164 case PRED:
165 return py::str("|b1");
166 case S8:
167 return py::str("|i1");
168 case S16:
169 return py::str("<i2");
170 case S32:
171 return py::str("<i4");
172 case S64:
173 return py::str("<i8");
174 case U8:
175 return py::str("|u1");
176 case U16:
177 return py::str("<u2");
178 case U32:
179 return py::str("<u4");
180 case U64:
181 return py::str("<u8");
182 case BF16:
183 return py::str("<V2");
184 case F16:
185 return py::str("<f2");
186 case F32:
187 return py::str("<f4");
188 case F64:
189 return py::str("<f8");
190 case C64:
191 return py::str("<c8");
192 case C128:
193 return py::str("<c16");
194 default:
195 return Unimplemented("Unimplemented primitive type %s",
196 PrimitiveType_Name(type));
197 }
198 }
199
Squash64BitTypes(PrimitiveType type)200 PrimitiveType Squash64BitTypes(PrimitiveType type) {
201 switch (type) {
202 case S64:
203 return S32;
204 case U64:
205 return U32;
206 case F64:
207 return F32;
208 case C128:
209 return C64;
210 default:
211 return type;
212 }
213 }
214
215 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)216 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
217 std::vector<ssize_t> strides;
218 CHECK(shape.IsArray());
219 CHECK(shape.has_layout());
220
221 strides.resize(shape.dimensions_size());
222 ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
223 for (int i : shape.layout().minor_to_major()) {
224 strides.at(i) = stride;
225 stride *= shape.dimensions(i);
226 }
227 return strides;
228 }
229
ByteStridesForShapeInt64(const Shape & shape)230 std::vector<int64_t> ByteStridesForShapeInt64(const Shape& shape) {
231 std::vector<int64_t> strides;
232 CHECK(shape.IsArray());
233 CHECK(shape.has_layout());
234
235 strides.resize(shape.dimensions_size());
236 int64_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
237 for (int i : shape.layout().minor_to_major()) {
238 strides.at(i) = stride;
239 stride *= shape.dimensions(i);
240 }
241 return strides;
242 }
243
LiteralToPython(std::shared_ptr<xla::Literal> literal)244 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
245 xla::Literal& m = *literal;
246 if (m.shape().IsTuple()) {
247 std::vector<Literal> elems = m.DecomposeTuple();
248 std::vector<py::object> arrays(elems.size());
249 for (int i = 0; i < elems.size(); ++i) {
250 TF_ASSIGN_OR_RETURN(
251 arrays[i],
252 LiteralToPython(absl::make_unique<Literal>(std::move(elems[i]))));
253 }
254 py::tuple result(elems.size());
255 for (int i = 0; i < elems.size(); ++i) {
256 PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
257 }
258 return result;
259 }
260 TF_RET_CHECK(m.shape().IsArray());
261
262 py::object literal_object = py::cast(literal);
263 TF_ASSIGN_OR_RETURN(py::dtype dtype,
264 PrimitiveTypeToDtype(m.shape().element_type()));
265 return py::array(dtype, m.shape().dimensions(),
266 ByteStridesForShape(m.shape()), m.untyped_data(),
267 literal_object);
268 }
269
GetPythonBufferTree(const py::object & argument)270 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
271 PythonBufferTree tree;
272 if (py::isinstance<py::tuple>(argument)) {
273 py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
274 std::vector<Shape> host_shapes(tuple.size());
275 for (int i = 0; i < host_shapes.size(); ++i) {
276 TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
277 GetPythonBufferTree(tuple[i]));
278 tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
279 std::move(subtree.leaves.begin(), subtree.leaves.end(),
280 std::back_inserter(tree.leaves));
281 tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
282 std::move(subtree.arrays.begin(), subtree.arrays.end(),
283 std::back_inserter(tree.arrays));
284 host_shapes[i] = std::move(subtree.shape);
285 }
286 tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
287 } else {
288 pybind11::detail::type_caster<BorrowingLiteral> caster;
289 if (!caster.load(argument, /*convert=*/true)) {
290 return InvalidArgument("Invalid array value.");
291 }
292 DCHECK_EQ(caster.arrays.size(), 1);
293 tree.arrays.push_back(std::move(caster.arrays.front()));
294 tree.leaves.push_back(std::move(*caster));
295 tree.shape = tree.leaves.front().shape();
296 }
297 return tree;
298 }
299
300 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)301 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
302 py::tuple out(xs.size());
303 for (int i = 0; i < xs.size(); ++i) {
304 out[i] = py::int_(xs[i]);
305 }
306 return out;
307 }
308
309 template <>
SpanToTuple(absl::Span<int const> xs)310 pybind11::tuple SpanToTuple(absl::Span<int const> xs) {
311 return IntSpanToTupleHelper(xs);
312 }
313 template <>
SpanToTuple(absl::Span<int64 const> xs)314 pybind11::tuple SpanToTuple(absl::Span<int64 const> xs) {
315 return IntSpanToTupleHelper(xs);
316 }
317
CastToArray(py::handle h)318 absl::optional<CastToArrayResult> CastToArray(py::handle h) {
319 py::array array = py::array::ensure(
320 h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
321 if (!array) {
322 return absl::nullopt;
323 }
324 auto type_or_status = DtypeToPrimitiveType(array.dtype());
325 if (!type_or_status.ok()) {
326 throw std::runtime_error(type_or_status.status().ToString());
327 }
328 PrimitiveType type = type_or_status.ValueOrDie();
329
330 absl::InlinedVector<int64, 4> dims(array.ndim());
331 for (int i = 0; i < array.ndim(); ++i) {
332 dims[i] = array.shape(i);
333 }
334 Shape shape = ShapeUtil::MakeShape(type, dims);
335 if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
336 throw std::runtime_error(absl::StrCat(
337 "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
338 ShapeUtil::ByteSizeOf(shape)));
339 }
340 return CastToArrayResult{array, static_cast<const char*>(array.data()),
341 shape};
342 }
343
344 } // namespace xla
345