• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/train/train_utils.h"
18 #include <vector>
19 #include "include/errorcode.h"
20 #include "include/ms_tensor.h"
21 #include "src/common/utils.h"
22 #include "src/lite_kernel.h"
23 #ifdef ENABLE_FP16
24 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
25 #endif
26 
27 namespace mindspore {
28 namespace lite {
TSFindTensor(const std::vector<lite::Tensor * > & where,const lite::Tensor * searchParameter)29 size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) {
30   for (size_t i = 0; i < where.size(); i++) {
31     if (where[i] == searchParameter) {
32       return i;
33     }
34   }
35   return where.size();
36 }
37 
TSFindTensorByName(const std::vector<lite::Tensor * > & where,const std::string & searchParameter)38 size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter) {
39   for (size_t i = 0; i < where.size(); i++) {
40     if (where[i]->tensor_name() == searchParameter) {
41       return i;
42     }
43   }
44   return where.size();
45 }
46 
TSFindKernel(const std::vector<kernel::LiteKernel * > & where,const std::string & searchParameter)47 kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter) {
48   auto it = std::find_if(where.begin(), where.end(),
49                          [&searchParameter](const kernel::LiteKernel *k) { return (k->name() == searchParameter); });
50   if (it == where.end()) {
51     return nullptr;
52   }
53   return *it;
54 }
55 
56 template <typename T>
CalcSparseClassificationAccuracy(T * predictions,int * labels,int batch_size,int num_of_classes)57 float CalcSparseClassificationAccuracy(T *predictions, int *labels, int batch_size, int num_of_classes) {
58   float accuracy = 0.0;
59   for (int b = 0; b < batch_size; b++) {
60     int max_idx = 0;
61     T max_score = predictions[num_of_classes * b];
62     for (int c = 1; c < num_of_classes; c++) {
63       if (predictions[num_of_classes * b + c] > max_score) {
64         max_score = predictions[num_of_classes * b + c];
65         max_idx = c;
66       }
67     }
68     if (labels[b] == max_idx) accuracy += 1.0;
69   }
70   return accuracy / (static_cast<float>(batch_size));
71 }
72 
CalculateSparseClassification(tensor::MSTensor * input,tensor::MSTensor * output)73 float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
74   if ((input->shape().size() != 1) || (input->data_type() != kNumberTypeInt32) || (output->shape().size() != 2)) {
75     MS_LOG(WARNING) << "SparseClassification got a " << input->shape() << "-D input tensor, " << output->shape()
76                     << "-D output tensor";
77     return 0.0;
78   }
79 
80   int batch = input->shape().at(0);
81   int num_classes = output->shape().at(1);
82   auto labels = reinterpret_cast<int *>(input->data());
83   float acc = 0.0f;
84   if (output->data_type() == kNumberTypeFloat32) {
85     acc = CalcSparseClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
86 #ifdef ENABLE_FP16
87   } else if (output->data_type() == kNumberTypeFloat16) {
88     acc = CalcSparseClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
89 #endif
90   }
91   return acc;
92 }
93 
94 template <typename T>
CalcOneHotClassificationAccuracy(T * predictions,float * labels,int batch_size,int num_of_classes)95 float CalcOneHotClassificationAccuracy(T *predictions, float *labels, int batch_size, int num_of_classes) {
96   float accuracy = 0.0;
97   for (int b = 0; b < batch_size; b++) {
98     int label = 0;
99     int max_idx = 0;
100     float max_label_score = labels[num_of_classes * b];
101     T max_score = predictions[num_of_classes * b];
102     for (int c = 1; c < num_of_classes; c++) {
103       if (predictions[num_of_classes * b + c] > max_score) {
104         max_score = predictions[num_of_classes * b + c];
105         max_idx = c;
106       }
107       if (labels[num_of_classes * b + c] > max_label_score) {
108         max_label_score = labels[num_of_classes * b + c];
109         label = c;
110       }
111     }
112     if (label == max_idx) accuracy += 1.0;
113   }
114   return accuracy / (static_cast<float>(batch_size));
115 }
116 
CalculateOneHotClassification(tensor::MSTensor * input,tensor::MSTensor * output)117 float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
118   if ((input->shape().size() != 2) || (output->shape().size() != 2)) {
119     MS_LOG(WARNING) << "OneHotClassification got a " << input->shape() << "-D input tensor, " << output->shape()
120                     << "-D output tensor";
121     return 0.0;
122   }
123 
124   int batch = input->shape().at(0);
125   int num_classes = input->shape().at(1);
126   auto labels = reinterpret_cast<float *>(input->data());
127   float acc = 0.0f;
128   if (output->data_type() == kNumberTypeFloat32) {
129     acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
130 #ifdef ENABLE_FP16
131   } else if (output->data_type() == kNumberTypeFloat16) {
132     acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
133 #endif
134   }
135   return acc;
136 }
137 
CastTensor(Tensor * tensor,TypeId dst_data_type,bool support_fp16)138 Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
139 #ifdef ENABLE_FP16
140   MS_ASSERT(tensor != nullptr);
141   std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
142   std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat};
143   if (!IsContain(valid_type, tensor->data_type())) {
144     MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type();
145     return nullptr;
146   }
147 
148   if (!IsContain(valid_type, dst_data_type)) {
149     MS_LOG(ERROR) << "destination data type must be fp32 or fp16";
150     return nullptr;
151   }
152 
153   auto origin_data = tensor->data();
154   MS_ASSERT(origin_data != nullptr);
155   auto restore_tensor = Tensor::CopyTensor(*tensor, false);
156   restore_tensor->set_data(origin_data);
157   restore_tensor->set_own_data(tensor->own_data());
158   restore_tensor->set_allocator(tensor->allocator());
159   restore_tensor->set_scale(tensor->get_scale());
160   if (IsContain(fp32_type, tensor->data_type()) && dst_data_type == kNumberTypeFloat16) {
161     tensor->set_data(nullptr);
162     tensor->set_data_type(kNumberTypeFloat16);
163     auto ret = tensor->MallocData();
164     auto new_tensor_data = tensor->data();
165     MS_ASSERT(new_tensor_data != nullptr);
166     if (RET_OK != ret) {
167       MS_LOG(ERROR) << "malloc data failed";
168       delete restore_tensor;
169       return nullptr;
170     }
171     MS_LOG(DEBUG) << "Convert tensor to fp16 " << tensor->tensor_name();
172     Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
173   } else {
174     tensor->set_data(nullptr);
175     tensor->set_data_type(kNumberTypeFloat32);
176     auto ret = tensor->MallocData();
177     if (RET_OK != ret) {
178       MS_LOG(ERROR) << "malloc data failed";
179       delete restore_tensor;
180       return nullptr;
181     }
182     auto new_tensor_data = tensor->data();
183     MS_ASSERT(new_tensor_data != nullptr);
184     MS_LOG(DEBUG) << "Convert tensor to fp32 " << tensor->tensor_name();
185     Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
186   }
187   return restore_tensor;
188 #else
189   return nullptr;
190 #endif
191 }
192 
ScaleTensor(Tensor * tensor,float scale)193 int ScaleTensor(Tensor *tensor, float scale) {
194   MS_ASSERT(tensor != nullptr);
195   std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat};
196   if (!IsContain(valid_type, tensor->data_type())) {
197     MS_LOG(DEBUG) << "Tensor: " << tensor->tensor_name() << " type is " << tensor->data_type();
198     return RET_OK;
199   }
200 
201   MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale;
202   return tensor->Scale<float>(scale);
203 }
204 }  // namespace lite
205 }  // namespace mindspore
206