1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 10 #include <BFloat16.hpp> 11 #include <Half.hpp> 12 13 namespace armnn 14 { 15 16 template<typename T> CompatibleTypes(DataType)17bool CompatibleTypes(DataType) 18 { 19 return false; 20 } 21 22 template<> CompatibleTypes(DataType dataType)23inline bool CompatibleTypes<float>(DataType dataType) 24 { 25 return dataType == DataType::Float32; 26 } 27 28 template<> CompatibleTypes(DataType dataType)29inline bool CompatibleTypes<Half>(DataType dataType) 30 { 31 return dataType == DataType::Float16; 32 } 33 34 template<> CompatibleTypes(DataType dataType)35inline bool CompatibleTypes<BFloat16>(DataType dataType) 36 { 37 return dataType == DataType::BFloat16; 38 } 39 40 template<> CompatibleTypes(DataType dataType)41inline bool CompatibleTypes<uint8_t>(DataType dataType) 42 { 43 return dataType == DataType::Boolean || dataType == DataType::QAsymmU8; 44 } 45 46 template<> CompatibleTypes(DataType dataType)47inline bool CompatibleTypes<int8_t>(DataType dataType) 48 { 49 ARMNN_NO_DEPRECATE_WARN_BEGIN 50 return dataType == DataType::QSymmS8 51 || dataType == DataType::QuantizedSymm8PerAxis 52 || dataType == DataType::QAsymmS8; 53 ARMNN_NO_DEPRECATE_WARN_END 54 } 55 56 template<> CompatibleTypes(DataType dataType)57inline bool CompatibleTypes<int16_t>(DataType dataType) 58 { 59 return dataType == DataType::QSymmS16; 60 } 61 62 template<> CompatibleTypes(DataType dataType)63inline bool CompatibleTypes<int32_t>(DataType dataType) 64 { 65 return dataType == DataType::Signed32; 66 } 67 68 } //namespace armnn 69