1 /*
2 * Copyright (c) 2017-2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
25
26 #include "arm_compute/core/Size2D.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
29 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
30 #include "arm_compute/runtime/CL/CLScheduler.h"
31 #include "src/core/CL/kernels/CLDepthConvertLayerKernel.h"
32 #include "src/core/CL/kernels/CLFillBorderKernel.h"
33 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
34 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
35 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h"
36 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h"
37 #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h"
38 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
39 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
40 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
41 #include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
42 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
43 #include "src/core/CL/kernels/CLTransposeKernel.h"
44 #include "support/Cast.h"
45 #include "support/MemorySupport.h"
46
47 #include <algorithm>
48
49 namespace arm_compute
50 {
51 using namespace arm_compute::misc::shape_calculator;
52 using namespace arm_compute::utils::cast;
53
54 namespace
55 {
construct_gemmlowp_output_stage(const ITensorInfo & input,const ITensorInfo & weights,const ITensorInfo & output,GEMMLowpOutputStageInfo & gemmlowp_output_stage,ActivationLayerInfo activation_info)56 Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output,
57 GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info)
58 {
59 gemmlowp_output_stage.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
60 gemmlowp_output_stage.gemmlowp_offset = 0;
61 gemmlowp_output_stage.gemmlowp_multiplier = 0;
62 gemmlowp_output_stage.gemmlowp_shift = 0;
63
64 const auto data_type = input.data_type();
65
66 // Configure output stage for quantized case
67 if(is_data_type_quantized_asymmetric(data_type))
68 {
69 const QuantizationInfo oq_info = output.quantization_info();
70 const UniformQuantizationInfo iq_unif = input.quantization_info().uniform();
71 const UniformQuantizationInfo wq_unif = weights.quantization_info().uniform();
72 const UniformQuantizationInfo oq_unif = oq_info.uniform();
73
74 const auto output_quant_info = (output.total_size() == 0) ? iq_unif : oq_unif;
75
76 const float multiplier = (iq_unif.scale * wq_unif.scale) / output_quant_info.scale;
77 int output_multiplier = 0;
78 int output_shift = 0;
79 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
80
81 PixelValue type_min{};
82 PixelValue type_max{};
83 std::tie(type_min, type_max) = get_min_max(data_type);
84
85 if(activation_info.enabled())
86 {
87 std::tie(type_min, type_max) = get_quantized_activation_min_max(activation_info, data_type, output_quant_info);
88 }
89
90 // Set the GEMMLowp output stage info
91 gemmlowp_output_stage.gemmlowp_offset = output_quant_info.offset;
92 gemmlowp_output_stage.gemmlowp_multiplier = output_multiplier;
93 gemmlowp_output_stage.gemmlowp_shift = output_shift;
94 gemmlowp_output_stage.gemmlowp_multipliers.push_back(output_multiplier);
95 gemmlowp_output_stage.gemmlowp_shifts.push_back(output_shift);
96 type_min.get(gemmlowp_output_stage.gemmlowp_min_bound);
97 type_max.get(gemmlowp_output_stage.gemmlowp_max_bound);
98 }
99
100 return Status{};
101 }
102
validate_mm(const ITensorInfo & input,const ITensorInfo & weights,const ITensorInfo * bias,const ITensorInfo & output,const FullyConnectedLayerInfo & fc_info)103 Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &output, const FullyConnectedLayerInfo &fc_info)
104 {
105 GEMMLowpOutputStageInfo gemmlowp_output_stage;
106 ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage, fc_info.activation_info));
107
108 const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
109 false, // is_b_reshaped
110 true, // reshape_b_only_on_first_run
111 0, // depth_output_gemm3d
112 false, // reinterpret_input_as_3d
113 fc_info.retain_internal_weights, // retain_internal_weights
114 gemmlowp_output_stage, // gemmlowp_output_stage
115 fc_info.fp_mixed_precision, // fp_mixed_precision
116 true, // broadcast_bias
117 ActivationLayerInfo()); // activation_info
118
119 if(is_data_type_quantized_asymmetric(input.data_type()))
120 {
121 const UniformQuantizationInfo iq_info = input.quantization_info().uniform();
122 const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
123
124 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
125 // Extract and negate input and weights offset
126 const QuantizationInfo input_quantization_info(iq_info.scale, -iq_info.offset);
127 const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset);
128
129 // Validate gemmlowp function
130 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(&input.clone()->set_quantization_info(input_quantization_info),
131 &weights.clone()->set_quantization_info(weights_quantization_info),
132 bias,
133 &output,
134 gemm_info));
135 }
136 else
137 {
138 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input, &weights, bias, &output, 1.f, 1.f, gemm_info));
139 }
140
141 return Status{};
142 }
143 } // namespace
144
configure(const ICLTensor * input,ICLTensor * output)145 void CLFullyConnectedLayerReshapeWeights::configure(const ICLTensor *input, ICLTensor *output)
146 {
147 configure(CLKernelLibrary::get().get_compile_context(), input, output);
148 }
149
configure(const CLCompileContext & compile_context,const ICLTensor * input,ICLTensor * output)150 void CLFullyConnectedLayerReshapeWeights::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output)
151 {
152 auto k = arm_compute::support::cpp14::make_unique<CLTransposeKernel>();
153 k->configure(compile_context, input, output);
154 _kernel = std::move(k);
155 }
156
validate(const ITensorInfo * input,const ITensorInfo * output)157 Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output)
158 {
159 return CLTransposeKernel::validate(input, output);
160 }
161
CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager,IWeightsManager * weights_manager)162 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
163 : _memory_group(memory_manager), _weights_manager(weights_manager), _convert_weights(), _convert_weights_managed(), _reshape_weights_managed_function(), _flatten_layer(), _reshape_weights_function(),
164 _mm_gemm(memory_manager, weights_manager), _mm_gemmlowp(memory_manager), _flatten_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true),
165 _are_weights_reshaped(true), _is_fc_after_conv(true), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
166 {
167 }
configure_mm(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)168 void CLFullyConnectedLayer::configure_mm(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
169 const FullyConnectedLayerInfo &fc_info)
170 {
171 GEMMLowpOutputStageInfo gemmlowp_output_stage;
172 construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage, fc_info.activation_info);
173
174 const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
175 false, // is_b_reshaped
176 true, // reshape_b_only_on_first_run
177 0, // depth_output_gemm3d
178 false, // reinterpret_input_as_3d
179 fc_info.retain_internal_weights, // retain_internal_weights
180 gemmlowp_output_stage, // gemmlowp_output_stage
181 fc_info.fp_mixed_precision, // fp_mixed_precision
182 true, // broadcast_bias
183 fc_info.activation_info); // activation_info
184
185 if(_is_quantized)
186 {
187 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
188 // Extract and negate input and weights offset
189 const QuantizationInfo input_quantization_info = input->info()->quantization_info();
190 const QuantizationInfo weights_quantization_info = weights->info()->quantization_info();
191
192 input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.uniform().scale, -input_quantization_info.uniform().offset));
193 weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
194
195 // Configure gemmlowp function
196 _mm_gemmlowp.configure(compile_context, input, weights, bias, output, gemm_info);
197
198 // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers
199 input->info()->set_quantization_info(input_quantization_info);
200 weights->info()->set_quantization_info(weights_quantization_info);
201 }
202 else
203 {
204 // Configure matrix multiply kernel
205 _mm_gemm.configure(compile_context, input, weights, bias, output, 1.f, 1.f, gemm_info);
206 }
207 }
208
configure_conv_fc(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)209 void CLFullyConnectedLayer::configure_conv_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
210 const FullyConnectedLayerInfo &fc_info)
211 {
212 ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
213
214 // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
215
216 // Initialize output tensor for flatten
217 TensorShape shape_flatten = compute_flatten_shape(input->info());
218 _flatten_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten).set_data_layout(DataLayout::NCHW));
219
220 // Configure flatten kernel
221 _memory_group.manage(&_flatten_output);
222 _flatten_layer.configure(compile_context, input, &_flatten_output);
223
224 // Configure matrix multiply kernel
225 configure_mm(compile_context, &_flatten_output, weights, bias, output, fc_info);
226
227 // Allocate the output tensor for flatten once all the configure methods have been called
228 _flatten_output.allocator()->allocate();
229 }
230
configure_fc_fc(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)231 void CLFullyConnectedLayer::configure_fc_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
232 const FullyConnectedLayerInfo &fc_info)
233 {
234 ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
235
236 // Configure matrix multiply kernel
237 configure_mm(compile_context, input, weights, bias, output, fc_info);
238 }
239
configure(const ICLTensor * input,const ICLTensor * weights,const ICLTensor * biases,ICLTensor * output,FullyConnectedLayerInfo fc_info)240 void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
241 FullyConnectedLayerInfo fc_info)
242 {
243 configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, fc_info);
244 }
245
configure(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * biases,ICLTensor * output,FullyConnectedLayerInfo fc_info)246 void CLFullyConnectedLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
247 FullyConnectedLayerInfo fc_info)
248 {
249 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
250
251 // Perform validate step
252 ARM_COMPUTE_ERROR_THROW_ON(CLFullyConnectedLayer::validate(input->info(),
253 weights->info(),
254 biases != nullptr ? biases->info() : nullptr,
255 output->info(),
256 fc_info));
257
258 _are_weights_converted = true;
259 _are_weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
260 _is_fc_after_conv = true;
261 _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
262 _is_prepared = fc_info.retain_internal_weights;
263 _original_weights = weights;
264
265 if(_weights_manager)
266 {
267 _weights_manager->manage(weights);
268 }
269
270 const ICLTensor *weights_to_use = weights;
271
272 // With the Fully Connected layer we can have 4 different cases:
273 // 1) Convolution layer -> Fully Connected layer without batches
274 // 2) Fully Connected layer -> Fully Connected layer without batches
275 // 3) Convolution layer -> Fully Connected layer with batches
276 // 4) Fully Connected layer -> Fully Connected layer with batches
277
278 // Check if we have a fully connected layer with batches
279 const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
280 if(is_batched_fc_layer)
281 {
282 _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
283 input->info()->tensor_shape().cend(),
284 output->info()->tensor_shape().cbegin() + 1));
285 }
286 else
287 {
288 _is_fc_after_conv = input->info()->num_dimensions() > 1;
289 }
290
291 // Reshape weights if needed
292 if(!_are_weights_reshaped)
293 {
294 if(_weights_manager && _weights_manager->are_weights_managed(weights))
295 {
296 _reshape_weights_managed_function.configure(compile_context, weights);
297 weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed_function));
298 }
299 else
300 {
301 // Reshape the weights
302 _reshape_weights_function.configure(compile_context, weights, &_reshape_weights_output);
303 weights_to_use = &_reshape_weights_output;
304 }
305 }
306
307 // Convert weights if needed
308 if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
309 {
310 if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
311 {
312 _convert_weights_managed.configure(compile_context, weights_to_use,
313 input->info()->tensor_shape(),
314 fc_info.weights_trained_layout);
315 weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_convert_weights_managed));
316 }
317 else
318 {
319 // Convert weights
320 _convert_weights.configure(compile_context, weights_to_use,
321 &_converted_weights_output,
322 input->info()->tensor_shape(),
323 fc_info.weights_trained_layout);
324
325 weights_to_use = &_converted_weights_output;
326 }
327 _are_weights_converted = false;
328 }
329
330 if(_is_fc_after_conv)
331 {
332 // Fully Connected layer after a Convolution Layer without batches
333 configure_conv_fc(compile_context, input, weights_to_use, biases, output, fc_info);
334 }
335 else
336 {
337 // Fully Connected layer after a Fully Connected Layer without batches
338 configure_fc_fc(compile_context, input, weights_to_use, biases, output, fc_info);
339 }
340 }
341
validate(const ITensorInfo * input,const ITensorInfo * weights,const ITensorInfo * biases,const ITensorInfo * output,FullyConnectedLayerInfo fc_info)342 Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
343 FullyConnectedLayerInfo fc_info)
344 {
345 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
346 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
347 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
348 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
349 ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(input->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
350 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
351
352 bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
353 bool is_fc_after_conv = true;
354
355 const ITensorInfo &flatten_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)).set_data_layout(DataLayout::NCHW));
356 const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
357 const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
358
359 // With the Fully Connected layer we can have 4 different cases:
360 // 1) Convolution layer -> Fully Connected layer without batches
361 // 2) Fully Connected layer -> Fully Connected layer without batches
362 // 3) Convolution layer -> Fully Connected layer with batches
363 // 4) Fully Connected layer -> Fully Connected layer with batches
364
365 const ITensorInfo *input_to_use = input;
366 const ITensorInfo *weights_to_use = weights;
367
368 // Check if we have a fully connected layer with batches
369 const bool is_batched_fc_layer = output->dimension(1) > 1;
370 if(is_batched_fc_layer)
371 {
372 is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3,
373 input->tensor_shape().cend(),
374 output->tensor_shape().cbegin() + 1));
375 }
376 else
377 {
378 is_fc_after_conv = input->num_dimensions() > 1;
379 }
380
381 if(!weights_reshaped)
382 {
383 // Validate reshape weights kernel
384 ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
385 weights_to_use = &reshaped_weights;
386 }
387
388 if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
389 {
390 // Validate convert weights kernel
391 ARM_COMPUTE_RETURN_ON_ERROR(CLConvertFullyConnectedWeights::validate(weights_to_use,
392 &converted_weights,
393 input->tensor_shape(),
394 fc_info.weights_trained_layout));
395 weights_to_use = &converted_weights;
396 }
397
398 if(is_fc_after_conv)
399 {
400 // Fully Connected layer after a Convolution Layer without batches
401 ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2))));
402
403 // Validate flatten kernel
404 ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &flatten_input));
405 input_to_use = &flatten_input;
406 }
407 else
408 {
409 // Fully Connected layer after a Fully Connected Layer without batches
410 ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
411 }
412
413 // Validate matrix multiply kernel
414 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, biases, *output, fc_info));
415
416 return Status{};
417 }
418
run()419 void CLFullyConnectedLayer::run()
420 {
421 prepare();
422
423 MemoryGroupResourceScope scope_mg(_memory_group);
424
425 // Linearize input if it comes from a convolutional layer
426 if(_is_fc_after_conv)
427 {
428 _flatten_layer.run();
429 }
430
431 // Run matrix multiply
432 if(_is_quantized)
433 {
434 _mm_gemmlowp.run();
435 }
436 else
437 {
438 _mm_gemm.run();
439 }
440 }
441
prepare()442 void CLFullyConnectedLayer::prepare()
443 {
444 if(!_is_prepared)
445 {
446 if(!_weights_manager)
447 {
448 ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
449 }
450
451 auto release_unused = [](CLTensor * w)
452 {
453 if(!w->is_used())
454 {
455 CLScheduler::get().queue().finish();
456 w->allocator()->free();
457 }
458 };
459
460 // Pointer to current weights
461 const ICLTensor *cur_weights = _original_weights;
462
463 // Reshape of the weights if needed (happens only once)
464 if(!_are_weights_reshaped)
465 {
466 if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
467 {
468 cur_weights = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->run(cur_weights, &_reshape_weights_managed_function));
469 }
470 else
471 {
472 // Run reshape weights kernel and mark weights as unused
473 _reshape_weights_output.allocator()->allocate();
474 _reshape_weights_function.run();
475
476 cur_weights->mark_as_unused();
477 cur_weights = &_reshape_weights_output;
478 }
479 _are_weights_reshaped = true;
480 }
481
482 // Convert weights if needed (happens only once)
483 if(!_are_weights_converted)
484 {
485 if(_weights_manager && _weights_manager->are_weights_managed(cur_weights))
486 {
487 _weights_manager->run(cur_weights, &_convert_weights_managed);
488 }
489 else
490 {
491 _converted_weights_output.allocator()->allocate();
492 _convert_weights.run();
493 cur_weights->mark_as_unused();
494 }
495
496 _are_weights_converted = true;
497 }
498
499 // Release reshaped weights if unused
500 release_unused(&_reshape_weights_output);
501
502 // Prepare GEMM prepare and release unused weights
503 if(!_is_quantized)
504 {
505 _mm_gemm.prepare();
506 }
507
508 // Release converted weights if unused
509 release_unused(&_reshape_weights_output);
510 release_unused(&_converted_weights_output);
511
512 _is_prepared = true;
513 }
514 }
515 } // namespace arm_compute
516