• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_training """
16import os
17
18import numpy as np
19from sklearn.metrics import roc_auc_score
20import mindspore.common.dtype as mstype
21from mindspore.ops import functional as F
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from mindspore.nn import Dropout
25from mindspore.nn.optim import Adam
26from mindspore.nn.metrics import Metric
27from mindspore import nn, Tensor, ParameterTuple, Parameter
28from mindspore.common.initializer import Uniform, initializer
29from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
30from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
31from mindspore.context import ParallelMode
32from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
33
34from .callback import EvalCallBack, LossCallBack
35
36
37np_type = np.float32
38ms_type = mstype.float32
39
40
41class AUCMetric(Metric):
42    """AUC metric for DeepFM model."""
43    def __init__(self):
44        super(AUCMetric, self).__init__()
45        self.pred_probs = []
46        self.true_labels = []
47
48    def clear(self):
49        """Clear the internal evaluation result."""
50        self.pred_probs = []
51        self.true_labels = []
52
53    def update(self, *inputs):
54        batch_predict = inputs[1].asnumpy()
55        batch_label = inputs[2].asnumpy()
56        self.pred_probs.extend(batch_predict.flatten().tolist())
57        self.true_labels.extend(batch_label.flatten().tolist())
58
59    def eval(self):
60        if len(self.true_labels) != len(self.pred_probs):
61            raise RuntimeError('true_labels.size() is not equal to pred_probs.size()')
62        auc = roc_auc_score(self.true_labels, self.pred_probs)
63        return auc
64
65
66def init_method(method, shape, name, max_val=1.0):
67    """
68    The method of init parameters.
69
70    Args:
71        method (str): The method uses to initialize parameter.
72        shape (list): The shape of parameter.
73        name (str): The name of parameter.
74        max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.
75
76    Returns:
77        Parameter.
78    """
79    if method in ['uniform']:
80        params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
81    elif method == "one":
82        params = Parameter(initializer("ones", shape, ms_type), name=name)
83    elif method == 'zero':
84        params = Parameter(initializer("zeros", shape, ms_type), name=name)
85    elif method == "normal":
86        params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
87    return params
88
89
90def init_var_dict(init_args, var_list):
91    """
92    Init parameter.
93
94    Args:
95        init_args (list): Define max and min value of parameters.
96        values (list): Define name, shape and init method of parameters.
97
98    Returns:
99        dict, a dict ot Parameter.
100    """
101    var_map = {}
102    _, max_val = init_args
103    for i, _ in enumerate(var_list):
104        key, shape, method = var_list[i]
105        if key not in var_map.keys():
106            if method in ['random', 'uniform']:
107                var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key)
108            elif method == "one":
109                var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
110            elif method == "zero":
111                var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
112            elif method == 'normal':
113                var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).
114                                                astype(dtype=np_type)), name=key)
115    return var_map
116
117
118class DenseLayer(nn.Cell):
119    """
120    Dense Layer for Deep Layer of DeepFM Model;
121    Containing: activation, matmul, bias_add;
122    Args:
123        input_dim (int): the shape of weight at 0-aixs;
124        output_dim (int): the shape of weight at 1-aixs, and shape of bias
125        weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal";
126        act_str (str): activation function method, "relu", "sigmoid", "tanh";
127        keep_prob (float): Dropout Layer keep_prob_rate;
128        scale_coef (float): input scale coefficient;
129    """
130
131    def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True,
132                 use_act=True):
133        super(DenseLayer, self).__init__()
134        weight_init, bias_init = weight_bias_init
135        self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
136        self.bias = init_method(bias_init, [output_dim], name="bias")
137        self.act_func = self._init_activation(act_str)
138        self.matmul = P.MatMul(transpose_b=False)
139        self.bias_add = P.BiasAdd()
140        self.cast = P.Cast()
141        self.dropout = Dropout(keep_prob=1.0)
142        self.mul = P.Mul()
143        self.realDiv = P.RealDiv()
144        self.scale_coef = scale_coef
145        self.convert_dtype = convert_dtype
146        self.use_act = use_act
147
148    def _init_activation(self, act_str):
149        """Init activation function"""
150        act_str = act_str.lower()
151        if act_str == "relu":
152            act_func = P.ReLU()
153        elif act_str == "sigmoid":
154            act_func = P.Sigmoid()
155        elif act_str == "tanh":
156            act_func = P.Tanh()
157        return act_func
158
159    def construct(self, x):
160        """Construct function"""
161        x = self.dropout(x)
162        if self.convert_dtype:
163            x = self.cast(x, mstype.float16)
164            weight = self.cast(self.weight, mstype.float16)
165            bias = self.cast(self.bias, mstype.float16)
166            wx = self.matmul(x, weight)
167            wx = self.bias_add(wx, bias)
168            if self.use_act:
169                wx = self.act_func(wx)
170            wx = self.cast(wx, mstype.float32)
171        else:
172            wx = self.matmul(x, self.weight)
173            wx = self.bias_add(wx, self.bias)
174            if self.use_act:
175                wx = self.act_func(wx)
176        return wx
177
178
179class DeepFMModel(nn.Cell):
180    """
181    From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction"
182
183    Args:
184        batch_size (int):  smaple_number of per step in training; (int, batch_size=128)
185        filed_size (int):  input filed number, or called id_feature number; (int, filed_size=39)
186        vocab_size (int):  id_feature vocab size, id dict size;  (int, vocab_size=200000)
187        emb_dim (int):  id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100)
188        deep_layer_args (list):  Deep Layer args, layer_dim_list, layer_activator;
189                             (int, deep_layer_args=[[100, 100, 100], "relu"])
190        init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds])
191        weight_bias_init (list): weight, bias init method for deep layers;
192                            (list[str], weight_bias_init=['random', 'zero'])
193        keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
194    """
195
196    def __init__(self, config):
197        super(DeepFMModel, self).__init__()
198
199        self.batch_size = config.batch_size
200        self.field_size = config.data_field_size
201        self.vocab_size = config.data_vocab_size
202        self.emb_dim = config.data_emb_dim
203        self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args
204        self.init_args = config.init_args
205        self.weight_bias_init = config.weight_bias_init
206        self.keep_prob = config.keep_prob
207        init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
208                     ('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
209        var_map = init_var_dict(self.init_args, init_acts)
210        self.fm_w = var_map["W_l2"]
211        self.embedding_table = var_map["V_l2"]
212        " Deep Layers "
213        self.deep_input_dims = self.field_size * self.emb_dim
214        self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
215        self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
216                                        self.deep_layer_act, self.keep_prob, convert_dtype=True)
217        self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
218                                        self.deep_layer_act, self.keep_prob, convert_dtype=True)
219        self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
220                                        self.deep_layer_act, self.keep_prob, convert_dtype=True)
221        self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
222                                        self.deep_layer_act, self.keep_prob, convert_dtype=True)
223        self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
224                                        self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
225        " FM, linear Layers "
226        self.Gatherv2 = P.Gather()
227        self.Mul = P.Mul()
228        self.ReduceSum = P.ReduceSum(keep_dims=False)
229        self.Reshape = P.Reshape()
230        self.Square = P.Square()
231        self.Shape = P.Shape()
232        self.Tile = P.Tile()
233        self.Concat = P.Concat(axis=1)
234        self.Cast = P.Cast()
235
236    def construct(self, id_hldr, wt_hldr):
237        """
238        Args:
239            id_hldr: batch ids;   [bs, field_size]
240            wt_hldr: batch weights;   [bs, field_size]
241        """
242
243        mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1))
244        # Linear layer
245        fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0)
246        wx = self.Mul(fm_id_weight, mask)
247        linear_out = self.ReduceSum(wx, 1)
248        # FM layer
249        fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0)
250        vx = self.Mul(fm_id_embs, mask)
251        v1 = self.ReduceSum(vx, 1)
252        v1 = self.Square(v1)
253        v2 = self.Square(vx)
254        v2 = self.ReduceSum(v2, 1)
255        fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
256        fm_out = self.Reshape(fm_out, (-1, 1))
257        #  Deep layer
258        deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
259        deep_in = self.dense_layer_1(deep_in)
260        deep_in = self.dense_layer_2(deep_in)
261        deep_in = self.dense_layer_3(deep_in)
262        deep_in = self.dense_layer_4(deep_in)
263        deep_out = self.dense_layer_5(deep_in)
264        out = linear_out + fm_out + deep_out
265        return out, self.fm_w, self.embedding_table
266
267
268class NetWithLossClass(nn.Cell):
269    """
270    NetWithLossClass definition.
271    """
272    def __init__(self, network, l2_coef=1e-6):
273        super(NetWithLossClass, self).__init__(auto_prefix=False)
274        self.loss = P.SigmoidCrossEntropyWithLogits()
275        self.network = network
276        self.l2_coef = l2_coef
277        self.Square = P.Square()
278        self.ReduceMean_false = P.ReduceMean(keep_dims=False)
279        self.ReduceSum_false = P.ReduceSum(keep_dims=False)
280
281    def construct(self, batch_ids, batch_wts, label):
282        predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts)
283        log_loss = self.loss(predict, label)
284        mean_log_loss = self.ReduceMean_false(log_loss)
285        l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight))
286        l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs))
287        l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5
288        loss = mean_log_loss + l2_loss_all
289        return loss
290
291
292class TrainStepWrap(nn.Cell):
293    """
294    TrainStepWrap definition
295    """
296    def __init__(self, network, lr, eps, loss_scale=1000.0):
297        super(TrainStepWrap, self).__init__(auto_prefix=False)
298        self.network = network
299        self.network.set_train()
300        self.weights = ParameterTuple(network.trainable_params())
301        self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale)
302        self.hyper_map = C.HyperMap()
303        self.grad = C.GradOperation(get_by_list=True, sens_param=True)
304        self.sens = loss_scale
305
306        self.reducer_flag = False
307        self.grad_reducer = None
308        parallel_mode = _get_parallel_mode()
309        if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
310            self.reducer_flag = True
311        if self.reducer_flag:
312            mean = _get_gradients_mean()
313            degree = _get_device_num()
314            self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
315
316    def construct(self, batch_ids, batch_wts, label):
317        weights = self.weights
318        loss = self.network(batch_ids, batch_wts, label)
319        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)  #
320        grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
321        if self.reducer_flag:
322            # apply grad reducer on grads
323            grads = self.grad_reducer(grads)
324        return F.depend(loss, self.optimizer(grads))
325
326
327class PredictWithSigmoid(nn.Cell):
328    """
329    Eval model with sigmoid.
330    """
331    def __init__(self, network):
332        super(PredictWithSigmoid, self).__init__(auto_prefix=False)
333        self.network = network
334        self.sigmoid = P.Sigmoid()
335
336    def construct(self, batch_ids, batch_wts, labels):
337        logits, _, _, = self.network(batch_ids, batch_wts)
338        pred_probs = self.sigmoid(logits)
339
340        return logits, pred_probs, labels
341
342
343class ModelBuilder:
344    """
345    Model builder for DeepFM.
346
347    Args:
348        model_config (ModelConfig): Model configuration.
349        train_config (TrainConfig): Train configuration.
350    """
351    def __init__(self, model_config, train_config):
352        self.model_config = model_config
353        self.train_config = train_config
354
355    def get_callback_list(self, model=None, eval_dataset=None):
356        """
357        Get callbacks which contains checkpoint callback, eval callback and loss callback.
358
359        Args:
360            model (Cell): The network is added callback (default=None).
361            eval_dataset (Dataset): Dataset for eval (default=None).
362        """
363        callback_list = []
364        if self.train_config.save_checkpoint:
365            config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps,
366                                         keep_checkpoint_max=self.train_config.keep_checkpoint_max)
367            ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix,
368                                      directory=self.train_config.output_path,
369                                      config=config_ck)
370            callback_list.append(ckpt_cb)
371        if self.train_config.eval_callback:
372            if model is None:
373                raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format(
374                                        self.train_config.eval_callback, model))
375            if eval_dataset is None:
376                raise RuntimeError("train_config.eval_callback is {}; get_callback_list() "
377                                   "args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset))
378            auc_metric = AUCMetric()
379            eval_callback = EvalCallBack(model, eval_dataset, auc_metric,
380                                         eval_file_path=os.path.join(self.train_config.output_path,
381                                                                     self.train_config.eval_file_name))
382            callback_list.append(eval_callback)
383        if self.train_config.loss_callback:
384            loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path,
385                                                                     self.train_config.loss_file_name))
386            callback_list.append(loss_callback)
387        if callback_list:
388            return callback_list
389        return None
390
391    def get_train_eval_net(self):
392        deepfm_net = DeepFMModel(self.model_config)
393        loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef)
394        train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate,
395                                  eps=self.train_config.epsilon,
396                                  loss_scale=self.train_config.loss_scale)
397        eval_net = PredictWithSigmoid(deepfm_net)
398        return train_net, eval_net
399