1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "Optimization.hpp" 9 10 #include <armnnUtils/FloatingPointConverter.hpp> 11 12 #include <backendsCommon/CpuTensorHandle.hpp> 13 14 #include <armnn/utility/IgnoreUnused.hpp> 15 16 #include <BFloat16.hpp> 17 #include <Half.hpp> 18 19 namespace armnn 20 { 21 namespace optimizations 22 { 23 24 struct BFloat16ToFloat32 25 { Funcarmnn::optimizations::BFloat16ToFloat3226 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle) 27 { 28 const TensorInfo& info = handle->GetTensorInfo(); 29 30 if (info.GetDataType() == DataType::BFloat16) 31 { 32 std::vector<float> newValues(info.GetNumElements()); 33 34 armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(handle->GetTensor<BFloat16>(), 35 info.GetNumElements(), 36 newValues.data()); 37 38 TensorInfo newInfo(info.GetShape(), DataType::Float32); 39 ConstTensor newInput(newInfo, newValues); 40 handle.reset(new ScopedCpuTensorHandle(newInput)); 41 } 42 } 43 }; 44 45 struct Float16ToFloat32 46 { Funcarmnn::optimizations::Float16ToFloat3247 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle) 48 { 49 const TensorInfo& info = handle->GetTensorInfo(); 50 51 if (info.GetDataType() == DataType::Float16) 52 { 53 std::vector<float> newValues(info.GetNumElements()); 54 55 armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetTensor<Half>(), 56 info.GetNumElements(), 57 newValues.data()); 58 59 TensorInfo newInfo(info.GetShape(), DataType::Float32); 60 ConstTensor newInput(newInfo, newValues); 61 handle.reset(new ScopedCpuTensorHandle(newInput)); 62 } 63 } 64 }; 65 66 struct Float32ToBFloat16 67 { Funcarmnn::optimizations::Float32ToBFloat1668 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle) 69 { 70 const TensorInfo& info = handle->GetTensorInfo(); 71 72 if (info.GetDataType() == DataType::Float32) 73 { 74 std::vector<BFloat16> newValues(info.GetNumElements()); 75 76 armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(handle->GetTensor<float>(), 77 info.GetNumElements(), 78 newValues.data()); 79 80 TensorInfo newInfo(info.GetShape(), DataType::BFloat16); 81 ConstTensor newInput(newInfo, newValues); 82 handle.reset(new ScopedCpuTensorHandle(newInput)); 83 } 84 } 85 }; 86 87 struct Float32ToFloat16 88 { Funcarmnn::optimizations::Float32ToFloat1689 static void Func(std::unique_ptr<ScopedCpuTensorHandle>& handle) 90 { 91 const TensorInfo& info = handle->GetTensorInfo(); 92 93 if (info.GetDataType() == DataType::Float32) 94 { 95 std::vector<Half> newValues(info.GetNumElements()); 96 97 armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetTensor<float>(), 98 info.GetNumElements(), 99 newValues.data()); 100 101 TensorInfo newInfo(info.GetShape(), DataType::Float16); 102 ConstTensor newInput(newInfo, newValues); 103 handle.reset(new ScopedCpuTensorHandle(newInput)); 104 } 105 } 106 }; 107 108 template<typename Converter, typename Predicate> 109 class ConvertConstants : public Optimization 110 { 111 public: 112 ConvertConstants() = default; 113 ConvertConstants(const ConvertConstants&) = default; 114 virtual ~ConvertConstants() = default; 115 Run(Graph & graph,Layer & layer) const116 void Run(Graph& graph, Layer& layer) const override 117 { 118 IgnoreUnused(graph); 119 if (Predicate::Test(layer)) 120 { 121 layer.OperateOnConstantTensors(Converter::Func); 122 } 123 } 124 protected: 125 }; 126 127 struct IsFloat32Layer 128 { Testarmnn::optimizations::IsFloat32Layer129 static bool Test(const Layer& layer) 130 { 131 return layer.GetDataType() == DataType::Float32; 132 } 133 }; 134 135 struct IsFloat16Layer 136 { Testarmnn::optimizations::IsFloat16Layer137 static bool Test(const Layer& layer) 138 { 139 return layer.GetDataType() == DataType::Float16; 140 } 141 }; 142 143 struct IsBFloat16Layer 144 { Testarmnn::optimizations::IsBFloat16Layer145 static bool Test(const Layer& layer) 146 { 147 return layer.GetDataType() == DataType::BFloat16; 148 } 149 }; 150 151 using ConvertConstantsBFloatToFloat = ConvertConstants<BFloat16ToFloat32, IsFloat32Layer>; 152 using ConvertConstantsFloatToBFloat = ConvertConstants<Float32ToBFloat16, IsBFloat16Layer>; 153 154 using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>; 155 using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>; 156 157 } //namespace optimizations 158 } //namespace armnn 159