1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <backendsCommon/CpuTensorHandle.hpp>
9
10 #include <armnn/Tensor.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13
14 #include <reference/RefTensorHandle.hpp>
15
16 #include <BFloat16.hpp>
17 #include <Half.hpp>
18
19 namespace armnn
20 {
21
22 ////////////////////////////////////////////
23 /// float32 helpers
24 ////////////////////////////////////////////
25
GetTensorInfo(const ITensorHandle * tensorHandle)26 inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
27 {
28 // We know that reference workloads use RefTensorHandles for inputs and outputs
29 const RefTensorHandle* refTensorHandle =
30 PolymorphicDowncast<const RefTensorHandle*>(tensorHandle);
31 return refTensorHandle->GetTensorInfo();
32 }
33
34 template <typename DataType, typename PayloadType>
GetInputTensorData(unsigned int idx,const PayloadType & data)35 const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
36 {
37 const ITensorHandle* tensorHandle = data.m_Inputs[idx];
38 return reinterpret_cast<const DataType*>(tensorHandle->Map());
39 }
40
41 template <typename DataType, typename PayloadType>
GetOutputTensorData(unsigned int idx,const PayloadType & data)42 DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
43 {
44 ITensorHandle* tensorHandle = data.m_Outputs[idx];
45 return reinterpret_cast<DataType*>(tensorHandle->Map());
46 }
47
48 template <typename PayloadType>
GetInputTensorDataFloat(unsigned int idx,const PayloadType & data)49 const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
50 {
51 return GetInputTensorData<float>(idx, data);
52 }
53
54 template <typename PayloadType>
GetOutputTensorDataFloat(unsigned int idx,const PayloadType & data)55 float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
56 {
57 return GetOutputTensorData<float>(idx, data);
58 }
59
60 template <typename PayloadType>
GetInputTensorDataHalf(unsigned int idx,const PayloadType & data)61 const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
62 {
63 return GetInputTensorData<Half>(idx, data);
64 }
65
66 template <typename PayloadType>
GetOutputTensorDataHalf(unsigned int idx,const PayloadType & data)67 Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
68 {
69 return GetOutputTensorData<Half>(idx, data);
70 }
71
72 template <typename PayloadType>
GetInputTensorDataBFloat16(unsigned int idx,const PayloadType & data)73 const BFloat16* GetInputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
74 {
75 return GetInputTensorData<BFloat16>(idx, data);
76 }
77
78 template <typename PayloadType>
GetOutputTensorDataBFloat16(unsigned int idx,const PayloadType & data)79 BFloat16* GetOutputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
80 {
81 return GetOutputTensorData<BFloat16>(idx, data);
82 }
83
84 ////////////////////////////////////////////
85 /// u8 helpers
86 ////////////////////////////////////////////
87
88 template<typename T>
Dequantize(const T * quant,const TensorInfo & info)89 std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
90 {
91 std::vector<float> ret(info.GetNumElements());
92 for (size_t i = 0; i < info.GetNumElements(); i++)
93 {
94 ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
95 }
96 return ret;
97 }
98
99 template<typename T>
Dequantize(const T * inputData,float * outputData,const TensorInfo & info)100 inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info)
101 {
102 for (unsigned int i = 0; i < info.GetNumElements(); i++)
103 {
104 outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
105 }
106 }
107
Quantize(uint8_t * quant,const float * dequant,const TensorInfo & info)108 inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
109 {
110 for (size_t i = 0; i < info.GetNumElements(); i++)
111 {
112 quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
113 }
114 }
115
116 } //namespace armnn
117