• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""OhemLoss."""
16import mindspore.nn as nn
17import mindspore.common.dtype as mstype
18from mindspore.ops import operations as P
19from mindspore.ops import functional as F
20
21
22class OhemLoss(nn.Cell):
23    """Ohem loss cell."""
24    def __init__(self, num, ignore_label):
25        super(OhemLoss, self).__init__()
26        self.mul = P.Mul()
27        self.shape = P.Shape()
28        self.one_hot = nn.OneHot(-1, num, 1.0, 0.0)
29        self.squeeze = P.Squeeze()
30        self.num = num
31        self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
32        self.mean = P.ReduceMean()
33        self.select = P.Select()
34        self.reshape = P.Reshape()
35        self.cast = P.Cast()
36        self.not_equal = P.NotEqual()
37        self.equal = P.Equal()
38        self.reduce_sum = P.ReduceSum(keep_dims=False)
39        self.fill = P.Fill()
40        self.transpose = P.Transpose()
41        self.ignore_label = ignore_label
42        self.loss_weight = 1.0
43
44    def construct(self, logits, labels):
45        logits = self.transpose(logits, (0, 2, 3, 1))
46        logits = self.reshape(logits, (-1, self.num))
47        labels = F.cast(labels, mstype.int32)
48        labels = self.reshape(labels, (-1,))
49        one_hot_labels = self.one_hot(labels)
50        losses = self.cross_entropy(logits, one_hot_labels)[0]
51        weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight
52        weighted_losses = self.mul(losses, weights)
53        loss = self.reduce_sum(weighted_losses, (0,))
54        zeros = self.fill(mstype.float32, self.shape(weights), 0.0)
55        ones = self.fill(mstype.float32, self.shape(weights), 1.0)
56        present = self.select(self.equal(weights, zeros), zeros, ones)
57        present = self.reduce_sum(present, (0,))
58
59        zeros = self.fill(mstype.float32, self.shape(present), 0.0)
60        min_control = self.fill(mstype.float32, self.shape(present), 1.0)
61        present = self.select(self.equal(present, zeros), min_control, present)
62        loss = loss / present
63        return loss
64