• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2017-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
25 
26 #include "arm_compute/core/Size2D.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
29 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
30 #include "arm_compute/runtime/CL/CLScheduler.h"
31 #include "src/core/CL/kernels/CLDepthConvertLayerKernel.h"
32 #include "src/core/CL/kernels/CLFillBorderKernel.h"
33 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
34 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
35 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h"
36 #include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h"
37 #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h"
38 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
39 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
40 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
41 #include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
42 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
43 #include "src/core/CL/kernels/CLTransposeKernel.h"
44 #include "support/Cast.h"
45 #include "support/MemorySupport.h"
46 
47 #include <algorithm>
48 
49 namespace arm_compute
50 {
51 using namespace arm_compute::misc::shape_calculator;
52 using namespace arm_compute::utils::cast;
53 
54 namespace
55 {
construct_gemmlowp_output_stage(const ITensorInfo & input,const ITensorInfo & weights,const ITensorInfo & output,GEMMLowpOutputStageInfo & gemmlowp_output_stage,ActivationLayerInfo activation_info)56 Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output,
57                                        GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info)
58 {
59     gemmlowp_output_stage.type                = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
60     gemmlowp_output_stage.gemmlowp_offset     = 0;
61     gemmlowp_output_stage.gemmlowp_multiplier = 0;
62     gemmlowp_output_stage.gemmlowp_shift      = 0;
63 
64     const auto data_type = input.data_type();
65 
66     // Configure output stage for quantized case
67     if(is_data_type_quantized_asymmetric(data_type))
68     {
69         const QuantizationInfo        oq_info = output.quantization_info();
70         const UniformQuantizationInfo iq_unif = input.quantization_info().uniform();
71         const UniformQuantizationInfo wq_unif = weights.quantization_info().uniform();
72         const UniformQuantizationInfo oq_unif = oq_info.uniform();
73 
74         const auto output_quant_info = (output.total_size() == 0) ? iq_unif : oq_unif;
75 
76         const float multiplier        = (iq_unif.scale * wq_unif.scale) / output_quant_info.scale;
77         int         output_multiplier = 0;
78         int         output_shift      = 0;
79         ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
80 
81         PixelValue type_min{};
82         PixelValue type_max{};
83         std::tie(type_min, type_max) = get_min_max(data_type);
84 
85         if(activation_info.enabled())
86         {
87             std::tie(type_min, type_max) = get_quantized_activation_min_max(activation_info, data_type, output_quant_info);
88         }
89 
90         // Set the GEMMLowp output stage info
91         gemmlowp_output_stage.gemmlowp_offset     = output_quant_info.offset;
92         gemmlowp_output_stage.gemmlowp_multiplier = output_multiplier;
93         gemmlowp_output_stage.gemmlowp_shift      = output_shift;
94         gemmlowp_output_stage.gemmlowp_multipliers.push_back(output_multiplier);
95         gemmlowp_output_stage.gemmlowp_shifts.push_back(output_shift);
96         type_min.get(gemmlowp_output_stage.gemmlowp_min_bound);
97         type_max.get(gemmlowp_output_stage.gemmlowp_max_bound);
98     }
99 
100     return Status{};
101 }
102 
validate_mm(const ITensorInfo & input,const ITensorInfo & weights,const ITensorInfo * bias,const ITensorInfo & output,const FullyConnectedLayerInfo & fc_info)103 Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &output, const FullyConnectedLayerInfo &fc_info)
104 {
105     GEMMLowpOutputStageInfo gemmlowp_output_stage;
106     ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage, fc_info.activation_info));
107 
108     const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
109                                          false,                           // is_b_reshaped
110                                          true,                            // reshape_b_only_on_first_run
111                                          0,                               // depth_output_gemm3d
112                                          false,                           // reinterpret_input_as_3d
113                                          fc_info.retain_internal_weights, // retain_internal_weights
114                                          gemmlowp_output_stage,           // gemmlowp_output_stage
115                                          fc_info.fp_mixed_precision,      // fp_mixed_precision
116                                          true,                            // broadcast_bias
117                                          ActivationLayerInfo());          // activation_info
118 
119     if(is_data_type_quantized_asymmetric(input.data_type()))
120     {
121         const UniformQuantizationInfo iq_info = input.quantization_info().uniform();
122         const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
123 
124         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
125         // Extract and negate input and weights offset
126         const QuantizationInfo input_quantization_info(iq_info.scale, -iq_info.offset);
127         const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset);
128 
129         // Validate gemmlowp function
130         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyCore::validate(&input.clone()->set_quantization_info(input_quantization_info),
131                                                                            &weights.clone()->set_quantization_info(weights_quantization_info),
132                                                                            bias,
133                                                                            &output,
134                                                                            gemm_info));
135     }
136     else
137     {
138         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input, &weights, bias, &output, 1.f, 1.f, gemm_info));
139     }
140 
141     return Status{};
142 }
143 } // namespace
144 
configure(const ICLTensor * input,ICLTensor * output)145 void CLFullyConnectedLayerReshapeWeights::configure(const ICLTensor *input, ICLTensor *output)
146 {
147     configure(CLKernelLibrary::get().get_compile_context(), input, output);
148 }
149 
configure(const CLCompileContext & compile_context,const ICLTensor * input,ICLTensor * output)150 void CLFullyConnectedLayerReshapeWeights::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output)
151 {
152     auto k = arm_compute::support::cpp14::make_unique<CLTransposeKernel>();
153     k->configure(compile_context, input, output);
154     _kernel = std::move(k);
155 }
156 
validate(const ITensorInfo * input,const ITensorInfo * output)157 Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, const ITensorInfo *output)
158 {
159     return CLTransposeKernel::validate(input, output);
160 }
161 
CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager,IWeightsManager * weights_manager)162 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
163     : _memory_group(memory_manager), _weights_manager(weights_manager), _convert_weights(), _convert_weights_managed(), _reshape_weights_managed_function(), _flatten_layer(), _reshape_weights_function(),
164       _mm_gemm(memory_manager, weights_manager), _mm_gemmlowp(memory_manager), _flatten_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true),
165       _are_weights_reshaped(true), _is_fc_after_conv(true), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
166 {
167 }
configure_mm(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)168 void CLFullyConnectedLayer::configure_mm(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
169                                          const FullyConnectedLayerInfo &fc_info)
170 {
171     GEMMLowpOutputStageInfo gemmlowp_output_stage;
172     construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage, fc_info.activation_info);
173 
174     const GEMMInfo &gemm_info = GEMMInfo(false,                           // is_a_reshaped
175                                          false,                           // is_b_reshaped
176                                          true,                            // reshape_b_only_on_first_run
177                                          0,                               // depth_output_gemm3d
178                                          false,                           // reinterpret_input_as_3d
179                                          fc_info.retain_internal_weights, // retain_internal_weights
180                                          gemmlowp_output_stage,           // gemmlowp_output_stage
181                                          fc_info.fp_mixed_precision,      // fp_mixed_precision
182                                          true,                            // broadcast_bias
183                                          fc_info.activation_info);        // activation_info
184 
185     if(_is_quantized)
186     {
187         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
188         // Extract and negate input and weights offset
189         const QuantizationInfo input_quantization_info   = input->info()->quantization_info();
190         const QuantizationInfo weights_quantization_info = weights->info()->quantization_info();
191 
192         input->info()->set_quantization_info(QuantizationInfo(input_quantization_info.uniform().scale, -input_quantization_info.uniform().offset));
193         weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
194 
195         // Configure gemmlowp function
196         _mm_gemmlowp.configure(compile_context, input, weights, bias, output, gemm_info);
197 
198         // Revert back QuantizatioInfo as input and weights could be used in other fully connected layers
199         input->info()->set_quantization_info(input_quantization_info);
200         weights->info()->set_quantization_info(weights_quantization_info);
201     }
202     else
203     {
204         // Configure matrix multiply kernel
205         _mm_gemm.configure(compile_context, input, weights, bias, output, 1.f, 1.f, gemm_info);
206     }
207 }
208 
configure_conv_fc(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)209 void CLFullyConnectedLayer::configure_conv_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
210                                               const FullyConnectedLayerInfo &fc_info)
211 {
212     ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
213 
214     // If the fully connected layer is called after a convolution layer, the input tensor must be linearized
215 
216     // Initialize output tensor for flatten
217     TensorShape shape_flatten = compute_flatten_shape(input->info());
218     _flatten_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_flatten).set_data_layout(DataLayout::NCHW));
219 
220     // Configure flatten kernel
221     _memory_group.manage(&_flatten_output);
222     _flatten_layer.configure(compile_context, input, &_flatten_output);
223 
224     // Configure matrix multiply kernel
225     configure_mm(compile_context, &_flatten_output, weights, bias, output, fc_info);
226 
227     // Allocate the output tensor for flatten once all the configure methods have been called
228     _flatten_output.allocator()->allocate();
229 }
230 
configure_fc_fc(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * bias,ICLTensor * output,const FullyConnectedLayerInfo & fc_info)231 void CLFullyConnectedLayer::configure_fc_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
232                                             const FullyConnectedLayerInfo &fc_info)
233 {
234     ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
235 
236     // Configure matrix multiply kernel
237     configure_mm(compile_context, input, weights, bias, output, fc_info);
238 }
239 
configure(const ICLTensor * input,const ICLTensor * weights,const ICLTensor * biases,ICLTensor * output,FullyConnectedLayerInfo fc_info)240 void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
241                                       FullyConnectedLayerInfo fc_info)
242 {
243     configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, fc_info);
244 }
245 
configure(const CLCompileContext & compile_context,const ICLTensor * input,const ICLTensor * weights,const ICLTensor * biases,ICLTensor * output,FullyConnectedLayerInfo fc_info)246 void CLFullyConnectedLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
247                                       FullyConnectedLayerInfo fc_info)
248 {
249     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
250 
251     // Perform validate step
252     ARM_COMPUTE_ERROR_THROW_ON(CLFullyConnectedLayer::validate(input->info(),
253                                                                weights->info(),
254                                                                biases != nullptr ? biases->info() : nullptr,
255                                                                output->info(),
256                                                                fc_info));
257 
258     _are_weights_converted = true;
259     _are_weights_reshaped  = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
260     _is_fc_after_conv      = true;
261     _is_quantized          = is_data_type_quantized_asymmetric(input->info()->data_type());
262     _is_prepared           = fc_info.retain_internal_weights;
263     _original_weights      = weights;
264 
265     if(_weights_manager)
266     {
267         _weights_manager->manage(weights);
268     }
269 
270     const ICLTensor *weights_to_use = weights;
271 
272     // With the Fully Connected layer we can have 4 different cases:
273     //  1) Convolution layer -> Fully Connected layer without batches
274     //  2) Fully Connected layer -> Fully Connected layer without batches
275     //  3) Convolution layer -> Fully Connected layer with batches
276     //  4) Fully Connected layer -> Fully Connected layer with batches
277 
278     // Check if we have a fully connected layer with batches
279     const bool is_batched_fc_layer = output->info()->dimension(1) > 1;
280     if(is_batched_fc_layer)
281     {
282         _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->info()->tensor_shape().cbegin() + 3,
283                                                                                   input->info()->tensor_shape().cend(),
284                                                                                   output->info()->tensor_shape().cbegin() + 1));
285     }
286     else
287     {
288         _is_fc_after_conv = input->info()->num_dimensions() > 1;
289     }
290 
291     // Reshape weights if needed
292     if(!_are_weights_reshaped)
293     {
294         if(_weights_manager && _weights_manager->are_weights_managed(weights))
295         {
296             _reshape_weights_managed_function.configure(compile_context, weights);
297             weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed_function));
298         }
299         else
300         {
301             // Reshape the weights
302             _reshape_weights_function.configure(compile_context, weights, &_reshape_weights_output);
303             weights_to_use = &_reshape_weights_output;
304         }
305     }
306 
307     // Convert weights if needed
308     if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
309     {
310         if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
311         {
312             _convert_weights_managed.configure(compile_context, weights_to_use,
313                                                input->info()->tensor_shape(),
314                                                fc_info.weights_trained_layout);
315             weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_convert_weights_managed));
316         }
317         else
318         {
319             // Convert weights
320             _convert_weights.configure(compile_context, weights_to_use,
321                                        &_converted_weights_output,
322                                        input->info()->tensor_shape(),
323                                        fc_info.weights_trained_layout);
324 
325             weights_to_use = &_converted_weights_output;
326         }
327         _are_weights_converted = false;
328     }
329 
330     if(_is_fc_after_conv)
331     {
332         // Fully Connected layer after a Convolution Layer without batches
333         configure_conv_fc(compile_context, input, weights_to_use, biases, output, fc_info);
334     }
335     else
336     {
337         // Fully Connected layer after a Fully Connected Layer without batches
338         configure_fc_fc(compile_context, input, weights_to_use, biases, output, fc_info);
339     }
340 }
341 
validate(const ITensorInfo * input,const ITensorInfo * weights,const ITensorInfo * biases,const ITensorInfo * output,FullyConnectedLayerInfo fc_info)342 Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
343                                        FullyConnectedLayerInfo fc_info)
344 {
345     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
346     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
347     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
348     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
349     ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(input->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
350                                 && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
351 
352     bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
353     bool is_fc_after_conv = true;
354 
355     const ITensorInfo &flatten_input     = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(input)).set_data_layout(DataLayout::NCHW));
356     const ITensorInfo &reshaped_weights  = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
357     const ITensorInfo &converted_weights = weights_reshaped ? TensorInfo(weights->clone()->set_is_resizable(true).reset_padding()) : TensorInfo(*reshaped_weights.clone());
358 
359     // With the Fully Connected layer we can have 4 different cases:
360     //  1) Convolution layer -> Fully Connected layer without batches
361     //  2) Fully Connected layer -> Fully Connected layer without batches
362     //  3) Convolution layer -> Fully Connected layer with batches
363     //  4) Fully Connected layer -> Fully Connected layer with batches
364 
365     const ITensorInfo *input_to_use   = input;
366     const ITensorInfo *weights_to_use = weights;
367 
368     // Check if we have a fully connected layer with batches
369     const bool is_batched_fc_layer = output->dimension(1) > 1;
370     if(is_batched_fc_layer)
371     {
372         is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(input->tensor_shape().cbegin() + 3,
373                                                                                  input->tensor_shape().cend(),
374                                                                                  output->tensor_shape().cbegin() + 1));
375     }
376     else
377     {
378         is_fc_after_conv = input->num_dimensions() > 1;
379     }
380 
381     if(!weights_reshaped)
382     {
383         // Validate reshape weights kernel
384         ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayerReshapeWeights::validate(weights, &reshaped_weights));
385         weights_to_use = &reshaped_weights;
386     }
387 
388     if(is_fc_after_conv && (input->data_layout() != fc_info.weights_trained_layout))
389     {
390         // Validate convert weights kernel
391         ARM_COMPUTE_RETURN_ON_ERROR(CLConvertFullyConnectedWeights::validate(weights_to_use,
392                                                                              &converted_weights,
393                                                                              input->tensor_shape(),
394                                                                              fc_info.weights_trained_layout));
395         weights_to_use = &converted_weights;
396     }
397 
398     if(is_fc_after_conv)
399     {
400         // Fully Connected layer after a Convolution Layer without batches
401         ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (input->dimension(0) * input->dimension(1) * input->dimension(2))));
402 
403         // Validate flatten kernel
404         ARM_COMPUTE_RETURN_ON_ERROR(CLFlattenLayer::validate(input, &flatten_input));
405         input_to_use = &flatten_input;
406     }
407     else
408     {
409         // Fully Connected layer after a Fully Connected Layer without batches
410         ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
411     }
412 
413     // Validate matrix multiply kernel
414     ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, biases, *output, fc_info));
415 
416     return Status{};
417 }
418 
run()419 void CLFullyConnectedLayer::run()
420 {
421     prepare();
422 
423     MemoryGroupResourceScope scope_mg(_memory_group);
424 
425     // Linearize input if it comes from a convolutional layer
426     if(_is_fc_after_conv)
427     {
428         _flatten_layer.run();
429     }
430 
431     // Run matrix multiply
432     if(_is_quantized)
433     {
434         _mm_gemmlowp.run();
435     }
436     else
437     {
438         _mm_gemm.run();
439     }
440 }
441 
prepare()442 void CLFullyConnectedLayer::prepare()
443 {
444     if(!_is_prepared)
445     {
446         if(!_weights_manager)
447         {
448             ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
449         }
450 
451         auto release_unused = [](CLTensor * w)
452         {
453             if(!w->is_used())
454             {
455                 CLScheduler::get().queue().finish();
456                 w->allocator()->free();
457             }
458         };
459 
460         // Pointer to current weights
461         const ICLTensor *cur_weights = _original_weights;
462 
463         // Reshape of the weights if needed (happens only once)
464         if(!_are_weights_reshaped)
465         {
466             if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
467             {
468                 cur_weights = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->run(cur_weights, &_reshape_weights_managed_function));
469             }
470             else
471             {
472                 // Run reshape weights kernel and mark weights as unused
473                 _reshape_weights_output.allocator()->allocate();
474                 _reshape_weights_function.run();
475 
476                 cur_weights->mark_as_unused();
477                 cur_weights = &_reshape_weights_output;
478             }
479             _are_weights_reshaped = true;
480         }
481 
482         // Convert weights if needed (happens only once)
483         if(!_are_weights_converted)
484         {
485             if(_weights_manager && _weights_manager->are_weights_managed(cur_weights))
486             {
487                 _weights_manager->run(cur_weights, &_convert_weights_managed);
488             }
489             else
490             {
491                 _converted_weights_output.allocator()->allocate();
492                 _convert_weights.run();
493                 cur_weights->mark_as_unused();
494             }
495 
496             _are_weights_converted = true;
497         }
498 
499         // Release reshaped weights if unused
500         release_unused(&_reshape_weights_output);
501 
502         // Prepare GEMM prepare and release unused weights
503         if(!_is_quantized)
504         {
505             _mm_gemm.prepare();
506         }
507 
508         // Release converted weights if unused
509         release_unused(&_reshape_weights_output);
510         release_unused(&_converted_weights_output);
511 
512         _is_prepared = true;
513     }
514 }
515 } // namespace arm_compute
516