• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
10 #include <armnnUtils/FloatingPointConverter.hpp>
11 
12 #include <backendsCommon/CpuTensorHandle.hpp>
13 
14 #include <armnn/utility/IgnoreUnused.hpp>
15 
16 #include <BFloat16.hpp>
17 #include <Half.hpp>
18 
19 namespace armnn
20 {
21 namespace optimizations
22 {
23 
24 struct BFloat16ToFloat32
25 {
Funcarmnn::optimizations::BFloat16ToFloat3226     static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
27     {
28         const TensorInfo& info = handle->GetTensorInfo();
29 
30         if (info.GetDataType() == DataType::BFloat16)
31         {
32             std::vector<float> newValues(info.GetNumElements());
33 
34             armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(handle->GetTensor<BFloat16>(),
35                                                                          info.GetNumElements(),
36                                                                          newValues.data());
37 
38             TensorInfo newInfo(info.GetShape(), DataType::Float32);
39             ConstTensor newInput(newInfo, newValues);
40             handle.reset(new ScopedCpuTensorHandle(newInput));
41         }
42     }
43 };
44 
45 struct Float16ToFloat32
46 {
Funcarmnn::optimizations::Float16ToFloat3247     static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
48     {
49         const TensorInfo& info = handle->GetTensorInfo();
50 
51         if (info.GetDataType() == DataType::Float16)
52         {
53             std::vector<float> newValues(info.GetNumElements());
54 
55             armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetTensor<Half>(),
56                                                                    info.GetNumElements(),
57                                                                    newValues.data());
58 
59             TensorInfo newInfo(info.GetShape(), DataType::Float32);
60             ConstTensor newInput(newInfo, newValues);
61             handle.reset(new ScopedCpuTensorHandle(newInput));
62         }
63     }
64 };
65 
66 struct Float32ToBFloat16
67 {
Funcarmnn::optimizations::Float32ToBFloat1668     static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
69     {
70         const TensorInfo& info = handle->GetTensorInfo();
71 
72         if (info.GetDataType() == DataType::Float32)
73         {
74             std::vector<BFloat16> newValues(info.GetNumElements());
75 
76             armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(handle->GetTensor<float>(),
77                                                                          info.GetNumElements(),
78                                                                          newValues.data());
79 
80             TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
81             ConstTensor newInput(newInfo, newValues);
82             handle.reset(new ScopedCpuTensorHandle(newInput));
83         }
84     }
85 };
86 
87 struct Float32ToFloat16
88 {
Funcarmnn::optimizations::Float32ToFloat1689     static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle)
90     {
91         const TensorInfo& info = handle->GetTensorInfo();
92 
93         if (info.GetDataType() == DataType::Float32)
94         {
95             std::vector<Half> newValues(info.GetNumElements());
96 
97             armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetTensor<float>(),
98                                                                    info.GetNumElements(),
99                                                                    newValues.data());
100 
101             TensorInfo newInfo(info.GetShape(), DataType::Float16);
102             ConstTensor newInput(newInfo, newValues);
103             handle.reset(new ScopedCpuTensorHandle(newInput));
104         }
105     }
106 };
107 
108 template<typename Converter, typename Predicate>
109 class ConvertConstants : public Optimization
110 {
111 public:
112     ConvertConstants() = default;
113     ConvertConstants(const ConvertConstants&) = default;
114     virtual ~ConvertConstants() = default;
115 
Run(Graph & graph,Layer & layer) const116     void Run(Graph& graph, Layer& layer) const override
117     {
118         IgnoreUnused(graph);
119         if (Predicate::Test(layer))
120         {
121             layer.OperateOnConstantTensors(Converter::Func);
122         }
123     }
124 protected:
125 };
126 
127 struct IsFloat32Layer
128 {
Testarmnn::optimizations::IsFloat32Layer129     static bool Test(const Layer& layer)
130     {
131         return layer.GetDataType() == DataType::Float32;
132     }
133 };
134 
135 struct IsFloat16Layer
136 {
Testarmnn::optimizations::IsFloat16Layer137     static bool Test(const Layer& layer)
138     {
139         return layer.GetDataType() == DataType::Float16;
140     }
141 };
142 
143 struct IsBFloat16Layer
144 {
Testarmnn::optimizations::IsBFloat16Layer145     static bool Test(const Layer& layer)
146     {
147         return layer.GetDataType() == DataType::BFloat16;
148     }
149 };
150 
151 using ConvertConstantsBFloatToFloat = ConvertConstants<BFloat16ToFloat32, IsFloat32Layer>;
152 using ConvertConstantsFloatToBFloat = ConvertConstants<Float32ToBFloat16, IsBFloat16Layer>;
153 
154 using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
155 using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;
156 
157 } //namespace optimizations
158 } //namespace armnn
159