• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BaseIterator.hpp"
9 
10 #include <armnnUtils/TensorUtils.hpp>
11 
12 #include <armnn/utility/Assert.hpp>
13 
14 namespace armnn
15 {
16 
17 template<typename T>
18 inline std::unique_ptr<Encoder<T>> MakeEncoder(const TensorInfo& info, void* data = nullptr);
19 
20 template<>
MakeEncoder(const TensorInfo & info,void * data)21 inline std::unique_ptr<Encoder<float>> MakeEncoder(const TensorInfo& info, void* data)
22 {
23     switch(info.GetDataType())
24     {
25         ARMNN_NO_DEPRECATE_WARN_BEGIN
26         case armnn::DataType::QuantizedSymm8PerAxis:
27         {
28             std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
29             return std::make_unique<QSymm8PerAxisEncoder>(
30                 static_cast<int8_t*>(data),
31                 params.second,
32                 params.first);
33         }
34         ARMNN_NO_DEPRECATE_WARN_END
35         case armnn::DataType::QAsymmS8:
36         {
37             return std::make_unique<QASymmS8Encoder>(
38                 static_cast<int8_t*>(data),
39                 info.GetQuantizationScale(),
40                 info.GetQuantizationOffset());
41         }
42         case armnn::DataType::QAsymmU8:
43         {
44             return std::make_unique<QASymm8Encoder>(
45                 static_cast<uint8_t*>(data),
46                 info.GetQuantizationScale(),
47                 info.GetQuantizationOffset());
48         }
49         case DataType::QSymmS8:
50         {
51             if (info.HasPerAxisQuantization())
52             {
53                 std::pair<unsigned int, std::vector<float>> params = armnnUtils::GetPerAxisParams(info);
54                 return std::make_unique<QSymm8PerAxisEncoder>(
55                         static_cast<int8_t*>(data),
56                         params.second,
57                         params.first);
58             }
59             else
60             {
61                 return std::make_unique<QSymmS8Encoder>(
62                         static_cast<int8_t*>(data),
63                         info.GetQuantizationScale(),
64                         info.GetQuantizationOffset());
65             }
66         }
67         case armnn::DataType::QSymmS16:
68         {
69             return std::make_unique<QSymm16Encoder>(
70                 static_cast<int16_t*>(data),
71                 info.GetQuantizationScale(),
72                 info.GetQuantizationOffset());
73         }
74         case armnn::DataType::Signed32:
75         {
76             return std::make_unique<Int32Encoder>(static_cast<int32_t*>(data));
77         }
78         case armnn::DataType::BFloat16:
79         {
80             return std::make_unique<BFloat16Encoder>(static_cast<armnn::BFloat16*>(data));
81         }
82         case armnn::DataType::Float16:
83         {
84             return std::make_unique<Float16Encoder>(static_cast<Half*>(data));
85         }
86         case armnn::DataType::Float32:
87         {
88             return std::make_unique<Float32Encoder>(static_cast<float*>(data));
89         }
90         default:
91         {
92             ARMNN_ASSERT_MSG(false, "Unsupported target Data Type!");
93             break;
94         }
95     }
96     return nullptr;
97 }
98 
99 template<>
MakeEncoder(const TensorInfo & info,void * data)100 inline std::unique_ptr<Encoder<bool>> MakeEncoder(const TensorInfo& info, void* data)
101 {
102     switch(info.GetDataType())
103     {
104         case armnn::DataType::Boolean:
105         {
106             return std::make_unique<BooleanEncoder>(static_cast<uint8_t*>(data));
107         }
108         default:
109         {
110             ARMNN_ASSERT_MSG(false, "Cannot encode from boolean. Not supported target Data Type!");
111             break;
112         }
113     }
114     return nullptr;
115 }
116 
117 template<>
MakeEncoder(const TensorInfo & info,void * data)118 inline std::unique_ptr<Encoder<int32_t>> MakeEncoder(const TensorInfo& info, void* data)
119 {
120     switch(info.GetDataType())
121     {
122         case DataType::Signed32:
123         {
124             return std::make_unique<Int32ToInt32tEncoder>(static_cast<int32_t*>(data));
125         }
126         default:
127         {
128             ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
129             break;
130         }
131     }
132     return nullptr;
133 }
134 
135 } //namespace armnn
136