1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClBatchNormalizationFloatWorkload.hpp"
7 #include "ClWorkloadUtils.hpp"
8
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <aclCommon/ArmComputeUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <cl/ClLayerSupport.hpp>
13 #include <cl/ClTensorHandle.hpp>
14
15 namespace armnn
16 {
17 using namespace armcomputetensorutils;
18
ClBatchNormalizationValidate(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & var,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & desc,const ActivationDescriptor * activationDescriptor)19 arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const TensorInfo& mean,
22 const TensorInfo& var,
23 const TensorInfo& beta,
24 const TensorInfo& gamma,
25 const BatchNormalizationDescriptor& desc,
26 const ActivationDescriptor* activationDescriptor)
27 {
28 const arm_compute::TensorInfo aclInputInfo =
29 armcomputetensorutils::BuildArmComputeTensorInfo(input, desc.m_DataLayout);
30 const arm_compute::TensorInfo aclOutputInfo =
31 armcomputetensorutils::BuildArmComputeTensorInfo(output, desc.m_DataLayout);
32 const arm_compute::TensorInfo aclMeanInfo =
33 armcomputetensorutils::BuildArmComputeTensorInfo(mean, desc.m_DataLayout);
34 const arm_compute::TensorInfo aclVarInfo =
35 armcomputetensorutils::BuildArmComputeTensorInfo(var, desc.m_DataLayout);
36 const arm_compute::TensorInfo aclBetaInfo =
37 armcomputetensorutils::BuildArmComputeTensorInfo(beta, desc.m_DataLayout);
38 const arm_compute::TensorInfo aclGammaInfo =
39 armcomputetensorutils::BuildArmComputeTensorInfo(gamma, desc.m_DataLayout);
40
41 const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
42 activationDescriptor);
43
44 return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
45 &aclOutputInfo,
46 &aclMeanInfo,
47 &aclVarInfo,
48 &aclBetaInfo,
49 &aclGammaInfo,
50 desc.m_Eps,
51 activationInfo);
52 }
53
ClBatchNormalizationFloatWorkload(const BatchNormalizationQueueDescriptor & descriptor,const WorkloadInfo & info)54 ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
55 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
56 : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
57 {
58 m_Mean = std::make_unique<arm_compute::CLTensor>();
59 BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
60
61 m_Variance = std::make_unique<arm_compute::CLTensor>();
62 BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
63
64 m_Gamma = std::make_unique<arm_compute::CLTensor>();
65 BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
66
67 m_Beta = std::make_unique<arm_compute::CLTensor>();
68 BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
69
70 m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
71
72 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
73 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
74
75 arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
76 input.info()->set_data_layout(aclDataLayout);
77 output.info()->set_data_layout(aclDataLayout);
78
79 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
80
81 m_Layer.configure(&input,
82 &output,
83 m_Mean.get(),
84 m_Variance.get(),
85 m_Beta.get(),
86 m_Gamma.get(),
87 m_Data.m_Parameters.m_Eps,
88 activationInfo);
89
90 InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
91 InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
92 InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
93 InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
94
95 // Force Compute Library to perform the necessary copying and reshaping, after which
96 // delete all the input tensors that will no longer be needed
97 m_Layer.prepare();
98 FreeUnusedTensors();
99 }
100
Execute() const101 void ClBatchNormalizationFloatWorkload::Execute() const
102 {
103 ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
104 RunClFunction(m_Layer, CHECK_LOCATION());
105 }
106
FreeUnusedTensors()107 void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
108 {
109 FreeTensorIfUnused(m_Mean);
110 FreeTensorIfUnused(m_Variance);
111 FreeTensorIfUnused(m_Gamma);
112 FreeTensorIfUnused(m_Beta);
113 }
114
115 } //namespace armnn
116