• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_
18 #define MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "pybind11/pybind11.h"
25 #include "pybind11/numpy.h"
26 
27 #include "ir/tensor.h"
28 
29 namespace py = pybind11;
30 
31 namespace pybind11 {
32 namespace detail {
33 // Similar to enums in `pybind11/numpy.h`. Determined by doing:
34 // python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
35 constexpr int NPY_FLOAT16 = 23;
36 
37 template <typename T>
38 struct npy_scalar_caster {
39   PYBIND11_TYPE_CASTER(T, _("PleaseOverride"));
40   using Array = array_t<T>;
41 
loadnpy_scalar_caster42   bool load(handle src, bool convert) {
43     // Taken from Eigen casters. Permits either scalar dtype or scalar array.
44     handle type = dtype::of<T>().attr("type");
45     if (!convert && !isinstance<Array>(src) && !isinstance(src, type)) return false;
46 
47     Array tmp = Array::ensure(src);
48     if (tmp && tmp.size() == 1 && tmp.ndim() == 0) {
49       this->value = *tmp.data();
50       return true;
51     }
52 
53     return false;
54   }
55 
castnpy_scalar_caster56   static handle cast(T src, return_value_policy, handle) {
57     Array tmp({1});
58     tmp.mutable_at(0) = src;
59     tmp.resize({});
60 
61     // You could also just return the array if you want a scalar array.
62     object scalar = tmp[tuple()];
63     return scalar.release();
64   }
65 };
66 
67 template <>
68 struct npy_format_descriptor<float16> {
69   static constexpr auto name = "float16";
70   static pybind11::dtype dtype() {
71     handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
72     return reinterpret_borrow<pybind11::dtype>(ptr);
73   }
74   virtual ~npy_format_descriptor<float16>() {}
75 };
76 
77 template <>
78 struct type_caster<float16> : public npy_scalar_caster<float16> {
79   static constexpr auto name = "float16";
80 };
81 }  // namespace detail
82 }  // namespace pybind11
83 
84 // brief mindspore namespace.
85 //
86 // mindspore namespace is the top level namespace of Mindsporeession project.
87 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
88 namespace mindspore {
89 // brief mindspore::tensor namespace
90 //
91 // A sub namespace in ME to support tensor related definition.
92 namespace tensor {
93 // Tensor python wrapper and adapter class.
94 class TensorPy {
95  public:
96   // brief Create Tensor from a numpy array object.
97   //
98   // param input [py::array] Data value of the tensor.
99   // param data_type [TypeId] Data type of the tensor.
100   static TensorPtr MakeTensor(const py::array &input, const TypePtr &data_type = nullptr);
101 
102   // brief Create Tensor from a numpy array without copy.
103   //
104   // param input [py::array] Data value of the tensor.
105   static TensorPtr MakeTensorOfNumpy(const py::array &input);
106 
107   static py::array SyncAsNumpy(const Tensor &tensor);
108 
109   static py::array AsNumpy(const Tensor &tensor);
110 
111   static py::tuple GetPyTupleShape(const Tensor &tensor);
112 
113   static py::tuple GetPyTupleStrides(const Tensor &tensor);
114 
115   static py::int_ GetPyItemSize(const Tensor &tensor);
116 
117   static py::int_ GetPyNBytes(const Tensor &tensor);
118 
119   static void FlushFromCache(const Tensor &tensor);
120 };
121 }  // namespace tensor
122 }  // namespace mindspore
123 
124 #endif  // MINDSPORE_CCSRC_UTILS_TENSOR_PY_H_
125