1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefFullyConnectedWorkload.hpp"
7
8 #include "FullyConnected.hpp"
9 #include "RefWorkloadUtils.hpp"
10
11 #include "Profiling.hpp"
12
13 namespace armnn
14 {
RefFullyConnectedWorkload(const FullyConnectedQueueDescriptor & descriptor,const WorkloadInfo & info)15 RefFullyConnectedWorkload::RefFullyConnectedWorkload(
16 const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
17 : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
18 m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
19 {
20 const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
21 m_WeightShape = rWeightInfo.GetShape();
22 m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
23
24 if (descriptor.m_Parameters.m_BiasEnabled)
25 {
26 m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
27 const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
28 m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
29 }
30 }
31
PostAllocationConfigure()32 void RefFullyConnectedWorkload::PostAllocationConfigure()
33 {
34 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
35 ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
36 m_InputShape = inputInfo.GetShape();
37 m_InputDecoder = MakeDecoder<float>(inputInfo);
38
39 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
40 m_OutputShape = outputInfo.GetShape();
41 m_OutputEncoder = MakeEncoder<float>(outputInfo);
42
43 m_NumActivations = 1; // Total number of activations in the input.
44 for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
45 {
46 m_NumActivations *= inputInfo.GetShape()[i];
47 }
48 }
49
Execute() const50 void RefFullyConnectedWorkload::Execute() const
51 {
52 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
53
54 m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
55 m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
56
57 FullyConnected(m_InputShape,
58 *m_InputDecoder,
59 m_OutputShape,
60 *m_OutputEncoder,
61 m_WeightShape,
62 *m_WeightDecoder,
63 *m_BiasDecoder,
64 m_Data.m_Parameters.m_BiasEnabled,
65 m_NumActivations,
66 m_Data.m_Parameters.m_TransposeWeightMatrix);
67 }
68
69 } //namespace armnn
70