1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_utils.h"
18 #include <vector>
19 #include "include/errorcode.h"
20 #include "include/ms_tensor.h"
21 #include "src/common/utils.h"
22 #include "src/lite_kernel.h"
23 #ifdef ENABLE_FP16
24 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
25 #endif
26
27 namespace mindspore {
28 namespace lite {
TSFindTensor(const std::vector<lite::Tensor * > & where,const lite::Tensor * searchParameter)29 size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) {
30 for (size_t i = 0; i < where.size(); i++) {
31 if (where[i] == searchParameter) {
32 return i;
33 }
34 }
35 return where.size();
36 }
37
TSFindTensorByName(const std::vector<lite::Tensor * > & where,const std::string & searchParameter)38 size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter) {
39 for (size_t i = 0; i < where.size(); i++) {
40 if (where[i]->tensor_name() == searchParameter) {
41 return i;
42 }
43 }
44 return where.size();
45 }
46
TSFindKernel(const std::vector<kernel::LiteKernel * > & where,const std::string & searchParameter)47 kernel::LiteKernel *TSFindKernel(const std::vector<kernel::LiteKernel *> &where, const std::string &searchParameter) {
48 auto it = std::find_if(where.begin(), where.end(),
49 [&searchParameter](const kernel::LiteKernel *k) { return (k->name() == searchParameter); });
50 if (it == where.end()) {
51 return nullptr;
52 }
53 return *it;
54 }
55
56 template <typename T>
CalcSparseClassificationAccuracy(T * predictions,int * labels,int batch_size,int num_of_classes)57 float CalcSparseClassificationAccuracy(T *predictions, int *labels, int batch_size, int num_of_classes) {
58 float accuracy = 0.0;
59 for (int b = 0; b < batch_size; b++) {
60 int max_idx = 0;
61 T max_score = predictions[num_of_classes * b];
62 for (int c = 1; c < num_of_classes; c++) {
63 if (predictions[num_of_classes * b + c] > max_score) {
64 max_score = predictions[num_of_classes * b + c];
65 max_idx = c;
66 }
67 }
68 if (labels[b] == max_idx) accuracy += 1.0;
69 }
70 return accuracy / (static_cast<float>(batch_size));
71 }
72
CalculateSparseClassification(tensor::MSTensor * input,tensor::MSTensor * output)73 float CalculateSparseClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
74 if ((input->shape().size() != 1) || (input->data_type() != kNumberTypeInt32) || (output->shape().size() != 2)) {
75 MS_LOG(WARNING) << "SparseClassification got a " << input->shape() << "-D input tensor, " << output->shape()
76 << "-D output tensor";
77 return 0.0;
78 }
79
80 int batch = input->shape().at(0);
81 int num_classes = output->shape().at(1);
82 auto labels = reinterpret_cast<int *>(input->data());
83 float acc = 0.0f;
84 if (output->data_type() == kNumberTypeFloat32) {
85 acc = CalcSparseClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
86 #ifdef ENABLE_FP16
87 } else if (output->data_type() == kNumberTypeFloat16) {
88 acc = CalcSparseClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
89 #endif
90 }
91 return acc;
92 }
93
94 template <typename T>
CalcOneHotClassificationAccuracy(T * predictions,float * labels,int batch_size,int num_of_classes)95 float CalcOneHotClassificationAccuracy(T *predictions, float *labels, int batch_size, int num_of_classes) {
96 float accuracy = 0.0;
97 for (int b = 0; b < batch_size; b++) {
98 int label = 0;
99 int max_idx = 0;
100 float max_label_score = labels[num_of_classes * b];
101 T max_score = predictions[num_of_classes * b];
102 for (int c = 1; c < num_of_classes; c++) {
103 if (predictions[num_of_classes * b + c] > max_score) {
104 max_score = predictions[num_of_classes * b + c];
105 max_idx = c;
106 }
107 if (labels[num_of_classes * b + c] > max_label_score) {
108 max_label_score = labels[num_of_classes * b + c];
109 label = c;
110 }
111 }
112 if (label == max_idx) accuracy += 1.0;
113 }
114 return accuracy / (static_cast<float>(batch_size));
115 }
116
CalculateOneHotClassification(tensor::MSTensor * input,tensor::MSTensor * output)117 float CalculateOneHotClassification(tensor::MSTensor *input, tensor::MSTensor *output) {
118 if ((input->shape().size() != 2) || (output->shape().size() != 2)) {
119 MS_LOG(WARNING) << "OneHotClassification got a " << input->shape() << "-D input tensor, " << output->shape()
120 << "-D output tensor";
121 return 0.0;
122 }
123
124 int batch = input->shape().at(0);
125 int num_classes = input->shape().at(1);
126 auto labels = reinterpret_cast<float *>(input->data());
127 float acc = 0.0f;
128 if (output->data_type() == kNumberTypeFloat32) {
129 acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
130 #ifdef ENABLE_FP16
131 } else if (output->data_type() == kNumberTypeFloat16) {
132 acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
133 #endif
134 }
135 return acc;
136 }
137
CastTensor(Tensor * tensor,TypeId dst_data_type,bool support_fp16)138 Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
139 #ifdef ENABLE_FP16
140 MS_ASSERT(tensor != nullptr);
141 std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
142 std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat};
143 if (!IsContain(valid_type, tensor->data_type())) {
144 MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type();
145 return nullptr;
146 }
147
148 if (!IsContain(valid_type, dst_data_type)) {
149 MS_LOG(ERROR) << "destination data type must be fp32 or fp16";
150 return nullptr;
151 }
152
153 auto origin_data = tensor->data();
154 MS_ASSERT(origin_data != nullptr);
155 auto restore_tensor = Tensor::CopyTensor(*tensor, false);
156 restore_tensor->set_data(origin_data);
157 restore_tensor->set_own_data(tensor->own_data());
158 restore_tensor->set_allocator(tensor->allocator());
159 restore_tensor->set_scale(tensor->get_scale());
160 if (IsContain(fp32_type, tensor->data_type()) && dst_data_type == kNumberTypeFloat16) {
161 tensor->set_data(nullptr);
162 tensor->set_data_type(kNumberTypeFloat16);
163 auto ret = tensor->MallocData();
164 auto new_tensor_data = tensor->data();
165 MS_ASSERT(new_tensor_data != nullptr);
166 if (RET_OK != ret) {
167 MS_LOG(ERROR) << "malloc data failed";
168 delete restore_tensor;
169 return nullptr;
170 }
171 MS_LOG(DEBUG) << "Convert tensor to fp16 " << tensor->tensor_name();
172 Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
173 } else {
174 tensor->set_data(nullptr);
175 tensor->set_data_type(kNumberTypeFloat32);
176 auto ret = tensor->MallocData();
177 if (RET_OK != ret) {
178 MS_LOG(ERROR) << "malloc data failed";
179 delete restore_tensor;
180 return nullptr;
181 }
182 auto new_tensor_data = tensor->data();
183 MS_ASSERT(new_tensor_data != nullptr);
184 MS_LOG(DEBUG) << "Convert tensor to fp32 " << tensor->tensor_name();
185 Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
186 }
187 return restore_tensor;
188 #else
189 return nullptr;
190 #endif
191 }
192
ScaleTensor(Tensor * tensor,float scale)193 int ScaleTensor(Tensor *tensor, float scale) {
194 MS_ASSERT(tensor != nullptr);
195 std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat};
196 if (!IsContain(valid_type, tensor->data_type())) {
197 MS_LOG(DEBUG) << "Tensor: " << tensor->tensor_name() << " type is " << tensor->data_type();
198 return RET_OK;
199 }
200
201 MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale;
202 return tensor->Scale<float>(scale);
203 }
204 } // namespace lite
205 } // namespace mindspore
206