• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/types.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/status_macros.h"
20 #include "tensorflow/python/lib/core/bfloat16.h"
21 
22 namespace xla {
23 
24 namespace py = pybind11;
25 
DtypeToPrimitiveType(const py::dtype & np_type)26 xla::StatusOr<PrimitiveType> DtypeToPrimitiveType(const py::dtype& np_type) {
27   static auto* types =
28       new absl::flat_hash_map<std::pair<char, int>, PrimitiveType>({
29           {{'b', 1}, PRED},
30           {{'i', 1}, S8},
31           {{'i', 2}, S16},
32           {{'i', 4}, S32},
33           {{'i', 8}, S64},
34           {{'u', 1}, U8},
35           {{'u', 2}, U16},
36           {{'u', 4}, U32},
37           {{'u', 8}, U64},
38           {{'V', 2}, BF16},  // array protocol code for raw data (void*)
39           {{'f', 2}, F16},
40           {{'f', 4}, F32},
41           {{'f', 8}, F64},
42           {{'c', 8}, C64},
43           {{'c', 16}, C128},
44       });
45   auto it = types->find({np_type.kind(), np_type.itemsize()});
46   if (it == types->end()) {
47     return InvalidArgument("Unknown NumPy type %c size %d", np_type.kind(),
48                            np_type.itemsize());
49   }
50   return it->second;
51 }
52 
PrimitiveTypeToDtype(PrimitiveType type)53 xla::StatusOr<py::dtype> PrimitiveTypeToDtype(PrimitiveType type) {
54   switch (type) {
55     case PRED:
56       return py::dtype::of<bool>();
57     case S8:
58       return py::dtype::of<int8>();
59     case S16:
60       return py::dtype::of<int16>();
61     case S32:
62       return py::dtype::of<int32>();
63     case S64:
64       return py::dtype::of<int64>();
65     case U8:
66       return py::dtype::of<uint8>();
67     case U16:
68       return py::dtype::of<uint16>();
69     case U32:
70       return py::dtype::of<uint32>();
71     case U64:
72       return py::dtype::of<uint64>();
73     case BF16: {
74       py::handle bfloat16(tensorflow::Bfloat16Dtype());
75       return py::dtype::from_args(py::reinterpret_borrow<py::object>(bfloat16));
76     }
77     case F16:
78       return py::dtype("e");  // PEP 3118 code for "float16
79     case F32:
80       return py::dtype::of<float>();
81     case F64:
82       return py::dtype::of<double>();
83     case C64:
84       return py::dtype::of<std::complex<float>>();
85     case C128:
86       return py::dtype::of<std::complex<double>>();
87     default:
88       return Unimplemented("Unimplemented primitive type %s",
89                            PrimitiveType_Name(type));
90   }
91 }
92 
93 // Returns a numpy-style format descriptor string for `type`.
FormatDescriptorForPrimitiveType(PrimitiveType type)94 StatusOr<std::string> FormatDescriptorForPrimitiveType(PrimitiveType type) {
95   // We use an "=" prefix to indicate that we prefer "standard" types like
96   // np.int32 rather than "native" types like np.cint. pybind11 does not qualify
97   // its format descriptors.
98   switch (type) {
99     case PRED:
100       return std::string("?");
101     case S8:
102       return std::string("=b");
103     case S16:
104       return std::string("=h");
105     case S32:
106       return std::string("=i");
107     case S64:
108       return std::string("=q");
109     case U8:
110       return std::string("=B");
111     case U16:
112       return std::string("=H");
113     case U32:
114       return std::string("=I");
115     case U64:
116       return std::string("=Q");
117     case F16:
118       return std::string("=e");
119     case F32:
120       return std::string("=f");
121     case F64:
122       return std::string("=d");
123     case C64:
124       return std::string("=Zf");
125     case C128:
126       return std::string("=Zd");
127     default:
128       return Unimplemented("Unimplemented primitive type %s",
129                            PrimitiveType_Name(type));
130   }
131 }
132 
TypeDescriptorForPrimitiveType(PrimitiveType type)133 StatusOr<py::str> TypeDescriptorForPrimitiveType(PrimitiveType type) {
134   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__,
135                 "Big endian support not implemented");
136   switch (type) {
137     case PRED:
138       return py::str("|b1");
139     case S8:
140       return py::str("|i1");
141     case S16:
142       return py::str("<i2");
143     case S32:
144       return py::str("<i4");
145     case S64:
146       return py::str("<i8");
147     case U8:
148       return py::str("|u1");
149     case U16:
150       return py::str("<u2");
151     case U32:
152       return py::str("<u4");
153     case U64:
154       return py::str("<u8");
155     case BF16:
156       return py::str("<V2");
157     case F16:
158       return py::str("<f2");
159     case F32:
160       return py::str("<f4");
161     case F64:
162       return py::str("<f8");
163     case C64:
164       return py::str("<c8");
165     case C128:
166       return py::str("<c16");
167     default:
168       return Unimplemented("Unimplemented primitive type %s",
169                            PrimitiveType_Name(type));
170   }
171 }
172 
173 // Returns the strides for `shape`.
ByteStridesForShape(const Shape & shape)174 std::vector<ssize_t> ByteStridesForShape(const Shape& shape) {
175   std::vector<ssize_t> strides;
176   CHECK(shape.IsArray());
177   CHECK(shape.has_layout());
178 
179   strides.resize(shape.dimensions_size());
180   ssize_t stride = ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type());
181   for (int i : shape.layout().minor_to_major()) {
182     strides.at(i) = stride;
183     stride *= shape.dimensions(i);
184   }
185   return strides;
186 }
187 
LiteralToPython(std::shared_ptr<xla::Literal> literal)188 StatusOr<py::object> LiteralToPython(std::shared_ptr<xla::Literal> literal) {
189   xla::Literal& m = *literal;
190   if (m.shape().IsTuple()) {
191     std::vector<Literal> elems = m.DecomposeTuple();
192     std::vector<py::object> arrays(elems.size());
193     for (int i = 0; i < elems.size(); ++i) {
194       TF_ASSIGN_OR_RETURN(
195           arrays[i],
196           LiteralToPython(absl::make_unique<Literal>(std::move(elems[i]))));
197     }
198     py::tuple result(elems.size());
199     for (int i = 0; i < elems.size(); ++i) {
200       PyTuple_SET_ITEM(result.ptr(), i, arrays[i].release().ptr());
201     }
202     return result;
203   }
204   TF_RET_CHECK(m.shape().IsArray());
205 
206   py::object literal_object = py::cast(literal);
207   TF_ASSIGN_OR_RETURN(py::dtype dtype,
208                       PrimitiveTypeToDtype(m.shape().element_type()));
209   return py::array(dtype, m.shape().dimensions(),
210                    ByteStridesForShape(m.shape()), m.untyped_data(),
211                    literal_object);
212 }
213 
GetPythonBufferTree(const py::object & argument)214 StatusOr<PythonBufferTree> GetPythonBufferTree(const py::object& argument) {
215   PythonBufferTree tree;
216   if (py::isinstance<py::tuple>(argument)) {
217     py::tuple tuple = py::reinterpret_borrow<py::tuple>(argument);
218     std::vector<Shape> host_shapes(tuple.size());
219     for (int i = 0; i < host_shapes.size(); ++i) {
220       TF_ASSIGN_OR_RETURN(PythonBufferTree subtree,
221                           GetPythonBufferTree(tuple[i]));
222       tree.leaves.reserve(tree.leaves.size() + subtree.leaves.size());
223       std::move(subtree.leaves.begin(), subtree.leaves.end(),
224                 std::back_inserter(tree.leaves));
225       tree.arrays.reserve(tree.arrays.size() + subtree.arrays.size());
226       std::move(subtree.arrays.begin(), subtree.arrays.end(),
227                 std::back_inserter(tree.arrays));
228       host_shapes[i] = std::move(subtree.shape);
229     }
230     tree.shape = ShapeUtil::MakeTupleShape(host_shapes);
231   } else {
232     pybind11::detail::type_caster<BorrowingLiteral> caster;
233     if (!caster.load(argument, /*convert=*/true)) {
234       return InvalidArgument("Invalid array value.");
235     }
236     DCHECK_EQ(caster.arrays.size(), 1);
237     tree.arrays.push_back(std::move(caster.arrays.front()));
238     tree.leaves.push_back(std::move(*caster));
239     tree.shape = tree.leaves.front().shape();
240   }
241   return tree;
242 }
243 
244 template <typename IntType>
IntSpanToTupleHelper(absl::Span<IntType const> xs)245 static py::tuple IntSpanToTupleHelper(absl::Span<IntType const> xs) {
246   py::tuple out(xs.size());
247   for (int i = 0; i < xs.size(); ++i) {
248     out[i] = py::int_(xs[i]);
249   }
250   return out;
251 }
252 
IntSpanToTuple(absl::Span<int64 const> xs)253 py::tuple IntSpanToTuple(absl::Span<int64 const> xs) {
254   return IntSpanToTupleHelper(xs);
255 }
IntSpanToTuple(absl::Span<int const> xs)256 py::tuple IntSpanToTuple(absl::Span<int const> xs) {
257   return IntSpanToTupleHelper(xs);
258 }
259 
IntSequenceToVector(const py::object & sequence)260 std::vector<int64> IntSequenceToVector(const py::object& sequence) {
261   std::vector<int64> output;
262   for (auto item : sequence) {
263     output.push_back(item.cast<int64>());
264   }
265   return output;
266 }
267 
CastToArray(py::handle h)268 absl::optional<CastToArrayResult> CastToArray(py::handle h) {
269   py::array array = py::array::ensure(
270       h, py::array::c_style | py::detail::npy_api::NPY_ARRAY_ALIGNED_);
271   if (!array) {
272     return absl::nullopt;
273   }
274   auto type_or_status = DtypeToPrimitiveType(array.dtype());
275   if (!type_or_status.ok()) {
276     throw std::runtime_error(type_or_status.status().ToString());
277   }
278   PrimitiveType type = type_or_status.ValueOrDie();
279 
280   absl::InlinedVector<int64, 4> dims(array.ndim());
281   for (int i = 0; i < array.ndim(); ++i) {
282     dims[i] = array.shape(i);
283   }
284   Shape shape = ShapeUtil::MakeShape(type, dims);
285   if (array.size() * array.itemsize() != ShapeUtil::ByteSizeOf(shape)) {
286     throw std::runtime_error(absl::StrCat(
287         "Size mismatch for buffer: ", array.size() * array.itemsize(), " vs. ",
288         ShapeUtil::ByteSizeOf(shape)));
289   }
290   return CastToArrayResult{array, static_cast<const char*>(array.data()),
291                            shape};
292 }
293 
294 }  // namespace xla
295