1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import os 17import numpy as np 18import mindspore as ms 19from mindspore import nn, ops, Tensor 20from mindspore.nn.layer.embedding_service import EmbeddingService 21from mindspore.nn.layer.embedding_service_layer import EsEmbeddingLookup 22from mindspore.communication import init, release, get_rank 23from mindspore import context 24 25 26class Net(nn.Cell): 27 """ 28 EsNet 29 """ 30 def __init__(self, embedding_dim, max_feature_count, table_id_dict=None, es_initializer=None, 31 es_counter_filter=None): 32 super(Net, self).__init__() 33 self.table_id = table_id_dict["test"] 34 self.embedding = EsEmbeddingLookup(self.table_id, es_initializer[self.table_id], embedding_dim=[embedding_dim], 35 max_key_num=max_feature_count, optimizer_mode="adam", 36 optimizer_params=[0.0, 0.0], 37 es_filter=es_counter_filter[self.table_id]) 38 self.w = ms.Parameter(Tensor([1.5], ms.float32), name="w", requires_grad=True) 39 40 def construct(self, keys, actual_keys_input=None, unique_indices=None, key_count=None): 41 if (actual_keys_input is not None) and (unique_indices is not None): 42 es_out = self.embedding(keys, actual_keys_input, unique_indices, key_count) 43 else: 44 es_out = self.embedding(keys) 45 output = es_out * self.w 46 return output 47 48 49class NetworkWithLoss(nn.Cell): 50 """ 51 NetworkWithLoss 52 """ 53 def __init__(self, network, loss): 54 super(NetworkWithLoss, self).__init__() 55 self.network = network 56 self.loss_fn = loss 57 58 def construct(self, x, label): 59 logits = self.network(x) 60 loss = self.loss_fn(logits, label) 61 return loss 62 63 64def train(): 65 """ 66 train net. 67 """ 68 init() 69 vocab_size = 1000 70 embedding_dim = 12 71 feature_length = 16 72 context.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 73 74 es = EmbeddingService() 75 filter_option = es.counter_filter(filter_freq=2, default_value=10.0) 76 ev_option = es.embedding_variable_option(filter_option=filter_option) 77 78 table_id_dict, es_initializer, es_counter_filter = es.embedding_init("test", init_vocabulary_size=vocab_size, 79 embedding_dim=embedding_dim, 80 max_feature_count=feature_length, 81 optimizer="adam", ev_option=ev_option, 82 mode="train") 83 if "test" not in table_id_dict: 84 raise ValueError("embedding_init error, not contain test table!") 85 if len(es_initializer) != 1 or len(es_counter_filter) != 1: 86 raise ValueError("embedding_init error, table len should be 1!") 87 88 print("Succ do embedding_init: ", table_id_dict, es_initializer, es_counter_filter, flush=True) 89 90 91 net = Net(embedding_dim, feature_length, table_id_dict, es_initializer, es_counter_filter) 92 loss_fn = ops.SigmoidCrossEntropyWithLogits() 93 optimizer = nn.Adam(params=net.trainable_params(), learning_rate=1e-3) 94 net_with_loss = NetworkWithLoss(net, loss_fn) 95 train_network = nn.TrainOneStepCell(net_with_loss, optimizer=optimizer) 96 train_network.set_train() 97 98 data = Tensor(np.array(np.ones((2, 8)), dtype=np.float32)) 99 label = Tensor(np.array(np.ones((2, 8, 12)), dtype=np.float32)) 100 101 loss = train_network(data, label) 102 print("Succ do train, loss is: ", loss, flush=True) 103 104 rank = get_rank() 105 print("Succ get rank, rank is: ", rank, flush=True) 106 if rank == 0: 107 save_embedding_path = os.path.join(os.getcwd(), "embedding") 108 save_ckpt_path = os.path.join(os.getcwd(), "ckpt") 109 print("After get path is: ", save_embedding_path, save_ckpt_path, flush=True) 110 es.embedding_table_export(save_embedding_path) 111 print("Succ do export embedding.", flush=True) 112 es.embedding_ckpt_export(save_ckpt_path) 113 print("Succ do export ckpt.", flush=True) 114 115 es.embedding_ckpt_import(save_ckpt_path) 116 print("Succ do import embedding.", flush=True) 117 118 release() 119 120 121if __name__ == "__main__": 122 train() 123