• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cmath>
18 #include "src/litert/kernel_registry.h"
19 #include "nnacl/softmax_parameter.h"
20 #include "nnacl/fp32/softmax_fp32.h"
21 #include "nnacl/fp32_grad/softmax_grad_utils.h"
22 #include "src/litert/kernel/cpu/fp32_grad/sparse_softmax_cross_entropy_with_logits.h"
23 #include "include/errorcode.h"
24 
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
29 
30 namespace mindspore::kernel {
ReSize()31 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return Prepare(); }
32 
ForwardPostExecute(const int * labels,const float * losses,float * output) const33 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
34                                                                      float *output) const {
35   float total_loss = 0;
36   MS_CHECK_GT(param->batch_size_, 0, RET_ERROR);
37   for (size_t i = 0; i < static_cast<size_t>(param->batch_size_); ++i) {
38     if (labels[i] < 0) {
39       MS_LOG(ERROR) << "label value must >= 0";
40       return RET_ERROR;
41     }
42     size_t label = labels[i];
43     if (label > param->number_of_classes_) {
44       MS_LOG(ERROR) << "error label input!";
45       return RET_ERROR;
46     } else {
47       total_loss -= logf(losses[i * param->number_of_classes_ + label]);
48     }
49   }
50   output[0] = total_loss / static_cast<float>(param->batch_size_);
51   return RET_OK;
52 }
53 
GradPostExecute(const int * labels,const float * losses,float * grads) const54 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses,
55                                                                   float *grads) const {
56   MS_CHECK_GT(param->batch_size_, 0, RET_ERROR);
57   size_t row_start = 0;
58   for (int i = 0; i < param->batch_size_; ++i) {
59     if (labels[i] < 0) {
60       MS_LOG(ERROR) << "label value must >= 0";
61       return RET_ERROR;
62     }
63     size_t label = labels[i];
64     if (label > param->number_of_classes_) {
65       MS_LOG(ERROR) << "error label input!";
66       return RET_ERROR;
67     } else {
68       for (size_t j = 0; j < param->number_of_classes_; ++j) {
69         size_t index = row_start + j;
70         if (j == label) {
71           grads[index] = (losses[index] - 1) / static_cast<float>(param->batch_size_);
72         } else {
73           grads[index] = losses[index] / static_cast<float>(param->batch_size_);
74         }
75       }
76     }
77     row_start += param->number_of_classes_;
78   }
79   return RET_OK;
80 }
81 
DoExecute(int task_id)82 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::DoExecute(int task_id) {
83   auto sce_param = reinterpret_cast<SoftmaxCrossEntropyParameter *>(op_parameter_);
84   auto ins = reinterpret_cast<float *>(in_tensors_.at(0)->data());
85   CHECK_NULL_RETURN(ins);
86   auto labels = reinterpret_cast<int *>(in_tensors_.at(1)->data());
87   CHECK_NULL_RETURN(labels);
88   float *out = reinterpret_cast<float *>(out_tensors_.at(0)->data());
89   CHECK_NULL_RETURN(out);
90   size_t data_size = in_tensors_.at(0)->ElementsNum();
91   float *losses = static_cast<float *>(workspace());
92   CHECK_NULL_RETURN(losses);
93   float *sum_data = losses + data_size;
94   int length = input_shape_[sm_params_->axis_];
95   int stride = UP_DIV(outter_size_, threads_);
96   int count = MSMIN(stride, outter_size_ - stride * task_id);
97   if (count <= 0) return RET_OK;
98   switch (stage_) {
99     case 0:
100       SoftMaxP1(ins, losses, sum_data, task_id * stride, count, length, inner_size_);
101       break;
102     case C1NUM:
103       SoftMaxP2(ins, losses, sum_data, task_id * stride, count, length, inner_size_);
104       break;
105     case C2NUM:
106       if (sce_param->is_grad_) {
107         return GradPostExecute(labels, losses, out);
108       } else {
109         return ForwardPostExecute(labels, losses, out);
110       }
111     default:
112       MS_LOG(ERROR) << "Unsupported stage";
113       return RET_ERROR;
114   }
115   return RET_OK;
116 }
117 
SparseSoftmaxCrossEntropyWithLogitsRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)118 int SparseSoftmaxCrossEntropyWithLogitsRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
119   CHECK_NULL_RETURN(cdata);
120   auto sparse_kernel = reinterpret_cast<SparseSoftmaxCrossEntropyWithLogitsCPUKernel *>(cdata);
121   auto error_code = sparse_kernel->DoExecute(task_id);
122   if (error_code != RET_OK) {
123     MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogitsRun error task_id[" << task_id << "] error_code[" << error_code
124                   << "]";
125     return RET_ERROR;
126   }
127   return RET_OK;
128 }
129 
Run()130 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Run() {
131   int axis = sm_params_->axis_;
132   int n_dim = n_dim_;
133   const int *input_shape = input_shape_;
134   int inner_size = 1;
135   int outter_size = 1;
136   CHECK_NULL_RETURN(in_tensors_.at(0));
137   size_t data_size = in_tensors_.at(0)->ElementsNum();
138   float *losses = static_cast<float *>(workspace());
139   CHECK_NULL_RETURN(losses);
140   float *sum_data = losses + data_size;
141   std::fill(losses, losses + data_size, 0.f);
142   std::fill(sum_data, sum_data + input_shape_[0], 0.f);
143   for (int i = 0; i < axis; i++) {
144     outter_size *= input_shape[i];
145   }
146   for (int i = axis + 1; i < n_dim; i++) {
147     inner_size *= input_shape[i];
148   }
149   inner_size_ = inner_size;
150   outter_size_ = outter_size;
151   int max_num_of_threads = (outter_size_ < op_parameter_->thread_num_) ? outter_size_ : op_parameter_->thread_num_;
152   const std::vector<int> threads = {max_num_of_threads, max_num_of_threads, 1};
153   for (int stage = 0; stage < static_cast<int>(threads.size()); stage++) {
154     stage_ = stage;
155     threads_ = threads.at(stage);
156     int error_code = ParallelLaunch(this->ms_context_, SparseSoftmaxCrossEntropyWithLogitsRun, this, threads_);
157     if (error_code != RET_OK) {
158       MS_LOG(ERROR) << "SparseSoftmaxCrossEntropyWithLogits function error error_code[" << error_code << "]";
159       return RET_ERROR;
160     }
161   }
162   return RET_OK;
163 }
164 
Prepare()165 int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Prepare() {
166   CHECK_LESS_RETURN(in_tensors_.size(), 2);
167   CHECK_LESS_RETURN(out_tensors_.size(), 1);
168   CHECK_NULL_RETURN(in_tensors_.at(0));
169   CHECK_NULL_RETURN(in_tensors_.at(1));
170   CHECK_NULL_RETURN(out_tensors_.at(0));
171   auto dims = in_tensors_.at(0)->shape();
172   param->n_dim_ = 2;
173   param->number_of_classes_ = dims.at(1);
174   param->batch_size_ = dims.at(0);
175   for (unsigned int i = 0; i < dims.size(); i++) {
176     param->input_shape_[i] = dims.at(i);
177   }
178   if (this->in_tensors_.size() != TWO_TENSOR) {
179     MS_LOG(ERROR) << "sparse softmax entropy loss should have two inputs";
180     return RET_ERROR;
181   }
182   auto *in0 = in_tensors_.front();
183   if (in0 == nullptr) {
184     MS_LOG(ERROR) << "sparse softmax entropy loss in0 have no data";
185     return RET_ERROR;
186   }
187   size_t data_size = in_tensors_.at(0)->ElementsNum();
188   set_workspace_size((data_size + static_cast<size_t>(dims.at(0))) * sizeof(float));
189   if (sm_params_ == nullptr) {
190     sm_params_ = new (std::nothrow) SoftmaxParameter();
191     if (sm_params_ == nullptr) {
192       MS_LOG(ERROR) << "new softmax param failed.";
193       return RET_ERROR;
194     }
195   }
196   n_dim_ = Num2;
197   element_size_ = static_cast<int>(data_size);
198   sm_params_->axis_ = 1;
199   for (size_t i = 0; i < dims.size(); i++) {
200     input_shape_[i] = dims.at(i);
201   }
202   return RET_OK;
203 }
204 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
205            LiteKernelCreator<SparseSoftmaxCrossEntropyWithLogitsCPUKernel>)
206 }  // namespace mindspore::kernel
207