• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pybind_api/ir/map_tensor_py.h"
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include "pybind11/pytypes.h"
22 #include "pybind_api/ir/tensor_py.h"
23 #include "include/common/pybind_api/api_register.h"
24 #include "include/common/utils/python_adapter.h"
25 #include "mindspore/ccsrc/include/backend/distributed/embedding_cache/embedding_cache_utils.h"
26 #include "pipeline/jit/ps/parse/parse_base.h"
27 #include "utils/hash_set.h"
28 #include "utils/log_adapter.h"
29 
30 namespace mindspore {
31 using tensor::TensorPy;
32 
ConvertMapTensorDefaultValue(const py::object & default_value_obj,const TypePtr & value_dtype)33 static ValuePtr ConvertMapTensorDefaultValue(const py::object &default_value_obj, const TypePtr &value_dtype) {
34   static const mindspore::HashSet<std::string> support_init_names = {"zeros", "ones", "normal"};
35   if (py::isinstance<py::str>(default_value_obj)) {
36     std::string init_name = py::cast<std::string>(default_value_obj);
37     if (support_init_names.find(init_name) == support_init_names.end()) {
38       MS_EXCEPTION(ValueError) << "Unsupported init name for map parameter: " << init_name;
39     }
40     return std::make_shared<StringImm>(init_name);
41   }
42   ValuePtr default_value;
43   bool convert_ok = parse::ConvertData(default_value_obj, &default_value, false, value_dtype, false);
44   if (!convert_ok || default_value == nullptr) {
45     MS_EXCEPTION(ValueError) << "Incorrect default value for map parameter: " << py::str(default_value_obj);
46   }
47   return default_value;
48 }
49 
ConvertMapTensorFilterValue(const py::object & filter_value_obj)50 static ValuePtr ConvertMapTensorFilterValue(const py::object &filter_value_obj) {
51   ValuePtr filter_value;
52   bool convert_ok = parse::ConvertData(filter_value_obj, &filter_value);
53   if (!convert_ok || filter_value == nullptr) {
54     MS_EXCEPTION(ValueError) << "Incorrect filter value for map parameter: " << py::str(filter_value_obj);
55   }
56   return filter_value;
57 }
58 
UpdateFromNumpy(const MapTensorPtr & map_tensor,const std::tuple<py::array,py::array,py::array> & numpy_data)59 void MapTensorPy::UpdateFromNumpy(const MapTensorPtr &map_tensor,
60                                   const std::tuple<py::array, py::array, py::array> &numpy_data) {
61   MS_EXCEPTION_IF_NULL(map_tensor);
62   MapTensor::ExportData data;
63   constexpr size_t key_index = 0;
64   constexpr size_t value_index = 1;
65   constexpr size_t status_index = 2;
66   data.key_tensor = TensorPy::MakeTensorOfNumpy(std::get<key_index>(numpy_data));
67   data.value_tensor = TensorPy::MakeTensorOfNumpy(std::get<value_index>(numpy_data));
68   data.status_tensor = TensorPy::MakeTensorOfNumpy(std::get<status_index>(numpy_data));
69   map_tensor->Update(data);
70 }
71 
ExportAsNumpy(const MapTensorPtr & map_tensor,bool incremental)72 std::tuple<py::array, py::array, py::array> MapTensorPy::ExportAsNumpy(const MapTensorPtr &map_tensor,
73                                                                        bool incremental) {
74   MS_EXCEPTION_IF_NULL(map_tensor);
75   auto data = map_tensor->Export(incremental);
76   return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
77                          TensorPy::AsNumpy(*data.status_tensor));
78 }
79 
ExportBytes(const MapTensorPtr & map_tensor,bool incremental)80 std::tuple<py::bytes, py::bytes, py::bytes> MapTensorPy::ExportBytes(const MapTensorPtr &map_tensor, bool incremental) {
81   MS_EXCEPTION_IF_NULL(map_tensor);
82   auto data = map_tensor->Export(incremental);
83   return std::make_tuple(TensorPy::GetBytes(*data.key_tensor), TensorPy::GetBytes(*data.value_tensor),
84                          TensorPy::GetBytes(*data.status_tensor));
85 }
86 
ExportSliceAsNumpy(const MapTensorPtr & map_tensor,bool incremental)87 std::tuple<py::array, py::array, py::array, bool> MapTensorPy::ExportSliceAsNumpy(const MapTensorPtr &map_tensor,
88                                                                                   bool incremental) {
89   MS_EXCEPTION_IF_NULL(map_tensor);
90   bool last_slice = false;
91   auto data = map_tensor->ExportSlice(incremental, &last_slice);
92   return std::make_tuple(TensorPy::AsNumpy(*data.key_tensor), TensorPy::AsNumpy(*data.value_tensor),
93                          TensorPy::AsNumpy(*data.status_tensor), last_slice);
94 }
95 
ExportPersistentSliceAsNumpy(const MapTensorPtr & map_tensor,int32_t param_key,bool incremental)96 std::tuple<py::array, py::array, py::array, bool> MapTensorPy::ExportPersistentSliceAsNumpy(
97   const MapTensorPtr &map_tensor, int32_t param_key, bool incremental) {
98   MS_EXCEPTION_IF_NULL(map_tensor);
99   bool last_slice = false;
100   auto storage = embedding_storage_manager.Get(param_key);
101   MS_EXCEPTION_IF_NULL(storage);
102   auto slice_data = storage->ExportSlice(incremental, &last_slice);
103   map_tensor->TransExportDataToTensor(slice_data);
104 
105   return std::make_tuple(TensorPy::AsNumpy(*(map_tensor->key_tensor())),
106                          TensorPy::AsNumpy(*(map_tensor->value_tensor())),
107                          TensorPy::AsNumpy(*(map_tensor->status_tensor())), last_slice);
108 }
109 
PyMapTensorGetKeys(const MapTensorPtr & map_tensor)110 static tensor::TensorPtr PyMapTensorGetKeys(const MapTensorPtr &map_tensor) {
111   MS_EXCEPTION_IF_NULL(map_tensor);
112   return map_tensor->key_tensor();
113 }
114 
PyMapTensorGetValues(const MapTensorPtr & map_tensor)115 static tensor::TensorPtr PyMapTensorGetValues(const MapTensorPtr &map_tensor) {
116   MS_EXCEPTION_IF_NULL(map_tensor);
117   return map_tensor->value_tensor();
118 }
119 
PyMapTensorGetData(const MapTensorPtr & map_tensor)120 static std::pair<tensor::TensorPtr, tensor::TensorPtr> PyMapTensorGetData(const MapTensorPtr &map_tensor) {
121   MS_EXCEPTION_IF_NULL(map_tensor);
122   auto keys = map_tensor->key_tensor();
123   auto values = map_tensor->value_tensor();
124   return std::pair<tensor::TensorPtr, tensor::TensorPtr>(keys, values);
125 }
126 
127 namespace tensor {
RegMapTensor(const py::module * m)128 void RegMapTensor(const py::module *m) {
129   // Define python MapTensor class.
130   (void)py::class_<MapTensor, MapTensorPtr>(*m, "MapTensor_")
131     .def(py::init([](const TypePtr &key_dtype, const TypePtr &value_dtype, const ShapeVector &value_shape,
132                      const py::object &default_value_obj, const py::object &permit_filter_obj,
133                      const py::object &evict_filter_obj) {
134            TypeId key_dtype_id = ((key_dtype != nullptr) ? key_dtype->type_id() : TypeId::kNumberTypeInt32);
135            TypeId value_dtype_id = ((value_dtype != nullptr) ? value_dtype->type_id() : TypeId::kNumberTypeFloat32);
136            ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
137            ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
138            ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
139            return std::make_shared<MapTensor>(key_dtype_id, value_dtype_id, value_shape, default_value,
140                                               permit_filter_value, evict_filter_value);
141          }),
142          py::arg("key_dtype"), py::arg("value_dtype"), py::arg("value_shape"), py::arg("default_value"),
143          py::arg("permit_filter_value"), py::arg("evict_filter_value"))
144     .def(py::init([](const Tensor &key_tensor, const Tensor &value_tensor, const py::object &default_value_obj,
145                      const py::object &permit_filter_obj, const py::object &evict_filter_obj) {
146            auto key_tensor_ptr = std::make_shared<tensor::Tensor>(key_tensor);
147            auto value_tensor_ptr = std::make_shared<tensor::Tensor>(value_tensor);
148            auto status_tensor_ptr = std::make_shared<Tensor>(kNumberTypeInt, key_tensor.shape());
149            auto value_dtype = value_tensor_ptr->Dtype();
150            ValuePtr default_value = ConvertMapTensorDefaultValue(default_value_obj, value_dtype);
151            ValuePtr permit_filter_value = ConvertMapTensorFilterValue(permit_filter_obj);
152            ValuePtr evict_filter_value = ConvertMapTensorFilterValue(evict_filter_obj);
153            return std::make_shared<MapTensor>(key_tensor_ptr, value_tensor_ptr, status_tensor_ptr, default_value,
154                                               permit_filter_value, evict_filter_value);
155          }),
156          py::arg("key_tensor"), py::arg("value_tensor"), py::arg("default_value"), py::arg("permit_filter_value"),
157          py::arg("evict_filter_value"))
158     .def_property_readonly("key_dtype", &MapTensor::KeyDtype)
159     .def_property_readonly("value_dtype", &MapTensor::ValueDtype)
160     .def_property_readonly("value_shape", &MapTensor::value_shape)
161     .def_property_readonly("size", &MapTensor::size)
162     .def("export_data", &MapTensorPy::ExportAsNumpy)
163     .def("export_bytes", &MapTensorPy::ExportBytes)
164     .def("import_data", &MapTensorPy::UpdateFromNumpy)
165     .def("export_slice_data", &MapTensorPy::ExportSliceAsNumpy)
166     .def("export_persistent_slice_data", &MapTensorPy::ExportPersistentSliceAsNumpy)
167     .def("__str__", &MapTensor::ToString)
168     .def("__repr__", &MapTensor::ToString)
169     .def("get_keys", &PyMapTensorGetKeys)
170     .def("get_values", &PyMapTensorGetValues)
171     .def("get_data", &PyMapTensorGetData);
172 }
173 }  // namespace tensor
174 }  // namespace mindspore
175