• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""embedding service"""
16import json
17import os
18import math
19
20from mindspore.nn.layer.embedding_service_layer import ESInitLayer
21from mindspore.common.initializer import Uniform, TruncatedNormal, Constant
22from mindspore.nn.layer.embedding_service_layer import ESEmbeddingTableImport, ESEmbeddingTableExport, \
23    ESEmbeddingCKPTImport, ESEmbeddingCKPTExport
24
25_INT32_MAX_VALUE = 2147483647
26
27
28class CounterFilter:
29    """ Counter filter for embedding table. """
30    def __init__(self, filter_freq, default_key_or_value, default_key=None, default_value=None):
31        self.filter_freq = filter_freq
32        self.default_key = default_key
33        self.default_value = default_value
34        self.default_key_or_value = default_key_or_value
35
36
37class PaddingParamsOption:
38    """ padding key option for embedding service table. """
39    def __init__(self, padding_key=None,
40                 mask=True,
41                 mask_zero=False):
42        self.padding_key = padding_key
43        self.mask = mask
44        self.mask_zero = mask_zero
45
46
47class CompletionKeyOption:
48    """ completion key option for embedding service table. """
49    def __init__(self, completion_key=None, mask=1):
50        self.completion_key = completion_key
51        self.mask = mask
52
53
54class EvictOption:
55    """ Evict option for embedding table. """
56    def __init__(self, steps_to_live):
57        self.steps_to_live = steps_to_live
58
59
60class EmbeddingVariableOption:
61    """ option for embedding service table. """
62    def __init__(self, filter_option=None,
63                 padding_option=None,
64                 evict_option=None,
65                 completion_option=None,
66                 storage_option=None,
67                 feature_freezing_option=None,
68                 communication_option=None):
69        self.filter_option = filter_option
70        self.padding_option = padding_option
71        self.evict_option = evict_option
72        self.completion_option = completion_option
73        self.storage_option = storage_option
74        self.feature_freezing_option = feature_freezing_option
75        self.communication_option = communication_option
76
77
78class EsInitializer:
79    """Initializer for embedding service table."""
80    def __init__(self, initializer_mode, min_scale=-0.01, max_scale=0.01,
81                 constant_value=1.0, mu=0.0, sigma=1.0, seed=0):
82        self.initializer_mode = initializer_mode
83        self.min = min_scale
84        self.max = max_scale
85        self.constant_value = constant_value
86        self.mu = mu
87        self.sigma = sigma
88        self.seed = seed
89
90
91class EsOptimizer:
92    """Optimizer for embedding service table."""
93    def __init__(self, name, initial_accumulator_value=0., ms=0., mom=0.):
94        self.name = name
95        self.initial_accumulator_value = initial_accumulator_value
96        self.ms = ms
97        self.mom = mom
98
99
100def check_common_init_params(name, init_vocabulary_size, embedding_dim):
101    """
102    Check init params.
103    """
104    if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None):
105        raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.")
106    if not isinstance(name, str):
107        raise TypeError("embedding table name must be string.")
108    if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)):
109        raise ValueError("init_vocabulary_size and embedding_dim must be int.")
110    if init_vocabulary_size < 0:
111        raise ValueError("init_vocabulary_size can not be smaller than zero.")
112    if embedding_dim <= 0:
113        raise ValueError("embedding_dim must be greater than zero.")
114
115
116class EmbeddingServiceOut:
117    """
118    EmbeddingServiceOut
119    """
120    def __init__(self, table_id_dict, es_initializer=None, es_counter_filter=None,
121                 es_padding_keys=None, es_completion_keys=None):
122        self.table_id_dict = table_id_dict
123        self.es_initializer = es_initializer
124        self.es_counter_filter = es_counter_filter
125        self.es_padding_keys = es_padding_keys
126        self.es_completion_keys = es_completion_keys
127
128
129class EmbeddingService:
130    """
131    EmbeddingService
132    """
133    def __init__(self):
134        """
135        Init EmbeddingService
136        """
137        env_dist = os.environ
138        es_cluster_config = env_dist.get("ESCLUSTER_CONFIG_PATH")
139        if es_cluster_config is None:
140            raise ValueError("EsClusterConfig env is null.")
141        self._server_ip_to_ps_num = {}
142        with open(es_cluster_config, encoding='utf-8') as a:
143            es_cluster_config_json = json.load(a)
144            self._es_cluster_conf = json.dumps(es_cluster_config_json)
145            self._ps_num = int(es_cluster_config_json["psNum"])
146            self._ps_ids = []
147            self._ps_ids_list = es_cluster_config_json["psCluster"]
148            for each_ps in self._ps_ids_list:
149                self._server_ip_to_ps_num[each_ps["ctrlPanel"]["ipaddr"]] = 0
150
151            for each_ps in self._ps_ids_list:
152                self._ps_ids.append(each_ps["id"])
153                ctrl_panel = each_ps["ctrlPanel"]
154                self._server_ip_to_ps_num[ctrl_panel["ipaddr"]] += 1
155
156            for each_server_ps_num in self._server_ip_to_ps_num:
157                if self._server_ip_to_ps_num[each_server_ps_num] > 4:
158                    raise ValueError("PS num of one server can not exceed 4, please check config params.")
159                if self._ps_num > 4:
160                    raise ValueError("PS num of one server can not exceed 4, please check config params.")
161
162        # storage each ps table's params
163        self._table_to_embedding_dim = {}
164        self._table_to_max_num = {}
165        self._table_to_optimizer = {}
166        self._table_to_slot_var_num = {}
167        self._table_to_counter_filter = {}
168        self._table_id_to_padding_key = {}
169        self._table_id_to_completion_key = {}
170        self._train_mode = True
171        self._train_level = False
172        self._optimizer = None
173        self._init_table_flag = False
174
175        self._small_table_name_list = []
176        self._ps_table_count = 0
177        self._table_name_to_id = {}
178        self._table_id_to_name = {}
179        self._table_id_to_initializer = {}
180        self._table_id_to_steps_to_live = {}
181
182        self._ps_table_id_list = []
183        # storage lookup: table_id list, lookup result list, lookup key list
184        self._ps_lookup_index = 0
185        # storage all inited table names
186        self._table_name_has_init = []
187        # only storage all inited PS table names
188        self._ps_table_name_list = []
189        # now only use for adagrad accum
190        self._ps_table_id_to_optimizer_params = {}
191
192        # use for counter filter
193        self._table_use_counter_filter = {}
194        self._use_counter_filter = False
195        self._use_evict = False
196        self._use_padding_key = False
197        self._use_completion_key = False
198
199    def embedding_init(self, name, init_vocabulary_size, embedding_dim, max_feature_count,
200                       initializer=Uniform(scale=0.01), ev_option=None, optimizer=None, optimizer_param=None,
201                       mode="train"):
202        """
203        Init embedding
204        :param name: big table name
205        :param init_vocabulary_size: vocab size
206        :param embedding_dim: embedding dim
207        :param max_feature_count: max feature count
208        :param initializer: mindspore common initializer
209        :param ev_option: output of embedding_variable_option
210        :param optimizer: optimizer
211        :param optimizer_param: optimizer param
212        :param mode: mode, train or predict
213        :return: table_id_dict, es_initializer_dict, es_filter_dict
214        """
215        check_common_init_params(name=name, init_vocabulary_size=init_vocabulary_size, embedding_dim=embedding_dim)
216        table_id = self._check_and_update_ps_init_params(name=name, init_vocabulary_size=init_vocabulary_size,
217                                                         max_feature_count=max_feature_count, ev_option=ev_option)
218        self._ps_lookup_index = self._ps_table_count
219        self._table_to_embedding_dim[table_id] = embedding_dim
220        self._table_to_max_num[table_id] = max_feature_count
221        # storage the table id for embedding PS table
222        self._ps_table_id_list.append(table_id)
223        self._ps_table_name_list.append(name)
224
225        if len(self._ps_table_id_list) > 10:
226            raise ValueError("Now only 10 PS embedding tables can be init.")
227        bucket_size = math.ceil(init_vocabulary_size / self._ps_num)
228        if optimizer is None:
229            self._train_mode = False
230            self._table_to_slot_var_num[table_id] = 0
231        else:
232            self._check_ps_opt_and_initializer(optimizer=optimizer, initializer=initializer, table_id=table_id)
233            self._optimizer = optimizer
234            self._table_to_optimizer[table_id] = self._optimizer
235            self._ps_table_id_to_optimizer_params[table_id] = []
236            self._update_optimizer_slot_var_num(table_id=table_id)
237            # new train or continue train from a checkpoint
238            if initializer is not None:
239                self._train_level = True
240        filter_mode = self._init_counter_filter(table_id, ev_option)
241        self._init_padding_key(table_id, ev_option)
242        self._init_completion_key(table_id, ev_option)
243        self._init_optimizer_mode_and_params(table_id, optimizer_param)
244        es_init_layer = ESInitLayer(self._ps_num, self._ps_ids, self._train_mode, self._train_level, table_id,
245                                    bucket_size, embedding_dim, self._table_to_slot_var_num.get(table_id),
246                                    self._table_id_to_initializer.get(table_id), filter_mode, optimizer,
247                                    self._ps_table_id_to_optimizer_params.get(table_id), max_feature_count, mode)
248        es_init_layer()
249        return EmbeddingServiceOut(self._table_name_to_id, self._table_id_to_initializer,
250                                   self._table_to_counter_filter, self._table_id_to_padding_key,
251                                   self._table_id_to_completion_key)
252
253    def padding_param(self, padding_key, mask=True, mask_zero=False):
254        """
255        Init padding key param
256        :param padding_key: padding key
257        :param mask: padding key mask
258        :param mask_zero: mask zero
259        :return: PaddingParamsOption obj
260        """
261        if not isinstance(padding_key, int):
262            raise TypeError("padding_key must be int, please check.")
263        if not isinstance(mask, bool):
264            raise TypeError("mask must be bool, please check.")
265        self._use_padding_key = True
266        return PaddingParamsOption(padding_key=padding_key, mask=mask, mask_zero=mask_zero)
267
268    def completion_key(self, completion_key, mask=True):
269        """
270        Init completion key param
271        :param completion_key: completion key
272        :param mask: completion key mask
273        :return: CompletionKeyOption obj
274        """
275        if not isinstance(completion_key, int):
276            raise TypeError("completion_key must be int, please check.")
277        if not isinstance(mask, bool):
278            raise TypeError("mask must be bool, please check.")
279        self._use_completion_key = True
280        completion_key_mask = 1 if mask is True else 0
281        return CompletionKeyOption(completion_key=completion_key, mask=completion_key_mask)
282
283    def counter_filter(self, filter_freq, default_key=None, default_value=None):
284        """
285        Set filter_option
286        :param filter_freq: filter freq
287        :param default_key: default key
288        :param default_value: default value
289        :return: CounterFilter obj
290        """
291        if not isinstance(filter_freq, int):
292            raise TypeError("filter_freq must be int, please check.")
293        if filter_freq < 0:
294            raise ValueError("filter_freq must can not be smaller than 0.")
295        if (default_key is None) and (default_value is None):
296            raise ValueError("default_key and default_value can not be both None.")
297        if (default_key is not None) and (default_value is not None):
298            raise ValueError("default_key and default_value can not be both set.")
299        if default_key is None and (not isinstance(default_value, (int, float))):
300            raise TypeError("When default_value is not None, it must be float or int, please check.")
301        if default_value is None and (not isinstance(default_key, int)):
302            raise TypeError("When default_key is not None, it must be int, please check.")
303        self._use_counter_filter = True
304        if default_key is None:
305            return CounterFilter(filter_freq=filter_freq, default_key_or_value=0,
306                                 default_key=0, default_value=default_value)
307        return CounterFilter(filter_freq=filter_freq, default_key_or_value=1,
308                             default_key=default_key, default_value=1)
309
310    def evict_option(self, steps_to_live):
311        """
312        Set evict_option
313        :param steps_to_live: steps to live
314        :return: EvictOption obj
315        """
316        if not isinstance(steps_to_live, int):
317            raise TypeError("steps_to_live must be int, please check.")
318        if steps_to_live <= 0:
319            raise ValueError("steps_to_live must must be greater than 0.")
320        self._use_evict = True
321        return EvictOption(steps_to_live=steps_to_live)
322
323    def embedding_variable_option(self, filter_option=None, padding_option=None, evict_option=None,
324                                  completion_option=None, storage_option=None, feature_freezing_option=None,
325                                  communication_option=None):
326        """
327        Set embedding variable option
328        :param filter_option: filter policy, is the output of counter_filter
329        :param padding_option: padding policy, is the output of padding_keys
330        :param evict_option: evict policy
331        :param completion_option: not support
332        :param storage_option: not support
333        :param feature_freezing_option: not support
334        :param communication_option: not support
335        :return: EmbeddingVariableOption obj
336        """
337        if (filter_option is not None) and (not isinstance(filter_option, CounterFilter)):
338            raise ValueError("If padding_option isn't None, it must be CounterFilter type.")
339        if filter_option is not None:
340            self._use_counter_filter = True
341        if (padding_option is not None) and (not isinstance(padding_option, PaddingParamsOption)):
342            raise TypeError("If padding_option isn't None, it must be EmbeddingPaddingParamsOption type.")
343        if (completion_option is not None) and (not isinstance(completion_option, CompletionKeyOption)):
344            raise TypeError("If completion_option isn't None, it must be EmbeddingPaddingCompletionKeyOption type.")
345        if (evict_option is not None) and (not isinstance(evict_option, EvictOption)):
346            raise TypeError("When evict_option is not None, it must be EvictOption type.")
347        return EmbeddingVariableOption(filter_option=filter_option, padding_option=padding_option,
348                                       evict_option=evict_option, completion_option=completion_option,
349                                       storage_option=storage_option, feature_freezing_option=feature_freezing_option,
350                                       communication_option=communication_option)
351
352    def embedding_ckpt_export(self, file_path):
353        """
354        Export big table ckpt
355        :param file_path: the file path to storage ckpt ret
356        :return:
357        """
358        embedding_dim_list = []
359        value_total_len_list = []
360        steps_to_live_list = []
361        for table_id in self._ps_table_id_list:
362            embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
363            value_total_len_list.append(self._table_to_embedding_dim.get(table_id) *
364                                        (self._table_to_slot_var_num.get(table_id) + 1) + 2)
365            steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0))
366        embedding_ckpt_export_layer = ESEmbeddingCKPTExport(embedding_dim_list, value_total_len_list,
367                                                            self._ps_table_name_list, self._ps_table_id_list,
368                                                            file_path, steps_to_live_list)
369        embedding_ckpt_export_layer()
370
371    def embedding_table_export(self, file_path):
372        """
373        Export big table embedding
374        :param file_path: the file path to storage embedding ret
375        :return:
376        """
377        embedding_dim_list = []
378        steps_to_live_list = []
379        for table_id in self._ps_table_id_list:
380            embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
381            steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0))
382
383        embedding_table_export_layer = ESEmbeddingTableExport(embedding_dim_list, embedding_dim_list,
384                                                              self._ps_table_name_list, self._ps_table_id_list,
385                                                              file_path, steps_to_live_list)
386        embedding_table_export_layer()
387
388    def embedding_ckpt_import(self, file_path):
389        """
390        Import big table ckpt
391        :param file_path: the file path to import ckpt ret
392        :return:
393        """
394        embedding_dim_list = []
395        value_total_len_list = []
396        for table_id in self._ps_table_id_list:
397            embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
398            value_total_len_list.append(self._table_to_embedding_dim.get(table_id) *
399                                        (self._table_to_slot_var_num.get(table_id) + 1) + 2)
400
401        embedding_ckpt_export_layer = ESEmbeddingCKPTImport(embedding_dim_list, value_total_len_list,
402                                                            self._ps_table_name_list, self._ps_table_id_list,
403                                                            file_path)
404        embedding_ckpt_export_layer()
405
406    def embedding_table_import(self, file_path):
407        """
408        Import big table embedding
409        :param file_path: the file path to import embedding ret
410        :return:
411        """
412        embedding_dim_list = []
413        for table_id in self._ps_table_id_list:
414            embedding_dim_list.append(self._table_to_embedding_dim.get(table_id))
415        embedding_table_export_layer = ESEmbeddingTableImport(embedding_dim_list, embedding_dim_list,
416                                                              self._ps_table_name_list, self._ps_table_id_list,
417                                                              file_path)
418        embedding_table_export_layer()
419
420    def _check_and_update_ps_init_params(self, name, init_vocabulary_size, max_feature_count, ev_option):
421        """
422        Check parameter server params and init table id
423        """
424        steps_to_live = 0
425        if max_feature_count is None:
426            raise ValueError("For ps table, max_feature_count can not be None.")
427        if (ev_option is not None) and (not isinstance(ev_option, EmbeddingVariableOption)):
428            raise TypeError("For ps table, ev_option must be EmbeddingVariableOption type.")
429        if (ev_option is not None) and (ev_option.evict_option is not None):
430            steps_to_live = ev_option.evict_option.steps_to_live
431        if not isinstance(max_feature_count, int):
432            raise ValueError("For ps table, max_feature_count must be int.")
433        if init_vocabulary_size >= _INT32_MAX_VALUE:
434            raise ValueError("init_vocabulary_size exceeds int32 max value.")
435        if max_feature_count <= 0:
436            raise ValueError("For ps table, max_feature_count must be greater than zero.")
437        if name not in self._table_name_has_init:
438            table_id = self._ps_table_count
439            self._table_name_to_id[name] = table_id
440            self._table_id_to_name[table_id] = name
441            self._table_id_to_steps_to_live[table_id] = steps_to_live
442            self._ps_table_count += 1
443            self._table_name_has_init.append(name)
444        else:
445            raise ValueError("This table has been initialized.")
446        return table_id
447
448    def _check_ps_opt_and_initializer(self, optimizer, initializer, table_id):
449        """
450        Check args of parameter server
451        :param optimizer: the optimizer type, just support adam now
452        :param initializer: mindspore common initializer
453        :param table_id: table id
454        :return:
455        """
456        if optimizer not in ["adam", "adagrad", "adamw", "ftrl"]:
457            raise ValueError("optimizer should be one of adam, adagrad, adamw, ftrl")
458        if initializer is not None:
459            if isinstance(initializer, EsInitializer):
460                self._table_id_to_initializer[table_id] = initializer
461            elif isinstance(initializer, TruncatedNormal):
462                self._table_id_to_initializer[table_id] = \
463                    EsInitializer(initializer_mode="truncated_normal", mu=initializer.mean,
464                                  sigma=initializer.sigma, seed=initializer.seed[0])
465            elif isinstance(initializer, Uniform):
466                self._table_id_to_initializer[table_id] = \
467                    EsInitializer(initializer_mode="random_uniform",
468                                  min_scale=-initializer.scale,
469                                  max_scale=initializer.scale, seed=initializer.seed[0])
470            elif isinstance(initializer, Constant):
471                self._table_id_to_initializer[table_id] = \
472                    EsInitializer(initializer_mode="constant", constant_value=initializer.value)
473            else:
474                raise TypeError("initializer must be EsInitializer or mindspore initializer, and only support"
475                                "Uniform, TruncatedNormal and Constant value.")
476
477    def _update_optimizer_slot_var_num(self, table_id):
478        """
479        Update _table_to_slot_var_num by diff optimizer
480        """
481        # adam, adamw, rmsprop include m and v, 2 slots; adagrad include accumulator, 1 slot; sgd include 0 slot
482        if self._optimizer == "adagrad":
483            self._table_to_slot_var_num[table_id] = 1
484        elif self._optimizer == "sgd":
485            self._table_to_slot_var_num[table_id] = 0
486        else:
487            self._table_to_slot_var_num[table_id] = 2
488
489    def _init_counter_filter(self, table_id, ev_option):
490        """
491        Init counter filter params
492        """
493        if (ev_option is not None) and (ev_option.filter_option is not None):
494            filter_mode = "counter"
495            self._table_to_counter_filter[table_id] = ev_option.filter_option
496            self._table_use_counter_filter[table_id] = 1
497        else:
498            filter_mode = "no_filter"
499            self._table_use_counter_filter[table_id] = 0
500        return filter_mode
501
502    def _init_padding_key(self, table_id, ev_option):
503        """
504        Init padding key params
505        """
506        if (ev_option is not None) and (ev_option.padding_option is not None):
507            self._table_id_to_padding_key[table_id] = ev_option.padding_option
508
509    def _init_completion_key(self, table_id, ev_option):
510        """
511        Init completion key params
512        """
513        if (ev_option is not None) and (ev_option.completion_option is not None):
514            self._table_id_to_completion_key[table_id] = ev_option.completion_option
515
516    def _init_optimizer_mode_and_params(self, table_id, optimizer_param):
517        """
518        Init _ps_table_id_to_optimizer_params by diff optimizer
519        """
520        optimizer = self._table_to_optimizer.get(table_id)
521        if optimizer is None:
522            return
523        if optimizer in ["adagrad", "ftrl"]:
524            if optimizer_param is not None:
525                self._ps_table_id_to_optimizer_params[table_id].append(optimizer_param)
526            else:
527                raise ValueError("For adagrad optimizer, optimizer_param should have 1 param, "
528                                 "initial_accumulator_value.")
529
530        if optimizer in ["adam", "adamw", "sgd", "ftrl"]:
531            self._ps_table_id_to_optimizer_params[table_id].append(0.)
532