• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClSpaceToBatchNdWorkload.hpp"
7 
8 #include "ClWorkloadUtils.hpp"
9 
10 #include <aclCommon/ArmComputeUtils.hpp>
11 #include <aclCommon/ArmComputeTensorUtils.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnn/utility/PolymorphicDowncast.hpp>
14 #include <backendsCommon/CpuTensorHandle.hpp>
15 #include <cl/ClLayerSupport.hpp>
16 #include <cl/ClTensorHandle.hpp>
17 #include <cl/ClLayerSupport.hpp>
18 
19 namespace armnn
20 {
21 using namespace armcomputetensorutils;
22 
ClSpaceToBatchNdWorkloadValidate(const TensorInfo & input,const TensorInfo & output,const SpaceToBatchNdDescriptor & descriptor)23 arm_compute::Status ClSpaceToBatchNdWorkloadValidate(const TensorInfo& input,
24                                                      const TensorInfo& output,
25                                                      const SpaceToBatchNdDescriptor& descriptor)
26 {
27     const arm_compute::TensorInfo aclInputInfo  = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
28     const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
29 
30     // ArmNN blockShape is [H, W] Cl asks for W, H
31     int32_t blockHeight = armnn::numeric_cast<int32_t>(descriptor.m_BlockShape[0]);
32     int32_t blockWidth  = armnn::numeric_cast<int32_t>(descriptor.m_BlockShape[1]);
33 
34     arm_compute::Size2D paddingLeftTop = BuildArmComputeSize2D(
35         descriptor.m_PadList[1].first, descriptor.m_PadList[0].first);
36     arm_compute::Size2D paddingRightBottom  = BuildArmComputeSize2D(
37         descriptor.m_PadList[1].second, descriptor.m_PadList[0].second);
38 
39     return arm_compute::CLSpaceToBatchLayer::validate(&aclInputInfo,
40                                                       blockWidth,
41                                                       blockHeight,
42                                                       paddingLeftTop,
43                                                       paddingRightBottom,
44                                                       &aclOutputInfo);
45 }
46 
ClSpaceToBatchNdWorkload(const SpaceToBatchNdQueueDescriptor & descriptor,const WorkloadInfo & info)47 ClSpaceToBatchNdWorkload::ClSpaceToBatchNdWorkload(
48     const SpaceToBatchNdQueueDescriptor& descriptor, const WorkloadInfo& info)
49     : BaseWorkload<SpaceToBatchNdQueueDescriptor>(descriptor, info)
50 {
51     m_Data.ValidateInputsOutputs("ClSpaceToBatchNdWorkload", 1, 1);
52 
53     arm_compute::ICLTensor& input  =
54         armnn::PolymorphicPointerDowncast<IClTensorHandle>(m_Data.m_Inputs[0])->GetTensor();
55     arm_compute::ICLTensor& output =
56         armnn::PolymorphicPointerDowncast<IClTensorHandle>(m_Data.m_Outputs[0])->GetTensor();
57 
58     // ArmNN blockShape is [H, W] Cl asks for W, H
59     int32_t blockHeight = armnn::numeric_cast<int32_t>(m_Data.m_Parameters.m_BlockShape[0]);
60     int32_t blockWidth  = armnn::numeric_cast<int32_t>(m_Data.m_Parameters.m_BlockShape[1]);
61 
62     arm_compute::Size2D paddingLeftTop = BuildArmComputeSize2D(
63         m_Data.m_Parameters.m_PadList[1].first, m_Data.m_Parameters.m_PadList[0].first);
64     arm_compute::Size2D paddingRightBottom  = BuildArmComputeSize2D(
65         m_Data.m_Parameters.m_PadList[1].second, m_Data.m_Parameters.m_PadList[0].second);
66 
67     arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
68     input.info()->set_data_layout(aclDataLayout);
69     output.info()->set_data_layout(aclDataLayout);
70 
71     m_SpaceToBatchLayer.configure(&input,
72                                   blockWidth,
73                                   blockHeight,
74                                   paddingLeftTop,
75                                   paddingRightBottom,
76                                   &output);
77 }
78 
Execute() const79 void ClSpaceToBatchNdWorkload::Execute() const
80 {
81     ARMNN_SCOPED_PROFILING_EVENT_CL("ClSpaceToBatchNdWorkload_Execute");
82     RunClFunction(m_SpaceToBatchLayer, CHECK_LOCATION());
83 }
84 
85 } //namespace armnn
86