1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/types.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/python/lib/core/bfloat16.h"
21
22 namespace xla {
23
24 namespace py = pybind11;
25
DtypeToPrimitiveType(const py::dtype & np_type)26 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
27 static auto* types =
28 new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
29 {{'b', 1}, PRED},
30 {{'i', 1}, S8},
31 {{'i', 2}, S16},
32 {{'i', 4}, S32},
33 {{'i', 8}, S64},
34 {{'u', 1}, U8},
35 {{'u', 2}, U16},
36 {{'u', 4}, U32},
37 {{'u', 8}, U64},
38 {{'V', 2}, BF16}, // array protocol code for raw data (void*)
39 {{'f', 2}, F16},
40 {{'f', 4}, F32},
41 {{'f', 8}, F64},
42 {{'c', 8}, C64},
43 {{'c', 16}, C128},
44 });
45 auto it = types->find({np_type.kind(), np_type.itemsize()});
46 if (it == types->end()) {
47 return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
48 np_type.itemsize());
49 }
50 return it->second;
51 }
52
PrimitiveTypeToDtype(PrimitiveType type)53 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
54 switch (type) {
55 case PRED:
56 return py::dtype::of<bool>();
57 case S8:
58 return py::dtype::of<int8>();
59 case S16:
60 return py::dtype::of<int16>();
61 case S32:
62 return py::dtype::of<int32>();
63 case S64:
64 return py::dtype::of<int64>();
65 case U8:
66 return py::dtype::of<uint8>();
67 case U16:
68 return py::dtype::of<uint16>();
69 case U32:
70 return py::dtype::of<uint32>();
71 case U64:
72 return py::dtype::of<uint64>();
73 case BF16: {
74 py::handle bfloat16(tensorflow::Bfloat16Dtype());
75 return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
76 }
77 case F16:
78 return py::dtype("e"); // PEP 3118 code for "float16
79 case F32:
80 return py::dtype::of<float>();
81 case F64:
82 return py::dtype::of<double>();
83 case C64:
84 return py::dtype::of<std::complex<float>>();
85 case C128:
86 return py::dtype::of<std::complex<double>>();
87 default:
88 return Unimplemented("Unimplemented primitive type %s",
89 PrimitiveType_Name(type));
90 }
91 }
92
93 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)94 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
95 // We use an "=" prefix to indicate that we prefer "standard" types like
96 // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
97 // its format descriptors.
98 switch (type) {
99 case PRED:
100 return std::string("?");
101 case S8:
102 return std::string("=b");
103 case S16:
104 return std::string("=h");
105 case S32:
106 return std::string("=i");
107 case S64:
108 return std::string("=q");
109 case U8:
110 return std::string("=B");
111 case U16:
112 return std::string("=H");
113 case U32:
114 return std::string("=I");
115 case U64:
116 return std::string("=Q");
117 case F16:
118 return std::string("=e");
119 case F32:
120 return std::string("=f");
121 case F64:
122 return std::string("=d");
123 case C64:
124 return std::string("=Zf");
125 case C128:
126 return std::string("=Zd");
127 default:
128 return Unimplemented("Unimplemented primitive type %s",
129 PrimitiveType_Name(type));
130 }
131 }
132
TypeDescriptorForPrimitiveType(PrimitiveType type)133 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
134 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
135 "Big endian support not implemented");
136 switch (type) {
137 case PRED:
138 return py::str("|b1");
139 case S8:
140 return py::str("|i1");
141 case S16:
142 return py::str("<i2");
143 case S32:
144 return py::str("<i4");
145 case S64:
146 return py::str("<i8");
147 case U8:
148 return py::str("|u1");
149 case U16:
150 return py::str("<u2");
151 case U32:
152 return py::str("<u4");
153 case U64:
154 return py::str("<u8");
155 case BF16:
156 return py::str("<V2");
157 case F16:
158 return py::str("<f2");
159 case F32:
160 return py::str("<f4");
161 case F64:
162 return py::str("<f8");
163 case C64:
164 return py::str("<c8");
165 case C128:
166 return py::str("<c16");
167 default:
168 return Unimplemented("Unimplemented primitive type %s",
169 PrimitiveType_Name(type));
170 }
171 }
172
173 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)174 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
175 std::vector<ssize_t> strides;
176 CHECK(shape.IsArray());
177 CHECK(shape.has_layout());
178
179 strides.resize(shape.dimensions_size());
180 ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
181 for (int i : shape.layout().minor_to_major()) {
182 strides.at(i) = stride;
183 stride *= shape.dimensions(i);
184 }
185 return strides;
186 }
187
LiteralToPython(std::shared_ptr<xla::Literal> literal)188 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
189 xla::Literal& m = *literal;
190 if (m.shape().IsTuple()) {
191 std::vector<Literal> elems = m.DecomposeTuple();
192 std::vector<py::object> arrays(elems.size());
193 for (int i = 0; i < elems.size(); ++i) {
194 TF_ASSIGN_OR_RETURN(
195 arrays[i],
196 LiteralToPython(absl::make_unique<Literal>(std::move(elems[i]))));
197 }
198 py::tuple result(elems.size());
199 for (int i = 0; i < elems.size(); ++i) {
200 PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
201 }
202 return result;
203 }
204 TF_RET_CHECK(m.shape().IsArray());
205
206 py::object literal_object = py::cast(literal);
207 TF_ASSIGN_OR_RETURN(py::dtype dtype,
208 PrimitiveTypeToDtype(m.shape().element_type()));
209 return py::array(dtype, m.shape().dimensions(),
210 ByteStridesForShape(m.shape()), m.untyped_data(),
211 literal_object);
212 }
213
GetPythonBufferTree(const py::object & argument)214 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
215 PythonBufferTree tree;
216 if (py::isinstance<py::tuple>(argument)) {
217 py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
218 std::vector<Shape> host_shapes(tuple.size());
219 for (int i = 0; i < host_shapes.size(); ++i) {
220 TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
221 GetPythonBufferTree(tuple[i]));
222 tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
223 std::move(subtree.leaves.begin(), subtree.leaves.end(),
224 std::back_inserter(tree.leaves));
225 tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
226 std::move(subtree.arrays.begin(), subtree.arrays.end(),
227 std::back_inserter(tree.arrays));
228 host_shapes[i] = std::move(subtree.shape);
229 }
230 tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
231 } else {
232 pybind11::detail::type_caster<BorrowingLiteral> caster;
233 if (!caster.load(argument, /*convert=*/true)) {
234 return InvalidArgument("Invalid array value.");
235 }
236 DCHECK_EQ(caster.arrays.size(), 1);
237 tree.arrays.push_back(std::move(caster.arrays.front()));
238 tree.leaves.push_back(std::move(*caster));
239 tree.shape = tree.leaves.front().shape();
240 }
241 return tree;
242 }
243
244 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)245 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
246 py::tuple out(xs.size());
247 for (int i = 0; i < xs.size(); ++i) {
248 out[i] = py::int_(xs[i]);
249 }
250 return out;
251 }
252
IntSpanToTuple(absl::Span<int64 const> xs)253 py::tuple IntSpanToTuple(absl::Span<int64 const> xs) {
254 return IntSpanToTupleHelper(xs);
255 }
IntSpanToTuple(absl::Span<int const> xs)256 py::tuple IntSpanToTuple(absl::Span<int const> xs) {
257 return IntSpanToTupleHelper(xs);
258 }
259
IntSequenceToVector(const py::object & sequence)260 std::vector<int64> IntSequenceToVector(const py::object& sequence) {
261 std::vector<int64> output;
262 for (auto item : sequence) {
263 output.push_back(item.cast<int64>());
264 }
265 return output;
266 }
267
CastToArray(py::handle h)268 absl::optional<CastToArrayResult> CastToArray(py::handle h) {
269 py::array array = py::array::ensure(
270 h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
271 if (!array) {
272 return absl::nullopt;
273 }
274 auto type_or_status = DtypeToPrimitiveType(array.dtype());
275 if (!type_or_status.ok()) {
276 throw std::runtime_error(type_or_status.status().ToString());
277 }
278 PrimitiveType type = type_or_status.ValueOrDie();
279
280 absl::InlinedVector<int64, 4> dims(array.ndim());
281 for (int i = 0; i < array.ndim(); ++i) {
282 dims[i] = array.shape(i);
283 }
284 Shape shape = ShapeUtil::MakeShape(type, dims);
285 if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
286 throw std::runtime_error(absl::StrCat(
287 "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
288 ShapeUtil::ByteSizeOf(shape)));
289 }
290 return CastToArrayResult{array, static_cast<const char*>(array.data()),
291 shape};
292 }
293
294 } // namespace xla
295