1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClConstantWorkload.hpp"
7
8 #include <Half.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClTensorHandle.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12
13 #include "ClWorkloadUtils.hpp"
14
15 namespace armnn
16 {
17
ClConstantWorkloadValidate(const TensorInfo & output)18 arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output)
19 {
20 const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
21
22 std::array<arm_compute::DataType,8> supportedTypes = {
23 arm_compute::DataType::F16,
24 arm_compute::DataType::F32,
25 arm_compute::DataType::QASYMM8,
26 arm_compute::DataType::QASYMM8_SIGNED,
27 arm_compute::DataType::QSYMM16,
28 arm_compute::DataType::QSYMM8,
29 arm_compute::DataType::QSYMM8_PER_CHANNEL,
30 arm_compute::DataType::S32
31 };
32 auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
33
34 if (it != end(supportedTypes))
35 {
36 return arm_compute::Status{};
37 }
38 else
39 {
40 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
41 }
42 }
43
ClConstantWorkload(const ConstantQueueDescriptor & descriptor,const WorkloadInfo & info)44 ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
45 : BaseWorkload<ConstantQueueDescriptor>(descriptor, info)
46 , m_RanOnce(false)
47 {
48 }
49
Execute() const50 void ClConstantWorkload::Execute() const
51 {
52 ARMNN_SCOPED_PROFILING_EVENT_CL("ClConstantWorkload_Execute");
53
54 // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
55 // on the first inference, then reused for subsequent inferences.
56 // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
57 // have been configured at the time.
58 if (!m_RanOnce)
59 {
60 const ConstantQueueDescriptor& data = this->m_Data;
61
62 ARMNN_ASSERT(data.m_LayerOutput != nullptr);
63 arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
64 arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
65
66 switch (computeDataType)
67 {
68 case arm_compute::DataType::F16:
69 {
70 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
71 break;
72 }
73 case arm_compute::DataType::F32:
74 {
75 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
76 break;
77 }
78 case arm_compute::DataType::QASYMM8:
79 {
80 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
81 break;
82 }
83 case arm_compute::DataType::QASYMM8_SIGNED:
84 {
85 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
86 break;
87 }
88 case arm_compute::DataType::QSYMM16:
89 {
90 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
91 break;
92 }
93 case arm_compute::DataType::QSYMM8:
94 case arm_compute::DataType::QSYMM8_PER_CHANNEL:
95 {
96 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
97 break;
98 }
99 case arm_compute::DataType::S32:
100 {
101 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
102 break;
103 }
104 default:
105 {
106 ARMNN_ASSERT_MSG(false, "Unknown data type");
107 break;
108 }
109 }
110
111 m_RanOnce = true;
112 }
113 }
114
115 } //namespace armnn
116