• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDDATA_PYBINDSUPPORT_H
18 #define MINDDATA_PYBINDSUPPORT_H
19 
20 #include <string>
21 
22 #include "pybind11/numpy.h"
23 #include "pybind11/pybind11.h"
24 #include "base/float16.h"
25 
26 namespace py = pybind11;
27 
28 namespace pybind11 {
29 namespace detail {
30 // Similar to enums in `pybind11/numpy.h`. Determined by doing:
31 // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
32 constexpr int kNpyFloat16 = 23;
33 
34 template <typename T>
35 struct npy_scalar_caster {
36   PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
37   using Array = array_t<T>;
38 
loadnpy_scalar_caster39   bool load(handle src, bool convert) {
40     // Taken from Eigen casters. Permits either scalar dtype or scalar array.
41     handle type = dtype::of<T>().attr("type");  // Could make more efficient.
42     if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) {
43       return false;
44     }
45 
46     Array tmp = Array::ensure(src);
47     if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
48       this->value = *tmp.data();
49       return true;
50     }
51 
52     return false;
53   }
54 
castnpy_scalar_caster55   static handle cast(T src, return_value_policy, handle) {
56     Array tmp({1});
57     tmp.mutable_at(0) = src;
58     tmp.resize({});
59 
60     // You could also just return the array if you want a scalar array.
61     object scalar = tmp[tuple()];
62     return scalar.release();
63   }
64 };
65 
66 template <>
67 struct npy_format_descriptor<float16> {
68   static constexpr auto name = "float16";
69   static pybind11::dtype dtype() {
70     handle ptr = npy_api::get().PyArray_DescrFromType_(kNpyFloat16);
71     return reinterpret_borrow<pybind11::dtype>(ptr);
72   }
73   virtual ~npy_format_descriptor<float16>() {}
74 
75   static std::string format() {
76     // following: https://docs.python.org/3/library/struct.html#format-characters
77     return "e";
78   }
79 };
80 
81 template <>
82 struct type_caster<float16> : public npy_scalar_caster<float16> {
83   static constexpr auto name = "float16";
84 };
85 }  // namespace detail
86 }  // namespace pybind11
87 
88 #endif  // MINDDATA_PYBINDSUPPORT_H
89