• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""MapParameter implementation."""
16from __future__ import absolute_import
17
18__all__ = ['MapParameter']
19
20import os
21import sys
22from copy import copy
23import numbers
24import mindspore as ms
25from mindspore.common.parameter import Parameter, _get_unique_parameter_key
26from mindspore._c_expression import Tensor as Tensor_
27from mindspore._c_expression import MapTensor_
28from mindspore.ops.operations import _map_tensor_ops
29
30
31class MapParameter(Parameter):
32    """
33    MapParameter is a parameter that stores a map like data structure.
34
35    .. warning::
36        This is an experimental API that is subject to change or deletion.
37
38    Args:
39        key_dtype (:class:`mindspore.dtype`): The data type of the key. The argument should be defined in
40            `mindspore.dtype`, currently only integer types are supported. Default: int32.
41        value_dtype (:class:`mindspore.dtype`): The data type of the value Tensor. The argument should
42            be defined in `mindspore.dtype`. Default: float32.
43        value_shape (Union[tuple, list, int]): Used to indicate the shape of the value Tensor. The argument should be
44            a list of integers, a tuple of integers or an integer. Default: 1.
45        key_tensor (:class:`mindspore.tensor`): The key Tensor.
46        value_tensor (:class:`mindspore.tensor`): The value Tensor.
47        default_value (Union[numbers.Number, str]): The default value number or initializer name. Default: 'normal'.
48        permit_filter_value (numbers.Number): The permit filter value number. Default: 1.
49        evict_filter_value (numbers.Number): The evict filter value number. Default: MAX_SIZE.
50        name (str): Name of the map parameter. Default: ``None``.
51        requires_grad (bool): True if the parameter requires gradient. Default: True.
52
53
54    Examples:
55        >>> import mindspore as ms
56        >>> from mindspore import Tensor
57        >>> from mindspore.experimental import MapParameter
58        >>>
59        >>> m = MapParameter(key_dtype=ms.int32, value_dtype=ms.float32, value_shape=(3), default_value='zeros')
60        >>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
61        [[0. 0. 0.]
62         [0. 0. 0.]
63         [0. 0. 0.]]
64        >>> m.put(Tensor([1, 2], dtype=ms.int32), Tensor([[1, 1, 1], [2, 2, 2]], dtype=ms.float32))
65        >>> t = m.get(Tensor([1, 2, 3], dtype=ms.int32))
66        >>> print(t)
67        [[1. 1. 1.]
68         [2. 2. 2.]
69         [0. 0. 0.]]
70        >>> m.erase(Tensor([2, 3], dtype=ms.int32))
71        >>> print(t)
72        [[1. 1. 1.]]
73
74    """
75
76    def __new__(cls, key_dtype=None, value_dtype=None, value_shape=None, key_tensor=None, value_tensor=None,
77                default_value='normal', permit_filter_value=1, evict_filter_value=sys.maxsize, **kwargs):
78        if value_dtype is not None:
79            if isinstance(value_shape, numbers.Number):
80                value_shape = (value_shape,)
81            data = Tensor_(value_dtype, value_shape)
82        elif value_tensor is not None:
83            data = Tensor_(value_tensor.dtype, value_tensor.shape)
84        else:
85            # default
86            data = Tensor_(ms.float32, (1,))
87        obj = Tensor_.__new__(cls)
88        Tensor_.__init__(obj, data)
89        # Compatible attributes with Parameter.
90        obj.has_init = False
91        obj.init_mode = None
92        obj.is_default_input_init = False
93        # MapParameter added attributes.
94        MapParameter._check_map_parameter_args(key_tensor, key_dtype, value_tensor, value_dtype, value_shape)
95        if key_tensor is not None:
96            obj.key_dtype = key_tensor.dtype
97        else:
98            obj.key_dtype = key_dtype if key_dtype is not None else ms.int32
99
100        if value_tensor is not None:
101            obj.value_dtype = value_tensor.dtype
102        else:
103            obj.value_dtype = value_dtype if value_dtype is not None else ms.float32
104
105        if value_tensor is not None:
106            obj.value_shape = value_tensor.shape
107        else:
108            obj.value_shape = value_shape if value_shape is not None else (1,)
109
110        obj.default_value = default_value
111        obj.permit_filter_value = permit_filter_value
112        obj.evict_filter_value = evict_filter_value
113        obj.key_tensor = key_tensor
114        obj.value_tensor = value_tensor
115        return obj
116
117    def __init__(self, name=None, requires_grad=True, **kwargs):
118        Parameter.__init__(self, self, name=name, requires_grad=requires_grad)
119        if self.key_tensor is not None and self.value_tensor is not None:
120            self._map_tensor = MapTensor_(self.key_tensor, self.value_tensor, self.default_value,
121                                          self.permit_filter_value, self.evict_filter_value)
122        else:
123            self._map_tensor = MapTensor_(self.key_dtype, self.value_dtype, self.value_shape, self.default_value,
124                                          self.permit_filter_value, self.evict_filter_value)
125        self.map_put = _map_tensor_ops.put
126        self.map_erase = _map_tensor_ops.erase
127
128    def __getitem__(self, key_tensor):
129        return self.get(key_tensor, True)
130
131    def __setitem__(self, key_tensor, value_tensor):
132        return self.put(key_tensor, value_tensor)
133
134    def __str__(self):
135        return 'MapParameter(' + str(self._map_tensor) + ')'
136
137    def __copy__(self):
138        x = type(self)()
139        x.__dict__.update(self.__dict__)
140        return x
141
142    @staticmethod
143    def _check_map_parameter_args(key_tensor, key_dtype, value_tensor, value_dtype, value_shape):
144        if key_dtype is not None and key_tensor is not None and key_dtype != key_tensor.dtype:
145            raise ValueError(f"When initializing a MapParameter, 'key_dtype' and 'key_tensor.dtype' should be set the"
146                             f" same.")
147        if value_dtype is not None and value_tensor is not None and value_dtype != value_tensor.dtype:
148            raise ValueError(f"When initializing a MapParameter, 'value_dtype' and 'value_tensor.dtype' should be set "
149                             f"the same.")
150        if value_shape is not None and value_tensor is not None and value_shape != value_tensor.shape:
151            raise ValueError(f"When initializing a map_parameter, 'value_shape' and 'value_tensor.shape' should be set "
152                             f"the same.")
153
154    def clone(self, init='same'):
155        """
156        Clone the MapParameter.
157
158        Args:
159            init (Union[str, numbers.Number]): Initialize the default value of the new map parameter.
160                If `init` is a `numbers.Number`, clone a new map parameter with the same key value shape
161                and dtype, and the default value of the new map parameter will be set according to `init`.
162                If `init` is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
163                If `init` is 'same', clone a new map parameter with the same default value. Default: 'same'.
164
165        Returns:
166            MapParameter, the new map parameter.
167        """
168        x = copy(self)
169        x.param_info = self.param_info.clone()
170        info = self.param_info
171        if hasattr(info, "cloned_obj"):
172            info.cloned_obj.append(x)
173        else:
174            info.cloned_obj = [x]
175        self.param_info = info
176        if init != 'same':
177            x.default_value = init
178        x._map_tensor = MapTensor_(x.key_dtype, x.value_dtype, x.value_shape, x.default_value, x.permit_filter_value,
179                                   x.evict_filter_value)
180        x.cache_enable = self.cache_enable
181        if x.cache_enable:
182            x.key = _get_unique_parameter_key()
183        return x
184
185    def get(self, key_tensor, insert_default_value=True):
186        """
187        Get value tensor according the key tensor, fill and return the default value in map parameter if key is not
188        existed.
189
190        Args:
191            key_tensor (Tensor): The key tensor.
192            insert_default_value (bool): The flag of insert default_value.
193
194        Returns:
195            Tensor, the value tensor for the key tensor.
196        """
197        map_get = _map_tensor_ops.MapTensorGet(insert_default_value)
198        return map_get(self._map_tensor, key_tensor)
199
200    def get_keys(self):
201        """
202        Get all keys as a tensor.
203
204        Returns:
205            Tensor, the tensor contains all keys.
206        """
207        return self._map_tensor.get_keys()
208
209    def get_values(self):
210        """
211        Get all values as a tensor.
212
213        Returns:
214            Tensor, the tensor contains all values.
215        """
216        return self._map_tensor.get_values()
217
218    def get_data(self):
219        """
220        Get all keys and values as a tensor.
221
222        Returns:
223            Tensor, the tensor contains all keys and values.
224        """
225        return self._map_tensor.get_data()
226
227    def put(self, key_tensor, value_tensor):
228        """
229        Insert or update records according the given key tensor and value tensor.
230
231        Args:
232            key_tensor (Tensor): The key tensor.
233            value_tensor (Tensor): The value tensor.
234
235        Returns:
236            MapParameter, the MapParameter object itself.
237        """
238        self.map_put(self._map_tensor, key_tensor, value_tensor)
239        return self._map_tensor
240
241    def erase(self, key_tensor):
242        """
243        Remove records according the given key tensor.
244
245        Args:
246            key_tensor (Tensor): The key tensor.
247
248        Returns:
249            MapParameter, the MapParameter object itself.
250        """
251        self.map_erase(self._map_tensor, key_tensor)
252        return self._map_tensor
253
254    def export_data(self, incremental=False):
255        """
256        Export data from this map parameter.
257
258        Args:
259            incremental (bool): False for full export, otherwise for incremental export. Default: False.
260            When exporting data incrementally, the value_array does not contain unchanged data.The length
261            of the key_array and the length of the status_array are consistent.
262
263        Returns:
264            Tuple(key_array, value_array, status_array), The exported data as a tuple.
265        """
266        return self._map_tensor.export_data(incremental)
267
268    def export_bytes(self, incremental=False):
269        """
270        Export bytes from this map parameter.
271
272        Args:
273            incremental (bool): False for full export, otherwise for incremental export. Default: False.
274            When exporting data incrementally, the value_array does not contain unchanged data. The length
275            of the key_array and the length of the status_array are consistent.
276
277        Returns:
278            Tuple(bytes, bytes, bytes), The exported bytes as a tuple.
279        """
280        return self._map_tensor.export_bytes(incremental)
281
282    def import_data(self, data):
283        """
284        Import this map parameter from exported data.
285
286        Args:
287            data (Tuple): The data tuple with key_array, value_array and status_array.
288        """
289        self._map_tensor.import_data(data)
290
291    def export_slice_data(self, incremental=False):
292        """
293        Export a slice data from this map parameter.
294        When MapParameter occupies a large memory, only one slice
295        of MapParameter is exported at a time (the default slice size is 1GB).
296
297        Args:
298            incremental (bool): False for full export, otherwise for incremental export. Default: False.
299            When exporting data incrementally, the value_array does not contain unchanged data.The length
300            of the key_array and the length of the status_array are consistent.
301
302        Returns:
303            Tuple(key_array, value_array, status_array, last_slice), The exported data as a tuple, and
304            the last_slice is bool variable and means whether finish export.
305        """
306        enable_persistent = "MS_EMBEDDING_REMOTE_CACHE_MEMORY_SIZE" in os.environ
307        if not enable_persistent:
308            return self._map_tensor.export_slice_data(incremental)
309        return self._map_tensor.export_persistent_slice_data(self.key, incremental)
310