1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "NetworkUtils.hpp" 8 #include "Optimization.hpp" 9 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 template <typename LayerT> ConvertWeight(Layer * l)18inline LayerT* ConvertWeight(Layer* l) 19 { 20 LayerT* layer = PolymorphicDowncast<LayerT*>(l); 21 if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected) 22 && layer->m_Weight) 23 { 24 const TensorInfo& info = layer->m_Weight->GetTensorInfo(); 25 26 if (info.GetDataType() == DataType::Float32) 27 { 28 std::vector<BFloat16> newValues(info.GetNumElements()); 29 30 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(), 31 info.GetNumElements(), 32 newValues.data()); 33 34 TensorInfo newInfo(info.GetShape(), DataType::BFloat16); 35 ConstTensor newInput(newInfo, newValues); 36 layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput)); 37 } 38 } 39 return layer; 40 } 41 42 class ConvertFp32NetworkToBf16Impl 43 { 44 public: 45 Run(Graph & graph,Layer & layer) const46 void Run(Graph& graph, Layer& layer) const 47 { 48 // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer. 49 // And also convert weight data type from Float32 to Bfloat16. 50 // Do not convert bias data type. 51 if (layer.GetType() == LayerType::Convolution2d) 52 { 53 if (layer.GetDataType() == DataType::Float32) 54 { 55 InsertConvertFp32ToBf16LayersBefore(graph,layer); 56 ConvertWeight<Convolution2dLayer>(&layer); 57 } 58 } 59 else if (layer.GetType() == LayerType::FullyConnected) 60 { 61 if (layer.GetDataType() == DataType::Float32) 62 { 63 InsertConvertFp32ToBf16LayersBefore(graph,layer); 64 ConvertWeight<FullyConnectedLayer>(&layer); 65 } 66 } 67 } 68 69 protected: 70 ConvertFp32NetworkToBf16Impl() = default; 71 ~ConvertFp32NetworkToBf16Impl() = default; 72 }; 73 74 using Fp32NetworkToBf16Converter = OptimizeForType<Layer, ConvertFp32NetworkToBf16Impl>; 75 76 } // namespace optimizations 77 } // namespace armnn 78