1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_training """ 16import os 17 18import numpy as np 19from sklearn.metrics import roc_auc_score 20import mindspore.common.dtype as mstype 21from mindspore.ops import functional as F 22from mindspore.ops import composite as C 23from mindspore.ops import operations as P 24from mindspore.nn import Dropout 25from mindspore.nn.optim import Adam 26from mindspore.nn.metrics import Metric 27from mindspore import nn, Tensor, ParameterTuple, Parameter 28from mindspore.common.initializer import Uniform, initializer 29from mindspore.train.callback import ModelCheckpoint, CheckpointConfig 30from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean 31from mindspore.context import ParallelMode 32from mindspore.nn.wrap.grad_reducer import DistributedGradReducer 33 34from .callback import EvalCallBack, LossCallBack 35 36 37np_type = np.float32 38ms_type = mstype.float32 39 40 41class AUCMetric(Metric): 42 """AUC metric for DeepFM model.""" 43 def __init__(self): 44 super(AUCMetric, self).__init__() 45 self.pred_probs = [] 46 self.true_labels = [] 47 48 def clear(self): 49 """Clear the internal evaluation result.""" 50 self.pred_probs = [] 51 self.true_labels = [] 52 53 def update(self, *inputs): 54 batch_predict = inputs[1].asnumpy() 55 batch_label = inputs[2].asnumpy() 56 self.pred_probs.extend(batch_predict.flatten().tolist()) 57 self.true_labels.extend(batch_label.flatten().tolist()) 58 59 def eval(self): 60 if len(self.true_labels) != len(self.pred_probs): 61 raise RuntimeError('true_labels.size() is not equal to pred_probs.size()') 62 auc = roc_auc_score(self.true_labels, self.pred_probs) 63 return auc 64 65 66def init_method(method, shape, name, max_val=1.0): 67 """ 68 The method of init parameters. 69 70 Args: 71 method (str): The method uses to initialize parameter. 72 shape (list): The shape of parameter. 73 name (str): The name of parameter. 74 max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter. 75 76 Returns: 77 Parameter. 78 """ 79 if method in ['uniform']: 80 params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) 81 elif method == "one": 82 params = Parameter(initializer("ones", shape, ms_type), name=name) 83 elif method == 'zero': 84 params = Parameter(initializer("zeros", shape, ms_type), name=name) 85 elif method == "normal": 86 params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name) 87 return params 88 89 90def init_var_dict(init_args, var_list): 91 """ 92 Init parameter. 93 94 Args: 95 init_args (list): Define max and min value of parameters. 96 values (list): Define name, shape and init method of parameters. 97 98 Returns: 99 dict, a dict ot Parameter. 100 """ 101 var_map = {} 102 _, max_val = init_args 103 for i, _ in enumerate(var_list): 104 key, shape, method = var_list[i] 105 if key not in var_map.keys(): 106 if method in ['random', 'uniform']: 107 var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key) 108 elif method == "one": 109 var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) 110 elif method == "zero": 111 var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) 112 elif method == 'normal': 113 var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape). 114 astype(dtype=np_type)), name=key) 115 return var_map 116 117 118class DenseLayer(nn.Cell): 119 """ 120 Dense Layer for Deep Layer of DeepFM Model; 121 Containing: activation, matmul, bias_add; 122 Args: 123 input_dim (int): the shape of weight at 0-aixs; 124 output_dim (int): the shape of weight at 1-aixs, and shape of bias 125 weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal"; 126 act_str (str): activation function method, "relu", "sigmoid", "tanh"; 127 keep_prob (float): Dropout Layer keep_prob_rate; 128 scale_coef (float): input scale coefficient; 129 """ 130 131 def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True, 132 use_act=True): 133 super(DenseLayer, self).__init__() 134 weight_init, bias_init = weight_bias_init 135 self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") 136 self.bias = init_method(bias_init, [output_dim], name="bias") 137 self.act_func = self._init_activation(act_str) 138 self.matmul = P.MatMul(transpose_b=False) 139 self.bias_add = P.BiasAdd() 140 self.cast = P.Cast() 141 self.dropout = Dropout(keep_prob=1.0) 142 self.mul = P.Mul() 143 self.realDiv = P.RealDiv() 144 self.scale_coef = scale_coef 145 self.convert_dtype = convert_dtype 146 self.use_act = use_act 147 148 def _init_activation(self, act_str): 149 """Init activation function""" 150 act_str = act_str.lower() 151 if act_str == "relu": 152 act_func = P.ReLU() 153 elif act_str == "sigmoid": 154 act_func = P.Sigmoid() 155 elif act_str == "tanh": 156 act_func = P.Tanh() 157 return act_func 158 159 def construct(self, x): 160 """Construct function""" 161 x = self.dropout(x) 162 if self.convert_dtype: 163 x = self.cast(x, mstype.float16) 164 weight = self.cast(self.weight, mstype.float16) 165 bias = self.cast(self.bias, mstype.float16) 166 wx = self.matmul(x, weight) 167 wx = self.bias_add(wx, bias) 168 if self.use_act: 169 wx = self.act_func(wx) 170 wx = self.cast(wx, mstype.float32) 171 else: 172 wx = self.matmul(x, self.weight) 173 wx = self.bias_add(wx, self.bias) 174 if self.use_act: 175 wx = self.act_func(wx) 176 return wx 177 178 179class DeepFMModel(nn.Cell): 180 """ 181 From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction" 182 183 Args: 184 batch_size (int): smaple_number of per step in training; (int, batch_size=128) 185 filed_size (int): input filed number, or called id_feature number; (int, filed_size=39) 186 vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000) 187 emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100) 188 deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator; 189 (int, deep_layer_args=[[100, 100, 100], "relu"]) 190 init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds]) 191 weight_bias_init (list): weight, bias init method for deep layers; 192 (list[str], weight_bias_init=['random', 'zero']) 193 keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) 194 """ 195 196 def __init__(self, config): 197 super(DeepFMModel, self).__init__() 198 199 self.batch_size = config.batch_size 200 self.field_size = config.data_field_size 201 self.vocab_size = config.data_vocab_size 202 self.emb_dim = config.data_emb_dim 203 self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args 204 self.init_args = config.init_args 205 self.weight_bias_init = config.weight_bias_init 206 self.keep_prob = config.keep_prob 207 init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), 208 ('V_l2', [self.vocab_size, self.emb_dim], 'normal')] 209 var_map = init_var_dict(self.init_args, init_acts) 210 self.fm_w = var_map["W_l2"] 211 self.embedding_table = var_map["V_l2"] 212 " Deep Layers " 213 self.deep_input_dims = self.field_size * self.emb_dim 214 self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] 215 self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init, 216 self.deep_layer_act, self.keep_prob, convert_dtype=True) 217 self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init, 218 self.deep_layer_act, self.keep_prob, convert_dtype=True) 219 self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init, 220 self.deep_layer_act, self.keep_prob, convert_dtype=True) 221 self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init, 222 self.deep_layer_act, self.keep_prob, convert_dtype=True) 223 self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init, 224 self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False) 225 " FM, linear Layers " 226 self.Gatherv2 = P.Gather() 227 self.Mul = P.Mul() 228 self.ReduceSum = P.ReduceSum(keep_dims=False) 229 self.Reshape = P.Reshape() 230 self.Square = P.Square() 231 self.Shape = P.Shape() 232 self.Tile = P.Tile() 233 self.Concat = P.Concat(axis=1) 234 self.Cast = P.Cast() 235 236 def construct(self, id_hldr, wt_hldr): 237 """ 238 Args: 239 id_hldr: batch ids; [bs, field_size] 240 wt_hldr: batch weights; [bs, field_size] 241 """ 242 243 mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1)) 244 # Linear layer 245 fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0) 246 wx = self.Mul(fm_id_weight, mask) 247 linear_out = self.ReduceSum(wx, 1) 248 # FM layer 249 fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0) 250 vx = self.Mul(fm_id_embs, mask) 251 v1 = self.ReduceSum(vx, 1) 252 v1 = self.Square(v1) 253 v2 = self.Square(vx) 254 v2 = self.ReduceSum(v2, 1) 255 fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) 256 fm_out = self.Reshape(fm_out, (-1, 1)) 257 # Deep layer 258 deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) 259 deep_in = self.dense_layer_1(deep_in) 260 deep_in = self.dense_layer_2(deep_in) 261 deep_in = self.dense_layer_3(deep_in) 262 deep_in = self.dense_layer_4(deep_in) 263 deep_out = self.dense_layer_5(deep_in) 264 out = linear_out + fm_out + deep_out 265 return out, self.fm_w, self.embedding_table 266 267 268class NetWithLossClass(nn.Cell): 269 """ 270 NetWithLossClass definition. 271 """ 272 def __init__(self, network, l2_coef=1e-6): 273 super(NetWithLossClass, self).__init__(auto_prefix=False) 274 self.loss = P.SigmoidCrossEntropyWithLogits() 275 self.network = network 276 self.l2_coef = l2_coef 277 self.Square = P.Square() 278 self.ReduceMean_false = P.ReduceMean(keep_dims=False) 279 self.ReduceSum_false = P.ReduceSum(keep_dims=False) 280 281 def construct(self, batch_ids, batch_wts, label): 282 predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts) 283 log_loss = self.loss(predict, label) 284 mean_log_loss = self.ReduceMean_false(log_loss) 285 l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight)) 286 l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs)) 287 l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5 288 loss = mean_log_loss + l2_loss_all 289 return loss 290 291 292class TrainStepWrap(nn.Cell): 293 """ 294 TrainStepWrap definition 295 """ 296 def __init__(self, network, lr, eps, loss_scale=1000.0): 297 super(TrainStepWrap, self).__init__(auto_prefix=False) 298 self.network = network 299 self.network.set_train() 300 self.weights = ParameterTuple(network.trainable_params()) 301 self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) 302 self.hyper_map = C.HyperMap() 303 self.grad = C.GradOperation(get_by_list=True, sens_param=True) 304 self.sens = loss_scale 305 306 self.reducer_flag = False 307 self.grad_reducer = None 308 parallel_mode = _get_parallel_mode() 309 if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): 310 self.reducer_flag = True 311 if self.reducer_flag: 312 mean = _get_gradients_mean() 313 degree = _get_device_num() 314 self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree) 315 316 def construct(self, batch_ids, batch_wts, label): 317 weights = self.weights 318 loss = self.network(batch_ids, batch_wts, label) 319 sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # 320 grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) 321 if self.reducer_flag: 322 # apply grad reducer on grads 323 grads = self.grad_reducer(grads) 324 return F.depend(loss, self.optimizer(grads)) 325 326 327class PredictWithSigmoid(nn.Cell): 328 """ 329 Eval model with sigmoid. 330 """ 331 def __init__(self, network): 332 super(PredictWithSigmoid, self).__init__(auto_prefix=False) 333 self.network = network 334 self.sigmoid = P.Sigmoid() 335 336 def construct(self, batch_ids, batch_wts, labels): 337 logits, _, _, = self.network(batch_ids, batch_wts) 338 pred_probs = self.sigmoid(logits) 339 340 return logits, pred_probs, labels 341 342 343class ModelBuilder: 344 """ 345 Model builder for DeepFM. 346 347 Args: 348 model_config (ModelConfig): Model configuration. 349 train_config (TrainConfig): Train configuration. 350 """ 351 def __init__(self, model_config, train_config): 352 self.model_config = model_config 353 self.train_config = train_config 354 355 def get_callback_list(self, model=None, eval_dataset=None): 356 """ 357 Get callbacks which contains checkpoint callback, eval callback and loss callback. 358 359 Args: 360 model (Cell): The network is added callback (default=None). 361 eval_dataset (Dataset): Dataset for eval (default=None). 362 """ 363 callback_list = [] 364 if self.train_config.save_checkpoint: 365 config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps, 366 keep_checkpoint_max=self.train_config.keep_checkpoint_max) 367 ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix, 368 directory=self.train_config.output_path, 369 config=config_ck) 370 callback_list.append(ckpt_cb) 371 if self.train_config.eval_callback: 372 if model is None: 373 raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format( 374 self.train_config.eval_callback, model)) 375 if eval_dataset is None: 376 raise RuntimeError("train_config.eval_callback is {}; get_callback_list() " 377 "args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset)) 378 auc_metric = AUCMetric() 379 eval_callback = EvalCallBack(model, eval_dataset, auc_metric, 380 eval_file_path=os.path.join(self.train_config.output_path, 381 self.train_config.eval_file_name)) 382 callback_list.append(eval_callback) 383 if self.train_config.loss_callback: 384 loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path, 385 self.train_config.loss_file_name)) 386 callback_list.append(loss_callback) 387 if callback_list: 388 return callback_list 389 return None 390 391 def get_train_eval_net(self): 392 deepfm_net = DeepFMModel(self.model_config) 393 loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef) 394 train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate, 395 eps=self.train_config.epsilon, 396 loss_scale=self.train_config.loss_scale) 397 eval_net = PredictWithSigmoid(deepfm_net) 398 return train_net, eval_net 399