• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClReshapeWorkload.hpp"
7 #include <cl/ClTensorHandle.hpp>
8 #include <backendsCommon/CpuTensorHandle.hpp>
9 
10 #include "ClWorkloadUtils.hpp"
11 
12 namespace armnn
13 {
14 
ClReshapeWorkloadValidate(const TensorInfo & input,const TensorInfo & output)15 arm_compute::Status ClReshapeWorkloadValidate(const TensorInfo& input,
16                                               const TensorInfo& output)
17 {
18     const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
19     const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
20 
21     return arm_compute::CLReshapeLayer::validate(&aclInputInfo, &aclOutputInfo);
22 }
23 
ClReshapeWorkload(const ReshapeQueueDescriptor & descriptor,const WorkloadInfo & info)24 ClReshapeWorkload::ClReshapeWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info)
25     : BaseWorkload<ReshapeQueueDescriptor>(descriptor, info)
26 {
27     m_Data.ValidateInputsOutputs("ClReshapeWorkload", 1, 1);
28 
29     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
30     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
31 
32     m_Layer.configure(&input, &output);
33 }
34 
Execute() const35 void ClReshapeWorkload::Execute() const
36 {
37     ARMNN_SCOPED_PROFILING_EVENT_CL("ClReshapeWorkload_Execute");
38     RunClFunction(m_Layer, CHECK_LOCATION());
39 }
40 
41 } //namespace armnn
42