• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/nllloss.h"
18 #include "nnacl/kernel/default_kernel_base.h"
19 #include "nnacl/fp32/nllloss_fp32.h"
20 #include "nnacl/nllloss_parameter.h"
21 
NlllossCompute(KernelBase * self)22 int NlllossCompute(KernelBase *self) {
23   NLLLossStruct *nllloss = (NLLLossStruct *)self;
24   NNACL_CHECK_NULL_RETURN_ERR(nllloss);
25   float *logits = self->in_[Index0]->data_;
26   NNACL_CHECK_NULL_RETURN_ERR(logits);
27   int *labels = self->in_[Index1]->data_;
28   NNACL_CHECK_NULL_RETURN_ERR(labels);
29   float *weight = self->in_[Index2]->data_;
30   NNACL_CHECK_NULL_RETURN_ERR(weight);
31 
32   float *loss = self->out_[Index0]->data_;
33   NNACL_CHECK_NULL_RETURN_ERR(loss);
34   float *total_weight = self->out_[Index1]->data_;
35   NNACL_CHECK_NULL_RETURN_ERR(total_weight);
36 
37   ReductionType reduction_type = ((NLLLossParameter *)self->param_)->reduction_type_;
38   return NLLLoss(logits, labels, weight, loss, total_weight, nllloss, reduction_type);
39 }
40 
NlllossPrepare(KernelBase * self)41 int NlllossPrepare(KernelBase *self) {
42   NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_ERR);
43   NNACL_CHECK_FALSE(self->out_size_ < TWO_TENSOR, NNACL_ERR);
44   NLLLossStruct *nllloss = (NLLLossStruct *)self;
45   NNACL_CHECK_NULL_RETURN_ERR(nllloss);
46   TensorC *logits_tensor = self->in_[FIRST_INPUT];
47   NNACL_CHECK_NULL_RETURN_ERR(logits_tensor);
48   nllloss->batch_ = logits_tensor->shape_[Index0];
49   nllloss->class_num_ = logits_tensor->shape_[Index1];
50   return NNACL_OK;
51 }
52 
CreateNLLLoss(OpParameter * param,int data_type)53 KernelBase *CreateNLLLoss(OpParameter *param, int data_type) {
54   NLLLossStruct *nllloss = (NLLLossStruct *)malloc(sizeof(NLLLossStruct));
55   NNACL_CHECK_NULL_RETURN_NULL(nllloss);
56   nllloss->base_.Release = DefaultRelease;
57   nllloss->base_.Prepare = NlllossPrepare;
58   nllloss->base_.Resize = DefaultResize;
59   nllloss->base_.Compute = NlllossCompute;
60   return (KernelBase *)nllloss;
61 }
62 
63 REG_KERNEL_CREATOR(PrimType_NLLLoss, kNumberTypeFloat32, CreateNLLLoss)
64