1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""embedding service""" 16import json 17import os 18import math 19 20from mindspore.nn.layer.embedding_service_layer import ESInitLayer 21from mindspore.common.initializer import Uniform, TruncatedNormal, Constant 22from mindspore.nn.layer.embedding_service_layer import ESEmbeddingTableImport, ESEmbeddingTableExport, \ 23 ESEmbeddingCKPTImport, ESEmbeddingCKPTExport 24 25_INT32_MAX_VALUE = 2147483647 26 27 28class CounterFilter: 29 """ Counter filter for embedding table. """ 30 def __init__(self, filter_freq, default_key_or_value, default_key=None, default_value=None): 31 self.filter_freq = filter_freq 32 self.default_key = default_key 33 self.default_value = default_value 34 self.default_key_or_value = default_key_or_value 35 36 37class PaddingParamsOption: 38 """ padding key option for embedding service table. """ 39 def __init__(self, padding_key=None, 40 mask=True, 41 mask_zero=False): 42 self.padding_key = padding_key 43 self.mask = mask 44 self.mask_zero = mask_zero 45 46 47class CompletionKeyOption: 48 """ completion key option for embedding service table. """ 49 def __init__(self, completion_key=None, mask=1): 50 self.completion_key = completion_key 51 self.mask = mask 52 53 54class EvictOption: 55 """ Evict option for embedding table. """ 56 def __init__(self, steps_to_live): 57 self.steps_to_live = steps_to_live 58 59 60class EmbeddingVariableOption: 61 """ option for embedding service table. """ 62 def __init__(self, filter_option=None, 63 padding_option=None, 64 evict_option=None, 65 completion_option=None, 66 storage_option=None, 67 feature_freezing_option=None, 68 communication_option=None): 69 self.filter_option = filter_option 70 self.padding_option = padding_option 71 self.evict_option = evict_option 72 self.completion_option = completion_option 73 self.storage_option = storage_option 74 self.feature_freezing_option = feature_freezing_option 75 self.communication_option = communication_option 76 77 78class EsInitializer: 79 """Initializer for embedding service table.""" 80 def __init__(self, initializer_mode, min_scale=-0.01, max_scale=0.01, 81 constant_value=1.0, mu=0.0, sigma=1.0, seed=0): 82 self.initializer_mode = initializer_mode 83 self.min = min_scale 84 self.max = max_scale 85 self.constant_value = constant_value 86 self.mu = mu 87 self.sigma = sigma 88 self.seed = seed 89 90 91class EsOptimizer: 92 """Optimizer for embedding service table.""" 93 def __init__(self, name, initial_accumulator_value=0., ms=0., mom=0.): 94 self.name = name 95 self.initial_accumulator_value = initial_accumulator_value 96 self.ms = ms 97 self.mom = mom 98 99 100def check_common_init_params(name, init_vocabulary_size, embedding_dim): 101 """ 102 Check init params. 103 """ 104 if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None): 105 raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.") 106 if not isinstance(name, str): 107 raise TypeError("embedding table name must be string.") 108 if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)): 109 raise ValueError("init_vocabulary_size and embedding_dim must be int.") 110 if init_vocabulary_size < 0: 111 raise ValueError("init_vocabulary_size can not be smaller than zero.") 112 if embedding_dim <= 0: 113 raise ValueError("embedding_dim must be greater than zero.") 114 115 116class EmbeddingServiceOut: 117 """ 118 EmbeddingServiceOut 119 """ 120 def __init__(self, table_id_dict, es_initializer=None, es_counter_filter=None, 121 es_padding_keys=None, es_completion_keys=None): 122 self.table_id_dict = table_id_dict 123 self.es_initializer = es_initializer 124 self.es_counter_filter = es_counter_filter 125 self.es_padding_keys = es_padding_keys 126 self.es_completion_keys = es_completion_keys 127 128 129class EmbeddingService: 130 """ 131 EmbeddingService 132 """ 133 def __init__(self): 134 """ 135 Init EmbeddingService 136 """ 137 env_dist = os.environ 138 es_cluster_config = env_dist.get("ESCLUSTER_CONFIG_PATH") 139 if es_cluster_config is None: 140 raise ValueError("EsClusterConfig env is null.") 141 self._server_ip_to_ps_num = {} 142 with open(es_cluster_config, encoding='utf-8') as a: 143 es_cluster_config_json = json.load(a) 144 self._es_cluster_conf = json.dumps(es_cluster_config_json) 145 self._ps_num = int(es_cluster_config_json["psNum"]) 146 self._ps_ids = [] 147 self._ps_ids_list = es_cluster_config_json["psCluster"] 148 for each_ps in self._ps_ids_list: 149 self._server_ip_to_ps_num[each_ps["ctrlPanel"]["ipaddr"]] = 0 150 151 for each_ps in self._ps_ids_list: 152 self._ps_ids.append(each_ps["id"]) 153 ctrl_panel = each_ps["ctrlPanel"] 154 self._server_ip_to_ps_num[ctrl_panel["ipaddr"]] += 1 155 156 for each_server_ps_num in self._server_ip_to_ps_num: 157 if self._server_ip_to_ps_num[each_server_ps_num] > 4: 158 raise ValueError("PS num of one server can not exceed 4, please check config params.") 159 if self._ps_num > 4: 160 raise ValueError("PS num of one server can not exceed 4, please check config params.") 161 162 # storage each ps table's params 163 self._table_to_embedding_dim = {} 164 self._table_to_max_num = {} 165 self._table_to_optimizer = {} 166 self._table_to_slot_var_num = {} 167 self._table_to_counter_filter = {} 168 self._table_id_to_padding_key = {} 169 self._table_id_to_completion_key = {} 170 self._train_mode = True 171 self._train_level = False 172 self._optimizer = None 173 self._init_table_flag = False 174 175 self._small_table_name_list = [] 176 self._ps_table_count = 0 177 self._table_name_to_id = {} 178 self._table_id_to_name = {} 179 self._table_id_to_initializer = {} 180 self._table_id_to_steps_to_live = {} 181 182 self._ps_table_id_list = [] 183 # storage lookup: table_id list, lookup result list, lookup key list 184 self._ps_lookup_index = 0 185 # storage all inited table names 186 self._table_name_has_init = [] 187 # only storage all inited PS table names 188 self._ps_table_name_list = [] 189 # now only use for adagrad accum 190 self._ps_table_id_to_optimizer_params = {} 191 192 # use for counter filter 193 self._table_use_counter_filter = {} 194 self._use_counter_filter = False 195 self._use_evict = False 196 self._use_padding_key = False 197 self._use_completion_key = False 198 199 def embedding_init(self, name, init_vocabulary_size, embedding_dim, max_feature_count, 200 initializer=Uniform(scale=0.01), ev_option=None, optimizer=None, optimizer_param=None, 201 mode="train"): 202 """ 203 Init embedding 204 :param name: big table name 205 :param init_vocabulary_size: vocab size 206 :param embedding_dim: embedding dim 207 :param max_feature_count: max feature count 208 :param initializer: mindspore common initializer 209 :param ev_option: output of embedding_variable_option 210 :param optimizer: optimizer 211 :param optimizer_param: optimizer param 212 :param mode: mode, train or predict 213 :return: table_id_dict, es_initializer_dict, es_filter_dict 214 """ 215 check_common_init_params(name=name, init_vocabulary_size=init_vocabulary_size, embedding_dim=embedding_dim) 216 table_id = self._check_and_update_ps_init_params(name=name, init_vocabulary_size=init_vocabulary_size, 217 max_feature_count=max_feature_count, ev_option=ev_option) 218 self._ps_lookup_index = self._ps_table_count 219 self._table_to_embedding_dim[table_id] = embedding_dim 220 self._table_to_max_num[table_id] = max_feature_count 221 # storage the table id for embedding PS table 222 self._ps_table_id_list.append(table_id) 223 self._ps_table_name_list.append(name) 224 225 if len(self._ps_table_id_list) > 10: 226 raise ValueError("Now only 10 PS embedding tables can be init.") 227 bucket_size = math.ceil(init_vocabulary_size / self._ps_num) 228 if optimizer is None: 229 self._train_mode = False 230 self._table_to_slot_var_num[table_id] = 0 231 else: 232 self._check_ps_opt_and_initializer(optimizer=optimizer, initializer=initializer, table_id=table_id) 233 self._optimizer = optimizer 234 self._table_to_optimizer[table_id] = self._optimizer 235 self._ps_table_id_to_optimizer_params[table_id] = [] 236 self._update_optimizer_slot_var_num(table_id=table_id) 237 # new train or continue train from a checkpoint 238 if initializer is not None: 239 self._train_level = True 240 filter_mode = self._init_counter_filter(table_id, ev_option) 241 self._init_padding_key(table_id, ev_option) 242 self._init_completion_key(table_id, ev_option) 243 self._init_optimizer_mode_and_params(table_id, optimizer_param) 244 es_init_layer = ESInitLayer(self._ps_num, self._ps_ids, self._train_mode, self._train_level, table_id, 245 bucket_size, embedding_dim, self._table_to_slot_var_num.get(table_id), 246 self._table_id_to_initializer.get(table_id), filter_mode, optimizer, 247 self._ps_table_id_to_optimizer_params.get(table_id), max_feature_count, mode) 248 es_init_layer() 249 return EmbeddingServiceOut(self._table_name_to_id, self._table_id_to_initializer, 250 self._table_to_counter_filter, self._table_id_to_padding_key, 251 self._table_id_to_completion_key) 252 253 def padding_param(self, padding_key, mask=True, mask_zero=False): 254 """ 255 Init padding key param 256 :param padding_key: padding key 257 :param mask: padding key mask 258 :param mask_zero: mask zero 259 :return: PaddingParamsOption obj 260 """ 261 if not isinstance(padding_key, int): 262 raise TypeError("padding_key must be int, please check.") 263 if not isinstance(mask, bool): 264 raise TypeError("mask must be bool, please check.") 265 self._use_padding_key = True 266 return PaddingParamsOption(padding_key=padding_key, mask=mask, mask_zero=mask_zero) 267 268 def completion_key(self, completion_key, mask=True): 269 """ 270 Init completion key param 271 :param completion_key: completion key 272 :param mask: completion key mask 273 :return: CompletionKeyOption obj 274 """ 275 if not isinstance(completion_key, int): 276 raise TypeError("completion_key must be int, please check.") 277 if not isinstance(mask, bool): 278 raise TypeError("mask must be bool, please check.") 279 self._use_completion_key = True 280 completion_key_mask = 1 if mask is True else 0 281 return CompletionKeyOption(completion_key=completion_key, mask=completion_key_mask) 282 283 def counter_filter(self, filter_freq, default_key=None, default_value=None): 284 """ 285 Set filter_option 286 :param filter_freq: filter freq 287 :param default_key: default key 288 :param default_value: default value 289 :return: CounterFilter obj 290 """ 291 if not isinstance(filter_freq, int): 292 raise TypeError("filter_freq must be int, please check.") 293 if filter_freq < 0: 294 raise ValueError("filter_freq must can not be smaller than 0.") 295 if (default_key is None) and (default_value is None): 296 raise ValueError("default_key and default_value can not be both None.") 297 if (default_key is not None) and (default_value is not None): 298 raise ValueError("default_key and default_value can not be both set.") 299 if default_key is None and (not isinstance(default_value, (int, float))): 300 raise TypeError("When default_value is not None, it must be float or int, please check.") 301 if default_value is None and (not isinstance(default_key, int)): 302 raise TypeError("When default_key is not None, it must be int, please check.") 303 self._use_counter_filter = True 304 if default_key is None: 305 return CounterFilter(filter_freq=filter_freq, default_key_or_value=0, 306 default_key=0, default_value=default_value) 307 return CounterFilter(filter_freq=filter_freq, default_key_or_value=1, 308 default_key=default_key, default_value=1) 309 310 def evict_option(self, steps_to_live): 311 """ 312 Set evict_option 313 :param steps_to_live: steps to live 314 :return: EvictOption obj 315 """ 316 if not isinstance(steps_to_live, int): 317 raise TypeError("steps_to_live must be int, please check.") 318 if steps_to_live <= 0: 319 raise ValueError("steps_to_live must must be greater than 0.") 320 self._use_evict = True 321 return EvictOption(steps_to_live=steps_to_live) 322 323 def embedding_variable_option(self, filter_option=None, padding_option=None, evict_option=None, 324 completion_option=None, storage_option=None, feature_freezing_option=None, 325 communication_option=None): 326 """ 327 Set embedding variable option 328 :param filter_option: filter policy, is the output of counter_filter 329 :param padding_option: padding policy, is the output of padding_keys 330 :param evict_option: evict policy 331 :param completion_option: not support 332 :param storage_option: not support 333 :param feature_freezing_option: not support 334 :param communication_option: not support 335 :return: EmbeddingVariableOption obj 336 """ 337 if (filter_option is not None) and (not isinstance(filter_option, CounterFilter)): 338 raise ValueError("If padding_option isn't None, it must be CounterFilter type.") 339 if filter_option is not None: 340 self._use_counter_filter = True 341 if (padding_option is not None) and (not isinstance(padding_option, PaddingParamsOption)): 342 raise TypeError("If padding_option isn't None, it must be EmbeddingPaddingParamsOption type.") 343 if (completion_option is not None) and (not isinstance(completion_option, CompletionKeyOption)): 344 raise TypeError("If completion_option isn't None, it must be EmbeddingPaddingCompletionKeyOption type.") 345 if (evict_option is not None) and (not isinstance(evict_option, EvictOption)): 346 raise TypeError("When evict_option is not None, it must be EvictOption type.") 347 return EmbeddingVariableOption(filter_option=filter_option, padding_option=padding_option, 348 evict_option=evict_option, completion_option=completion_option, 349 storage_option=storage_option, feature_freezing_option=feature_freezing_option, 350 communication_option=communication_option) 351 352 def embedding_ckpt_export(self, file_path): 353 """ 354 Export big table ckpt 355 :param file_path: the file path to storage ckpt ret 356 :return: 357 """ 358 embedding_dim_list = [] 359 value_total_len_list = [] 360 steps_to_live_list = [] 361 for table_id in self._ps_table_id_list: 362 embedding_dim_list.append(self._table_to_embedding_dim.get(table_id)) 363 value_total_len_list.append(self._table_to_embedding_dim.get(table_id) * 364 (self._table_to_slot_var_num.get(table_id) + 1) + 2) 365 steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0)) 366 embedding_ckpt_export_layer = ESEmbeddingCKPTExport(embedding_dim_list, value_total_len_list, 367 self._ps_table_name_list, self._ps_table_id_list, 368 file_path, steps_to_live_list) 369 embedding_ckpt_export_layer() 370 371 def embedding_table_export(self, file_path): 372 """ 373 Export big table embedding 374 :param file_path: the file path to storage embedding ret 375 :return: 376 """ 377 embedding_dim_list = [] 378 steps_to_live_list = [] 379 for table_id in self._ps_table_id_list: 380 embedding_dim_list.append(self._table_to_embedding_dim.get(table_id)) 381 steps_to_live_list.append(self._table_id_to_steps_to_live.get(table_id, 0)) 382 383 embedding_table_export_layer = ESEmbeddingTableExport(embedding_dim_list, embedding_dim_list, 384 self._ps_table_name_list, self._ps_table_id_list, 385 file_path, steps_to_live_list) 386 embedding_table_export_layer() 387 388 def embedding_ckpt_import(self, file_path): 389 """ 390 Import big table ckpt 391 :param file_path: the file path to import ckpt ret 392 :return: 393 """ 394 embedding_dim_list = [] 395 value_total_len_list = [] 396 for table_id in self._ps_table_id_list: 397 embedding_dim_list.append(self._table_to_embedding_dim.get(table_id)) 398 value_total_len_list.append(self._table_to_embedding_dim.get(table_id) * 399 (self._table_to_slot_var_num.get(table_id) + 1) + 2) 400 401 embedding_ckpt_export_layer = ESEmbeddingCKPTImport(embedding_dim_list, value_total_len_list, 402 self._ps_table_name_list, self._ps_table_id_list, 403 file_path) 404 embedding_ckpt_export_layer() 405 406 def embedding_table_import(self, file_path): 407 """ 408 Import big table embedding 409 :param file_path: the file path to import embedding ret 410 :return: 411 """ 412 embedding_dim_list = [] 413 for table_id in self._ps_table_id_list: 414 embedding_dim_list.append(self._table_to_embedding_dim.get(table_id)) 415 embedding_table_export_layer = ESEmbeddingTableImport(embedding_dim_list, embedding_dim_list, 416 self._ps_table_name_list, self._ps_table_id_list, 417 file_path) 418 embedding_table_export_layer() 419 420 def _check_and_update_ps_init_params(self, name, init_vocabulary_size, max_feature_count, ev_option): 421 """ 422 Check parameter server params and init table id 423 """ 424 steps_to_live = 0 425 if max_feature_count is None: 426 raise ValueError("For ps table, max_feature_count can not be None.") 427 if (ev_option is not None) and (not isinstance(ev_option, EmbeddingVariableOption)): 428 raise TypeError("For ps table, ev_option must be EmbeddingVariableOption type.") 429 if (ev_option is not None) and (ev_option.evict_option is not None): 430 steps_to_live = ev_option.evict_option.steps_to_live 431 if not isinstance(max_feature_count, int): 432 raise ValueError("For ps table, max_feature_count must be int.") 433 if init_vocabulary_size >= _INT32_MAX_VALUE: 434 raise ValueError("init_vocabulary_size exceeds int32 max value.") 435 if max_feature_count <= 0: 436 raise ValueError("For ps table, max_feature_count must be greater than zero.") 437 if name not in self._table_name_has_init: 438 table_id = self._ps_table_count 439 self._table_name_to_id[name] = table_id 440 self._table_id_to_name[table_id] = name 441 self._table_id_to_steps_to_live[table_id] = steps_to_live 442 self._ps_table_count += 1 443 self._table_name_has_init.append(name) 444 else: 445 raise ValueError("This table has been initialized.") 446 return table_id 447 448 def _check_ps_opt_and_initializer(self, optimizer, initializer, table_id): 449 """ 450 Check args of parameter server 451 :param optimizer: the optimizer type, just support adam now 452 :param initializer: mindspore common initializer 453 :param table_id: table id 454 :return: 455 """ 456 if optimizer not in ["adam", "adagrad", "adamw", "ftrl"]: 457 raise ValueError("optimizer should be one of adam, adagrad, adamw, ftrl") 458 if initializer is not None: 459 if isinstance(initializer, EsInitializer): 460 self._table_id_to_initializer[table_id] = initializer 461 elif isinstance(initializer, TruncatedNormal): 462 self._table_id_to_initializer[table_id] = \ 463 EsInitializer(initializer_mode="truncated_normal", mu=initializer.mean, 464 sigma=initializer.sigma, seed=initializer.seed[0]) 465 elif isinstance(initializer, Uniform): 466 self._table_id_to_initializer[table_id] = \ 467 EsInitializer(initializer_mode="random_uniform", 468 min_scale=-initializer.scale, 469 max_scale=initializer.scale, seed=initializer.seed[0]) 470 elif isinstance(initializer, Constant): 471 self._table_id_to_initializer[table_id] = \ 472 EsInitializer(initializer_mode="constant", constant_value=initializer.value) 473 else: 474 raise TypeError("initializer must be EsInitializer or mindspore initializer, and only support" 475 "Uniform, TruncatedNormal and Constant value.") 476 477 def _update_optimizer_slot_var_num(self, table_id): 478 """ 479 Update _table_to_slot_var_num by diff optimizer 480 """ 481 # adam, adamw, rmsprop include m and v, 2 slots; adagrad include accumulator, 1 slot; sgd include 0 slot 482 if self._optimizer == "adagrad": 483 self._table_to_slot_var_num[table_id] = 1 484 elif self._optimizer == "sgd": 485 self._table_to_slot_var_num[table_id] = 0 486 else: 487 self._table_to_slot_var_num[table_id] = 2 488 489 def _init_counter_filter(self, table_id, ev_option): 490 """ 491 Init counter filter params 492 """ 493 if (ev_option is not None) and (ev_option.filter_option is not None): 494 filter_mode = "counter" 495 self._table_to_counter_filter[table_id] = ev_option.filter_option 496 self._table_use_counter_filter[table_id] = 1 497 else: 498 filter_mode = "no_filter" 499 self._table_use_counter_filter[table_id] = 0 500 return filter_mode 501 502 def _init_padding_key(self, table_id, ev_option): 503 """ 504 Init padding key params 505 """ 506 if (ev_option is not None) and (ev_option.padding_option is not None): 507 self._table_id_to_padding_key[table_id] = ev_option.padding_option 508 509 def _init_completion_key(self, table_id, ev_option): 510 """ 511 Init completion key params 512 """ 513 if (ev_option is not None) and (ev_option.completion_option is not None): 514 self._table_id_to_completion_key[table_id] = ev_option.completion_option 515 516 def _init_optimizer_mode_and_params(self, table_id, optimizer_param): 517 """ 518 Init _ps_table_id_to_optimizer_params by diff optimizer 519 """ 520 optimizer = self._table_to_optimizer.get(table_id) 521 if optimizer is None: 522 return 523 if optimizer in ["adagrad", "ftrl"]: 524 if optimizer_param is not None: 525 self._ps_table_id_to_optimizer_params[table_id].append(optimizer_param) 526 else: 527 raise ValueError("For adagrad optimizer, optimizer_param should have 1 param, " 528 "initial_accumulator_value.") 529 530 if optimizer in ["adam", "adamw", "sgd", "ftrl"]: 531 self._ps_table_id_to_optimizer_params[table_id].append(0.) 532