• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <backendsCommon/CpuTensorHandle.hpp>
9 
10 #include <armnn/Tensor.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <reference/RefTensorHandle.hpp>
15 
16 #include <BFloat16.hpp>
17 #include <Half.hpp>
18 
19 namespace armnn
20 {
21 
22 ////////////////////////////////////////////
23 /// float32 helpers
24 ////////////////////////////////////////////
25 
GetTensorInfo(const ITensorHandle * tensorHandle)26 inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle)
27 {
28     // We know that reference workloads use RefTensorHandles for inputs and outputs
29     const RefTensorHandle* refTensorHandle =
30         PolymorphicDowncast<const RefTensorHandle*>(tensorHandle);
31     return refTensorHandle->GetTensorInfo();
32 }
33 
34 template <typename DataType, typename PayloadType>
GetInputTensorData(unsigned int idx,const PayloadType & data)35 const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data)
36 {
37     const ITensorHandle* tensorHandle = data.m_Inputs[idx];
38     return reinterpret_cast<const DataType*>(tensorHandle->Map());
39 }
40 
41 template <typename DataType, typename PayloadType>
GetOutputTensorData(unsigned int idx,const PayloadType & data)42 DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data)
43 {
44     ITensorHandle* tensorHandle = data.m_Outputs[idx];
45     return reinterpret_cast<DataType*>(tensorHandle->Map());
46 }
47 
48 template <typename PayloadType>
GetInputTensorDataFloat(unsigned int idx,const PayloadType & data)49 const float* GetInputTensorDataFloat(unsigned int idx, const PayloadType& data)
50 {
51     return GetInputTensorData<float>(idx, data);
52 }
53 
54 template <typename PayloadType>
GetOutputTensorDataFloat(unsigned int idx,const PayloadType & data)55 float* GetOutputTensorDataFloat(unsigned int idx, const PayloadType& data)
56 {
57     return GetOutputTensorData<float>(idx, data);
58 }
59 
60 template <typename PayloadType>
GetInputTensorDataHalf(unsigned int idx,const PayloadType & data)61 const Half* GetInputTensorDataHalf(unsigned int idx, const PayloadType& data)
62 {
63     return GetInputTensorData<Half>(idx, data);
64 }
65 
66 template <typename PayloadType>
GetOutputTensorDataHalf(unsigned int idx,const PayloadType & data)67 Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data)
68 {
69     return GetOutputTensorData<Half>(idx, data);
70 }
71 
72 template <typename PayloadType>
GetInputTensorDataBFloat16(unsigned int idx,const PayloadType & data)73 const BFloat16* GetInputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
74 {
75     return GetInputTensorData<BFloat16>(idx, data);
76 }
77 
78 template <typename PayloadType>
GetOutputTensorDataBFloat16(unsigned int idx,const PayloadType & data)79 BFloat16* GetOutputTensorDataBFloat16(unsigned int idx, const PayloadType& data)
80 {
81     return GetOutputTensorData<BFloat16>(idx, data);
82 }
83 
84 ////////////////////////////////////////////
85 /// u8 helpers
86 ////////////////////////////////////////////
87 
88 template<typename T>
Dequantize(const T * quant,const TensorInfo & info)89 std::vector<float> Dequantize(const T* quant, const TensorInfo& info)
90 {
91     std::vector<float> ret(info.GetNumElements());
92     for (size_t i = 0; i < info.GetNumElements(); i++)
93     {
94         ret[i] = armnn::Dequantize(quant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
95     }
96     return ret;
97 }
98 
99 template<typename T>
Dequantize(const T * inputData,float * outputData,const TensorInfo & info)100 inline void Dequantize(const T* inputData, float* outputData, const TensorInfo& info)
101 {
102     for (unsigned int i = 0; i < info.GetNumElements(); i++)
103     {
104         outputData[i] = Dequantize<T>(inputData[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
105     }
106 }
107 
Quantize(uint8_t * quant,const float * dequant,const TensorInfo & info)108 inline void Quantize(uint8_t* quant, const float* dequant, const TensorInfo& info)
109 {
110     for (size_t i = 0; i < info.GetNumElements(); i++)
111     {
112         quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
113     }
114 }
115 
116 } //namespace armnn
117