1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pybind_api/ir/map_tensor_py.h"
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include "pybind11/pytypes.h"
22 #include "pybind_api/ir/tensor_py.h"
23 #include "include/common/pybind_api/api_register.h"
24 #include "include/common/utils/python_adapter.h"
25 #include "mindspore/ccsrc/include/backend/distributed/embedding_cache/embedding_cache_utils.h"
26 #include "pipeline/jit/ps/parse/parse_base.h"
27 #include "utils/hash_set.h"
28 #include "utils/log_adapter.h"
29
30 namespace mindspore {
31 using tensor::TensorPy;
32
ConvertMapTensorDefaultValue(const py::object & default_value_obj,const TypePtr & value_dtype)33 static ValuePtr ConvertMapTensorDefaultValue(const py::object &default_value_obj, const TypePtr &value_dtype) {
34 static const mindspore::HashSet<std::string> support_init_names = {"zeros", "ones", "normal"};
35 if (py::isinstance<py::str>(default_value_obj)) {
36 std::string init_name = py::cast<std::string>(default_value_obj);
37 if (support_init_names.find(init_name) == support_init_names.end()) {
38 MS_EXCEPTION(ValueError) << "Unsupported init name for map parameter: " << init_name;
39 }
40 return std::make_shared<StringImm>(init_name);
41 }
42 ValuePtr default_value;
43 bool convert_ok = parse::ConvertData(default_value_obj, &default_value, false, value_dtype, false);
44 if (!convert_ok || default_value == nullptr) {
45 MS_EXCEPTION(ValueError) << "Incorrect default value for map parameter: " << py::str(default_value_obj);
46 }
47 return default_value;
48 }
49
ConvertMapTensorFilterValue(const py::object & filter_value_obj)50 static ValuePtr ConvertMapTensorFilterValue(const py::object &filter_value_obj) {
51 ValuePtr filter_value;
52 bool convert_ok = parse::ConvertData(filter_value_obj, &filter_value);
53 if (!convert_ok || filter_value == nullptr) {
54 MS_EXCEPTION(ValueError) << "Incorrect filter value for map parameter: " << py::str(filter_value_obj);
55 }
56 return filter_value;
57 }
58
UpdateFromNumpy(const MapTensorPtr & map_tensor,const std::tuple<py::array,py::array,py::array> & numpy_data)59 void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
60 const std::tuple<py::array, py::array, py::array> &numpy_data) {
61 MS_EXCEPTION_IF_NULL(map_tensor);
62 MapTensor::ExportData data;
63 constexpr size_t key_index = 0;
64 constexpr size_t value_index = 1;
65 constexpr size_t status_index = 2;
66 data.key_tensor = TensorPy::MakeTensorOfNumpy(std::get<key_index>(numpy_data));
67 data.value_tensor = TensorPy::MakeTensorOfNumpy(std::get<value_index>(numpy_data));
68 data.status_tensor = TensorPy::MakeTensorOfNumpy(std::get<status_index>(numpy_data));
69 map_tensor->Update(data);
70 }
71
ExportAsNumpy(const MapTensorPtr & map_tensor,bool incremental)72 std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
73 bool incremental) {
74 MS_EXCEPTION_IF_NULL(map_tensor);
75 auto data = map_tensor->Export(incremental);
76 return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
77 TensorPy::AsNumpy(*data.status_tensor));
78 }
79
ExportBytes(const MapTensorPtr & map_tensor,bool incremental)80 std::tuple<py::bytes, py::bytes, py::bytes> MapTensorPy::ExportBytes(const MapTensorPtr &map_tensor, bool incremental) {
81 MS_EXCEPTION_IF_NULL(map_tensor);
82 auto data = map_tensor->Export(incremental);
83 return std::make_tuple(TensorPy::GetBytes(*data.key_tensor), TensorPy::GetBytes(*data.value_tensor),
84 TensorPy::GetBytes(*data.status_tensor));
85 }
86
ExportSliceAsNumpy(const MapTensorPtr & map_tensor,bool incremental)87 std::tuple<py::array, py::array, py::array, bool> MapTensorPy::ExportSliceAsNumpy(const MapTensorPtr &map_tensor,
88 bool incremental) {
89 MS_EXCEPTION_IF_NULL(map_tensor);
90 bool last_slice = false;
91 auto data = map_tensor->ExportSlice(incremental, &last_slice);
92 return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
93 TensorPy::AsNumpy(*data.status_tensor), last_slice);
94 }
95
ExportPersistentSliceAsNumpy(const MapTensorPtr & map_tensor,int32_t param_key,bool incremental)96 std::tuple<py::array, py::array, py::array, bool> MapTensorPy::ExportPersistentSliceAsNumpy(
97 const MapTensorPtr &map_tensor, int32_t param_key, bool incremental) {
98 MS_EXCEPTION_IF_NULL(map_tensor);
99 bool last_slice = false;
100 auto storage = embedding_storage_manager.Get(param_key);
101 MS_EXCEPTION_IF_NULL(storage);
102 auto slice_data = storage->ExportSlice(incremental, &last_slice);
103 map_tensor->TransExportDataToTensor(slice_data);
104
105 return std::make_tuple(TensorPy::AsNumpy(*(map_tensor->key_tensor())),
106 TensorPy::AsNumpy(*(map_tensor->value_tensor())),
107 TensorPy::AsNumpy(*(map_tensor->status_tensor())), last_slice);
108 }
109
PyMapTensorGetKeys(const MapTensorPtr & map_tensor)110 static tensor::TensorPtr PyMapTensorGetKeys(const MapTensorPtr &map_tensor) {
111 MS_EXCEPTION_IF_NULL(map_tensor);
112 return map_tensor->key_tensor();
113 }
114
PyMapTensorGetValues(const MapTensorPtr & map_tensor)115 static tensor::TensorPtr PyMapTensorGetValues(const MapTensorPtr &map_tensor) {
116 MS_EXCEPTION_IF_NULL(map_tensor);
117 return map_tensor->value_tensor();
118 }
119
PyMapTensorGetData(const MapTensorPtr & map_tensor)120 static std::pair<tensor::TensorPtr, tensor::TensorPtr> PyMapTensorGetData(const MapTensorPtr &map_tensor) {
121 MS_EXCEPTION_IF_NULL(map_tensor);
122 auto keys = map_tensor->key_tensor();
123 auto values = map_tensor->value_tensor();
124 return std::pair<tensor::TensorPtr, tensor::TensorPtr>(keys, values);
125 }
126
127 namespace tensor {
RegMapTensor(const py::module * m)128 void RegMapTensor(const py::module *m) {
129 // Define python MapTensor class.
130 (void)py::class_<MapTensor, MapTensorPtr>(*m, "MapTensor_")
131 .def(py::init([](const TypePtr &key_dtype, const TypePtr &value_dtype, const ShapeVector &value_shape,
132 const py::object &default_value_obj, const py::object &permit_filter_obj,
133 const py::object &evict_filter_obj) {
134 TypeId key_dtype_id = ((key_dtype != nullptr) ? key_dtype->type_id() : TypeId::kNumberTypeInt32);
135 TypeId value_dtype_id = ((value_dtype != nullptr) ? value_dtype->type_id() : TypeId::kNumberTypeFloat32);
136 ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
137 ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
138 ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
139 return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value,
140 permit_filter_value, evict_filter_value);
141 }),
142 py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"), py::arg("default_value"),
143 py::arg("permit_filter_value"), py::arg("evict_filter_value"))
144 .def(py::init([](const Tensor &key_tensor, const Tensor &value_tensor, const py::object &default_value_obj,
145 const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
146 auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
147 auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
148 auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
149 auto value_dtype = value_tensor_ptr->Dtype();
150 ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
151 ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
152 ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
153 return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
154 permit_filter_value, evict_filter_value);
155 }),
156 py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),
157 py::arg("evict_filter_value"))
158 .def_property_readonly("key_dtype", &MapTensor::KeyDtype)
159 .def_property_readonly("value_dtype", &MapTensor::ValueDtype)
160 .def_property_readonly("value_shape", &MapTensor::value_shape)
161 .def_property_readonly("size", &MapTensor::size)
162 .def("export_data", &MapTensorPy::ExportAsNumpy)
163 .def("export_bytes", &MapTensorPy::ExportBytes)
164 .def("import_data", &MapTensorPy::UpdateFromNumpy)
165 .def("export_slice_data", &MapTensorPy::ExportSliceAsNumpy)
166 .def("export_persistent_slice_data", &MapTensorPy::ExportPersistentSliceAsNumpy)
167 .def("__str__", &MapTensor::ToString)
168 .def("__repr__", &MapTensor::ToString)
169 .def("get_keys", &PyMapTensorGetKeys)
170 .def("get_values", &PyMapTensorGetValues)
171 .def("get_data", &PyMapTensorGetData);
172 }
173 } // namespace tensor
174 } // namespace mindspore
175