• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClConstantWorkload.hpp"
7 
8 #include <Half.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClTensorHandle.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 
13 #include "ClWorkloadUtils.hpp"
14 
15 namespace armnn
16 {
17 
ClConstantWorkloadValidate(const TensorInfo & output)18 arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output)
19 {
20     const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
21 
22     std::array<arm_compute::DataType,8> supportedTypes = {
23             arm_compute::DataType::F16,
24             arm_compute::DataType::F32,
25             arm_compute::DataType::QASYMM8,
26             arm_compute::DataType::QASYMM8_SIGNED,
27             arm_compute::DataType::QSYMM16,
28             arm_compute::DataType::QSYMM8,
29             arm_compute::DataType::QSYMM8_PER_CHANNEL,
30             arm_compute::DataType::S32
31     };
32     auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
33 
34     if (it != end(supportedTypes))
35     {
36         return arm_compute::Status{};
37     }
38     else
39     {
40         return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
41     }
42 }
43 
ClConstantWorkload(const ConstantQueueDescriptor & descriptor,const WorkloadInfo & info)44 ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
45     : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
46     , m_RanOnce(false)
47 {
48 }
49 
Execute() const50 void ClConstantWorkload::Execute() const
51 {
52     ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
53 
54     // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
55     // on the first inference, then reused for subsequent inferences.
56     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
57     // have been configured at the time.
58     if (!m_RanOnce)
59     {
60         const ConstantQueueDescriptor& data = this->m_Data;
61 
62         ARMNN_ASSERT(data.m_LayerOutput != nullptr);
63         arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
64         arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
65 
66         switch (computeDataType)
67         {
68             case arm_compute::DataType::F16:
69             {
70                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
71                 break;
72             }
73             case arm_compute::DataType::F32:
74             {
75                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
76                 break;
77             }
78             case arm_compute::DataType::QASYMM8:
79             {
80                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
81                 break;
82             }
83             case arm_compute::DataType::QASYMM8_SIGNED:
84             {
85                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
86                 break;
87             }
88             case arm_compute::DataType::QSYMM16:
89             {
90                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
91                 break;
92             }
93             case arm_compute::DataType::QSYMM8:
94             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
95             {
96                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
97                 break;
98             }
99             case arm_compute::DataType::S32:
100             {
101                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
102                 break;
103             }
104             default:
105             {
106                 ARMNN_ASSERT_MSG(false, "Unknown data type");
107                 break;
108             }
109         }
110 
111         m_RanOnce = true;
112     }
113 }
114 
115 } //namespace armnn
116