1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/train/train_utils.h"
18 #include <vector>
19 #include "include/errorcode.h"
20 #include "src/common/utils.h"
21 #include "src/executor/kernel_exec.h"
22 #ifdef ENABLE_FP16
23 #include "src/litert/kernel/cpu/fp16/fp16_op_handler.h"
24 #endif
25
26 namespace mindspore {
27 namespace lite {
TSFindTensor(const std::vector<lite::Tensor * > & where,const lite::Tensor * searchParameter)28 size_t TSFindTensor(const std::vector<lite::Tensor *> &where, const lite::Tensor *searchParameter) {
29 for (size_t i = 0; i < where.size(); i++) {
30 if (where[i] == searchParameter) {
31 return i;
32 }
33 }
34 return where.size();
35 }
36
TSFindTensorByName(const std::vector<lite::Tensor * > & where,const std::string & searchParameter)37 size_t TSFindTensorByName(const std::vector<lite::Tensor *> &where, const std::string &searchParameter) {
38 for (size_t i = 0; i < where.size(); i++) {
39 if (where[i]->tensor_name() == searchParameter) {
40 return i;
41 }
42 }
43 return where.size();
44 }
45
TSFindKernel(const std::vector<kernel::KernelExec * > & where,const std::string & searchParameter)46 kernel::KernelExec *TSFindKernel(const std::vector<kernel::KernelExec *> &where, const std::string &searchParameter) {
47 auto it = std::find_if(where.begin(), where.end(),
48 [&searchParameter](const kernel::KernelExec *k) { return (k->name() == searchParameter); });
49 if (it == where.end()) {
50 return nullptr;
51 }
52 return *it;
53 }
54
55 template <typename T>
CalcSparseClassificationAccuracy(T * predictions,int * labels,int batch_size,int num_of_classes)56 float CalcSparseClassificationAccuracy(T *predictions, int *labels, int batch_size, int num_of_classes) {
57 float accuracy = 0.0;
58 for (int b = 0; b < batch_size; b++) {
59 int max_idx = 0;
60 T max_score = predictions[num_of_classes * b];
61 for (int c = 1; c < num_of_classes; c++) {
62 if (predictions[num_of_classes * b + c] > max_score) {
63 max_score = predictions[num_of_classes * b + c];
64 max_idx = c;
65 }
66 }
67 if (labels[b] == max_idx) {
68 accuracy += 1.0;
69 }
70 }
71 return accuracy / (static_cast<float>(batch_size));
72 }
73
CalculateSparseClassification(lite::Tensor * input,lite::Tensor * output)74 float CalculateSparseClassification(lite::Tensor *input, lite::Tensor *output) {
75 if ((input->shape().size() != 1) || (input->data_type() != kNumberTypeInt32) || (output->shape().size() != 2)) {
76 MS_LOG(WARNING) << "SparseClassification got a " << input->shape() << "-D input tensor, " << output->shape()
77 << "-D output tensor";
78 return 0.0;
79 }
80
81 int batch = input->shape().at(0);
82 int num_classes = output->shape().at(1);
83 auto labels = reinterpret_cast<int *>(input->data());
84 float acc = 0.0f;
85 if (output->data_type() == kNumberTypeFloat32) {
86 acc = CalcSparseClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
87 #ifdef ENABLE_FP16
88 } else if (output->data_type() == kNumberTypeFloat16) {
89 acc = CalcSparseClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
90 #endif
91 }
92 return acc;
93 }
94
95 template <typename T>
CalcOneHotClassificationAccuracy(T * predictions,float * labels,int batch_size,int num_of_classes)96 float CalcOneHotClassificationAccuracy(T *predictions, float *labels, int batch_size, int num_of_classes) {
97 float accuracy = 0.0;
98 for (int b = 0; b < batch_size; b++) {
99 int label = 0;
100 int max_idx = 0;
101 float max_label_score = labels[num_of_classes * b];
102 T max_score = predictions[num_of_classes * b];
103 for (int c = 1; c < num_of_classes; c++) {
104 if (predictions[num_of_classes * b + c] > max_score) {
105 max_score = predictions[num_of_classes * b + c];
106 max_idx = c;
107 }
108 if (labels[num_of_classes * b + c] > max_label_score) {
109 max_label_score = labels[num_of_classes * b + c];
110 label = c;
111 }
112 }
113 if (label == max_idx) {
114 accuracy += 1.0;
115 }
116 }
117 return accuracy / (static_cast<float>(batch_size));
118 }
119
CalculateOneHotClassification(lite::Tensor * input,lite::Tensor * output)120 float CalculateOneHotClassification(lite::Tensor *input, lite::Tensor *output) {
121 if ((input->shape().size() != 2) || (output->shape().size() != 2)) {
122 MS_LOG(WARNING) << "OneHotClassification got a " << input->shape() << "-D input tensor, " << output->shape()
123 << "-D output tensor";
124 return 0.0;
125 }
126
127 int batch = input->shape().at(0);
128 int num_classes = input->shape().at(1);
129 auto labels = reinterpret_cast<float *>(input->data());
130 float acc = 0.0f;
131 if (output->data_type() == kNumberTypeFloat32) {
132 acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float *>(output->data()), labels, batch, num_classes);
133 #ifdef ENABLE_FP16
134 } else if (output->data_type() == kNumberTypeFloat16) {
135 acc = CalcOneHotClassificationAccuracy(reinterpret_cast<float16_t *>(output->data()), labels, batch, num_classes);
136 #endif
137 }
138 return acc;
139 }
140
CastTensor(Tensor * tensor,TypeId dst_data_type,bool support_fp16)141 Tensor *CastTensor(Tensor *tensor, TypeId dst_data_type, bool support_fp16) {
142 #ifdef ENABLE_FP16
143 MS_ASSERT(tensor != nullptr);
144 std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat16, kNumberTypeFloat};
145 std::vector<TypeId> fp32_type = {kNumberTypeFloat32, kNumberTypeFloat};
146 if (!IsContain(valid_type, tensor->data_type())) {
147 MS_LOG(ERROR) << "source data type must be fp32 or fp16,cur is " << tensor->data_type();
148 return nullptr;
149 }
150
151 if (!IsContain(valid_type, dst_data_type)) {
152 MS_LOG(ERROR) << "destination data type must be fp32 or fp16";
153 return nullptr;
154 }
155
156 auto origin_data = tensor->data();
157 MS_ASSERT(origin_data != nullptr);
158 auto restore_tensor = Tensor::CopyTensor(*tensor, false);
159 restore_tensor->set_data(origin_data);
160 restore_tensor->set_own_data(tensor->own_data());
161 restore_tensor->set_allocator(tensor->allocator());
162 restore_tensor->set_scale(tensor->get_scale());
163 if (IsContain(fp32_type, tensor->data_type()) && dst_data_type == kNumberTypeFloat16) {
164 tensor->set_data(nullptr);
165 tensor->set_data_type(kNumberTypeFloat16);
166 auto ret = tensor->MallocData();
167 auto new_tensor_data = tensor->data();
168 MS_ASSERT(new_tensor_data != nullptr);
169 if (RET_OK != ret) {
170 MS_LOG(ERROR) << "malloc data failed";
171 delete restore_tensor;
172 return nullptr;
173 }
174 MS_LOG(DEBUG) << "Convert tensor to fp16 " << tensor->tensor_name();
175 Float32ToFloat16_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
176 } else {
177 tensor->set_data(nullptr);
178 tensor->set_data_type(kNumberTypeFloat32);
179 auto ret = tensor->MallocData();
180 if (RET_OK != ret) {
181 MS_LOG(ERROR) << "malloc data failed";
182 delete restore_tensor;
183 return nullptr;
184 }
185 auto new_tensor_data = tensor->data();
186 MS_ASSERT(new_tensor_data != nullptr);
187 MS_LOG(DEBUG) << "Convert tensor to fp32 " << tensor->tensor_name();
188 Float16ToFloat32_fp16_handler(origin_data, new_tensor_data, tensor->ElementsNum(), support_fp16);
189 }
190 return restore_tensor;
191 #else
192 return nullptr;
193 #endif
194 }
195
ScaleTensor(Tensor * tensor,float scale)196 int ScaleTensor(Tensor *tensor, float scale) {
197 MS_ASSERT(tensor != nullptr);
198 std::vector<TypeId> valid_type = {kNumberTypeFloat32, kNumberTypeFloat};
199 if (!IsContain(valid_type, tensor->data_type())) {
200 MS_LOG(DEBUG) << "Tensor: " << tensor->tensor_name() << " type is " << tensor->data_type();
201 return RET_OK;
202 }
203
204 MS_LOG(DEBUG) << "Scale tensor: " << tensor->tensor_name() << " " << scale;
205 return tensor->Scale<float>(scale);
206 }
207
TSFindTensors(const kernel::KernelExec * pre_kernel,const kernel::KernelExec * post_kernel)208 std::vector<Tensor *> TSFindTensors(const kernel::KernelExec *pre_kernel, const kernel::KernelExec *post_kernel) {
209 std::vector<Tensor *> res;
210 MS_CHECK_TRUE_RET(pre_kernel != nullptr, res);
211 MS_CHECK_TRUE_RET(post_kernel != nullptr, res);
212 auto out_tensors = pre_kernel->out_tensors();
213 auto in_tensors = post_kernel->in_tensors();
214 for (auto tensor : out_tensors) {
215 if (std::find(in_tensors.begin(), in_tensors.end(), tensor) == in_tensors.end()) {
216 continue;
217 }
218 res.push_back(tensor);
219 }
220 return res;
221 }
222 } // namespace lite
223 } // namespace mindspore
224