• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefFullyConnectedWorkload.hpp"
7 
8 #include "FullyConnected.hpp"
9 #include "RefWorkloadUtils.hpp"
10 
11 #include "Profiling.hpp"
12 
13 namespace armnn
14 {
RefFullyConnectedWorkload(const FullyConnectedQueueDescriptor & descriptor,const WorkloadInfo & info)15 RefFullyConnectedWorkload::RefFullyConnectedWorkload(
16     const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
17         : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info),
18           m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight)))
19 {
20     const TensorInfo& rWeightInfo = m_Weight->GetTensorInfo();
21     m_WeightShape = rWeightInfo.GetShape();
22     m_WeightDecoder = MakeDecoder<float>(rWeightInfo, m_Weight->Map(true));
23 
24     if (descriptor.m_Parameters.m_BiasEnabled)
25     {
26         m_Bias = std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias));
27         const TensorInfo& biasInfo = m_Bias->GetTensorInfo();
28         m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true));
29     }
30 }
31 
PostAllocationConfigure()32 void RefFullyConnectedWorkload::PostAllocationConfigure()
33 {
34     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
35     ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1);
36     m_InputShape = inputInfo.GetShape();
37     m_InputDecoder = MakeDecoder<float>(inputInfo);
38 
39     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
40     m_OutputShape = outputInfo.GetShape();
41     m_OutputEncoder = MakeEncoder<float>(outputInfo);
42 
43     m_NumActivations = 1; // Total number of activations in the input.
44     for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++)
45     {
46         m_NumActivations *= inputInfo.GetShape()[i];
47     }
48 }
49 
Execute() const50 void RefFullyConnectedWorkload::Execute() const
51 {
52     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute");
53 
54     m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map());
55     m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map());
56 
57     FullyConnected(m_InputShape,
58                    *m_InputDecoder,
59                    m_OutputShape,
60                    *m_OutputEncoder,
61                    m_WeightShape,
62                    *m_WeightDecoder,
63                    *m_BiasDecoder,
64                    m_Data.m_Parameters.m_BiasEnabled,
65                    m_NumActivations,
66                    m_Data.m_Parameters.m_TransposeWeightMatrix);
67 }
68 
69 } //namespace armnn
70