• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "NetworkUtils.hpp"
8 #include "Optimization.hpp"
9 
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename LayerT>
ConvertWeight(Layer * l)18 inline LayerT* ConvertWeight(Layer* l)
19 {
20     LayerT* layer = PolymorphicDowncast<LayerT*>(l);
21     if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
22          && layer->m_Weight)
23     {
24         const TensorInfo& info = layer->m_Weight->GetTensorInfo();
25 
26         if (info.GetDataType() == DataType::Float32)
27         {
28             std::vector<BFloat16> newValues(info.GetNumElements());
29 
30             armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(),
31                                                                          info.GetNumElements(),
32                                                                          newValues.data());
33 
34             TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
35             ConstTensor newInput(newInfo, newValues);
36             layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput));
37         }
38     }
39     return layer;
40 }
41 
42 class ConvertFp32NetworkToBf16Impl
43 {
44 public:
45 
Run(Graph & graph,Layer & layer) const46     void Run(Graph& graph, Layer& layer) const
47     {
48         // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
49         // And also convert weight data type from Float32 to Bfloat16.
50         // Do not convert bias data type.
51         if (layer.GetType() == LayerType::Convolution2d)
52         {
53             if (layer.GetDataType() == DataType::Float32)
54             {
55                 InsertConvertFp32ToBf16LayersBefore(graph,layer);
56                 ConvertWeight<Convolution2dLayer>(&layer);
57             }
58         }
59         else if (layer.GetType() == LayerType::FullyConnected)
60         {
61             if (layer.GetDataType() == DataType::Float32)
62             {
63                 InsertConvertFp32ToBf16LayersBefore(graph,layer);
64                 ConvertWeight<FullyConnectedLayer>(&layer);
65             }
66         }
67     }
68 
69 protected:
70     ConvertFp32NetworkToBf16Impl() = default;
71     ~ConvertFp32NetworkToBf16Impl() = default;
72 };
73 
74 using Fp32NetworkToBf16Converter = OptimizeForType<Layer, ConvertFp32NetworkToBf16Impl>;
75 
76 } // namespace optimizations
77 } // namespace armnn
78