• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Resnet50 utils"""
16
17import time
18import numpy as np
19
20from mindspore.train.callback import Callback
21from mindspore import Tensor
22from mindspore import nn
23from mindspore.nn.loss.loss import LossBase
24from mindspore.ops import operations as P
25from mindspore.ops import functional as F
26from mindspore.common import dtype as mstype
27
28
29class Monitor(Callback):
30    """
31    Monitor loss and time.
32
33    Args:
34        lr_init (numpy array): train lr
35
36    Returns:
37        None
38
39    Examples:
40        >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
41    """
42
43    def __init__(self, lr_init=None, step_threshold=10):
44        super(Monitor, self).__init__()
45        self.lr_init = lr_init
46        self.lr_init_len = len(lr_init)
47        self.step_threshold = step_threshold
48
49    def epoch_begin(self, run_context):
50        self.losses = []
51        self.epoch_time = time.time()
52
53    def epoch_end(self, run_context):
54        cb_params = run_context.original_args()
55
56        epoch_mseconds = (time.time() - self.epoch_time) * 1000
57        per_step_mseconds = epoch_mseconds / cb_params.batch_num
58        print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds,
59                                                                                      per_step_mseconds,
60                                                                                      np.mean(self.losses)))
61        self.epoch_mseconds = epoch_mseconds
62
63    def step_begin(self, run_context):
64        self.step_time = time.time()
65
66    def step_end(self, run_context):
67        cb_params = run_context.original_args()
68        step_mseconds = (time.time() - self.step_time) * 1000
69        step_loss = cb_params.net_outputs
70
71        if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
72            step_loss = step_loss[0]
73        if isinstance(step_loss, Tensor):
74            step_loss = np.mean(step_loss.asnumpy())
75
76        self.losses.append(step_loss)
77        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
78
79        print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:8.6f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
80            cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch +
81            1, cb_params.batch_num, step_loss,
82            np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
83
84        if cb_params.cur_step_num == self.step_threshold:
85            run_context.request_stop()
86
87
88class CrossEntropy(LossBase):
89    """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
90
91    def __init__(self, smooth_factor=0, num_classes=1001):
92        super(CrossEntropy, self).__init__()
93        self.onehot = P.OneHot()
94        self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
95        self.off_value = Tensor(1.0 * smooth_factor /
96                                (num_classes - 1), mstype.float32)
97        self.ce = nn.SoftmaxCrossEntropyWithLogits()
98        self.mean = P.ReduceMean(False)
99
100    def construct(self, logit, label):
101        one_hot_label = self.onehot(label, F.shape(
102            logit)[1], self.on_value, self.off_value)
103        loss = self.ce(logit, one_hot_label)
104        loss = self.mean(loss, 0)
105        return loss
106