1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""OhemLoss.""" 16import mindspore.nn as nn 17import mindspore.common.dtype as mstype 18from mindspore.ops import operations as P 19from mindspore.ops import functional as F 20 21 22class OhemLoss(nn.Cell): 23 """Ohem loss cell.""" 24 def __init__(self, num, ignore_label): 25 super(OhemLoss, self).__init__() 26 self.mul = P.Mul() 27 self.shape = P.Shape() 28 self.one_hot = nn.OneHot(-1, num, 1.0, 0.0) 29 self.squeeze = P.Squeeze() 30 self.num = num 31 self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() 32 self.mean = P.ReduceMean() 33 self.select = P.Select() 34 self.reshape = P.Reshape() 35 self.cast = P.Cast() 36 self.not_equal = P.NotEqual() 37 self.equal = P.Equal() 38 self.reduce_sum = P.ReduceSum(keep_dims=False) 39 self.fill = P.Fill() 40 self.transpose = P.Transpose() 41 self.ignore_label = ignore_label 42 self.loss_weight = 1.0 43 44 def construct(self, logits, labels): 45 logits = self.transpose(logits, (0, 2, 3, 1)) 46 logits = self.reshape(logits, (-1, self.num)) 47 labels = F.cast(labels, mstype.int32) 48 labels = self.reshape(labels, (-1,)) 49 one_hot_labels = self.one_hot(labels) 50 losses = self.cross_entropy(logits, one_hot_labels)[0] 51 weights = self.cast(self.not_equal(labels, self.ignore_label), mstype.float32) * self.loss_weight 52 weighted_losses = self.mul(losses, weights) 53 loss = self.reduce_sum(weighted_losses, (0,)) 54 zeros = self.fill(mstype.float32, self.shape(weights), 0.0) 55 ones = self.fill(mstype.float32, self.shape(weights), 1.0) 56 present = self.select(self.equal(weights, zeros), zeros, ones) 57 present = self.reduce_sum(present, (0,)) 58 59 zeros = self.fill(mstype.float32, self.shape(present), 0.0) 60 min_control = self.fill(mstype.float32, self.shape(present), 1.0) 61 present = self.select(self.equal(present, zeros), min_control, present) 62 loss = loss / present 63 return loss 64