• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClBatchNormalizationFloatWorkload.hpp"
7 #include "ClWorkloadUtils.hpp"
8 
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <aclCommon/ArmComputeUtils.hpp>
11 #include <backendsCommon/CpuTensorHandle.hpp>
12 #include <cl/ClLayerSupport.hpp>
13 #include <cl/ClTensorHandle.hpp>
14 
15 namespace armnn
16 {
17 using namespace armcomputetensorutils;
18 
ClBatchNormalizationValidate(const TensorInfo & input,const TensorInfo & output,const TensorInfo & mean,const TensorInfo & var,const TensorInfo & beta,const TensorInfo & gamma,const BatchNormalizationDescriptor & desc,const ActivationDescriptor * activationDescriptor)19 arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
20                                                  const TensorInfo& output,
21                                                  const TensorInfo& mean,
22                                                  const TensorInfo& var,
23                                                  const TensorInfo& beta,
24                                                  const TensorInfo& gamma,
25                                                  const BatchNormalizationDescriptor& desc,
26                                                  const ActivationDescriptor* activationDescriptor)
27 {
28     const arm_compute::TensorInfo aclInputInfo =
29           armcomputetensorutils::BuildArmComputeTensorInfo(input, desc.m_DataLayout);
30     const arm_compute::TensorInfo aclOutputInfo =
31           armcomputetensorutils::BuildArmComputeTensorInfo(output, desc.m_DataLayout);
32     const arm_compute::TensorInfo aclMeanInfo =
33           armcomputetensorutils::BuildArmComputeTensorInfo(mean, desc.m_DataLayout);
34     const arm_compute::TensorInfo aclVarInfo =
35           armcomputetensorutils::BuildArmComputeTensorInfo(var, desc.m_DataLayout);
36     const arm_compute::TensorInfo aclBetaInfo =
37           armcomputetensorutils::BuildArmComputeTensorInfo(beta, desc.m_DataLayout);
38     const arm_compute::TensorInfo aclGammaInfo =
39           armcomputetensorutils::BuildArmComputeTensorInfo(gamma, desc.m_DataLayout);
40 
41     const arm_compute::ActivationLayerInfo activationInfo = ConvertActivationDescriptorToAclActivationLayerInfo(
42             activationDescriptor);
43 
44     return arm_compute::CLBatchNormalizationLayer::validate(&aclInputInfo,
45                                                             &aclOutputInfo,
46                                                             &aclMeanInfo,
47                                                             &aclVarInfo,
48                                                             &aclBetaInfo,
49                                                             &aclGammaInfo,
50                                                             desc.m_Eps,
51                                                             activationInfo);
52 }
53 
ClBatchNormalizationFloatWorkload(const BatchNormalizationQueueDescriptor & descriptor,const WorkloadInfo & info)54 ClBatchNormalizationFloatWorkload::ClBatchNormalizationFloatWorkload(
55     const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info)
56     : FloatWorkload<BatchNormalizationQueueDescriptor>(descriptor, info)
57 {
58     m_Mean = std::make_unique<arm_compute::CLTensor>();
59     BuildArmComputeTensor(*m_Mean, m_Data.m_Mean->GetTensorInfo());
60 
61     m_Variance = std::make_unique<arm_compute::CLTensor>();
62     BuildArmComputeTensor(*m_Variance, m_Data.m_Variance->GetTensorInfo());
63 
64     m_Gamma = std::make_unique<arm_compute::CLTensor>();
65     BuildArmComputeTensor(*m_Gamma, m_Data.m_Gamma->GetTensorInfo());
66 
67     m_Beta = std::make_unique<arm_compute::CLTensor>();
68     BuildArmComputeTensor(*m_Beta, m_Data.m_Beta->GetTensorInfo());
69 
70     m_Data.ValidateInputsOutputs("ClBatchNormalizationFloatWorkload", 1, 1);
71 
72     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
73     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
74 
75     arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
76     input.info()->set_data_layout(aclDataLayout);
77     output.info()->set_data_layout(aclDataLayout);
78 
79     const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
80 
81     m_Layer.configure(&input,
82                       &output,
83                       m_Mean.get(),
84                       m_Variance.get(),
85                       m_Beta.get(),
86                       m_Gamma.get(),
87                       m_Data.m_Parameters.m_Eps,
88                       activationInfo);
89 
90     InitializeArmComputeClTensorData(*m_Mean, m_Data.m_Mean);
91     InitializeArmComputeClTensorData(*m_Variance, m_Data.m_Variance);
92     InitializeArmComputeClTensorData(*m_Beta, m_Data.m_Beta);
93     InitializeArmComputeClTensorData(*m_Gamma, m_Data.m_Gamma);
94 
95     // Force Compute Library to perform the necessary copying and reshaping, after which
96     // delete all the input tensors that will no longer be needed
97     m_Layer.prepare();
98     FreeUnusedTensors();
99 }
100 
Execute() const101 void ClBatchNormalizationFloatWorkload::Execute() const
102 {
103     ARMNN_SCOPED_PROFILING_EVENT_CL("ClBatchNormalizationFloatWorkload_Execute");
104     RunClFunction(m_Layer, CHECK_LOCATION());
105 }
106 
FreeUnusedTensors()107 void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
108 {
109     FreeTensorIfUnused(m_Mean);
110     FreeTensorIfUnused(m_Variance);
111     FreeTensorIfUnused(m_Gamma);
112     FreeTensorIfUnused(m_Beta);
113 }
114 
115 } //namespace armnn
116