1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "BaseIterator.hpp"
9
10 #include <armnnUtils/TensorUtils.hpp>
11
12 #include <armnn/utility/Assert.hpp>
13
14 namespace armnn
15 {
16
17 template<typename T>
18 inline std::unique_ptr<Encoder<T>> MakeEncoder(const TensorInfo& info, void* data = nullptr);
19
20 template<>
MakeEncoder(const TensorInfo & info,void * data)21 inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* data)
22 {
23 switch(info.GetDataType())
24 {
25 ARMNN_NO_DEPRECATE_WARN_BEGIN
26 case armnn::DataType::QuantizedSymm8PerAxis:
27 {
28 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
29 return std::make_unique<QSymm8PerAxisEncoder>(
30 static_cast<int8_t*>(data),
31 params.second,
32 params.first);
33 }
34 ARMNN_NO_DEPRECATE_WARN_END
35 case armnn::DataType::QAsymmS8:
36 {
37 return std::make_unique<QASymmS8Encoder>(
38 static_cast<int8_t*>(data),
39 info.GetQuantizationScale(),
40 info.GetQuantizationOffset());
41 }
42 case armnn::DataType::QAsymmU8:
43 {
44 return std::make_unique<QASymm8Encoder>(
45 static_cast<uint8_t*>(data),
46 info.GetQuantizationScale(),
47 info.GetQuantizationOffset());
48 }
49 case DataType::QSymmS8:
50 {
51 if (info.HasPerAxisQuantization())
52 {
53 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
54 return std::make_unique<QSymm8PerAxisEncoder>(
55 static_cast<int8_t*>(data),
56 params.second,
57 params.first);
58 }
59 else
60 {
61 return std::make_unique<QSymmS8Encoder>(
62 static_cast<int8_t*>(data),
63 info.GetQuantizationScale(),
64 info.GetQuantizationOffset());
65 }
66 }
67 case armnn::DataType::QSymmS16:
68 {
69 return std::make_unique<QSymm16Encoder>(
70 static_cast<int16_t*>(data),
71 info.GetQuantizationScale(),
72 info.GetQuantizationOffset());
73 }
74 case armnn::DataType::Signed32:
75 {
76 return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data));
77 }
78 case armnn::DataType::BFloat16:
79 {
80 return std::make_unique<BFloat16Encoder>(static_cast<armnn::BFloat16*>(data));
81 }
82 case armnn::DataType::Float16:
83 {
84 return std::make_unique<Float16Encoder>(static_cast<Half*>(data));
85 }
86 case armnn::DataType::Float32:
87 {
88 return std::make_unique<Float32Encoder>(static_cast<float*>(data));
89 }
90 default:
91 {
92 ARMNN_ASSERT_MSG(false, "Unsupported target Data Type!");
93 break;
94 }
95 }
96 return nullptr;
97 }
98
99 template<>
MakeEncoder(const TensorInfo & info,void * data)100 inline std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void* data)
101 {
102 switch(info.GetDataType())
103 {
104 case armnn::DataType::Boolean:
105 {
106 return std::make_unique<BooleanEncoder>(static_cast<uint8_t*>(data));
107 }
108 default:
109 {
110 ARMNN_ASSERT_MSG(false, "Cannot encode from boolean. Not supported target Data Type!");
111 break;
112 }
113 }
114 return nullptr;
115 }
116
117 template<>
MakeEncoder(const TensorInfo & info,void * data)118 inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
119 {
120 switch(info.GetDataType())
121 {
122 case DataType::Signed32:
123 {
124 return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
125 }
126 default:
127 {
128 ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
129 break;
130 }
131 }
132 return nullptr;
133 }
134
135 } //namespace armnn
136