1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SerializerTestUtils.hpp"
7 #include "../Serializer.hpp"
8
9 #include <doctest/doctest.h>
10
11 using armnnDeserializer::IDeserializer;
12
LayerVerifierBase(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)13 LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14 const std::vector<armnn::TensorInfo>& inputInfos,
15 const std::vector<armnn::TensorInfo>& outputInfos)
16 : m_LayerName(layerName)
17 , m_InputTensorInfos(inputInfos)
18 , m_OutputTensorInfos(outputInfos)
19 {}
20
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id)21 void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22 const armnn::BaseDescriptor& descriptor,
23 const std::vector<armnn::ConstTensor>& constants,
24 const char* name,
25 const armnn::LayerBindingId id)
26 {
27 armnn::IgnoreUnused(descriptor, constants, id);
28 switch (layer->GetType())
29 {
30 case armnn::LayerType::Input: break;
31 case armnn::LayerType::Output: break;
32 default:
33 {
34 VerifyNameAndConnections(layer, name);
35 }
36 }
37 }
38
39
VerifyNameAndConnections(const armnn::IConnectableLayer * layer,const char * name)40 void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41 {
42 CHECK(std::string(name) == m_LayerName.c_str());
43
44 CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45 CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
46
47 for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48 {
49 const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
50 CHECK(connectedOutput);
51
52 const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
53 CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
54 CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
55
56 if (connectedInfo.HasMultipleQuantizationScales())
57 {
58 CHECK(connectedInfo.GetQuantizationScales() == m_InputTensorInfos[i].GetQuantizationScales());
59 }
60 else
61 {
62 CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
63 }
64 CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
65 }
66
67 for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
68 {
69 const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
70 CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
71 CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
72
73 CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
74 CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
75 }
76 }
77
VerifyConstTensors(const std::string & tensorName,const armnn::ConstTensor * expectedPtr,const armnn::ConstTensor * actualPtr)78 void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
79 const armnn::ConstTensor* expectedPtr,
80 const armnn::ConstTensor* actualPtr)
81 {
82 if (expectedPtr == nullptr)
83 {
84 CHECK_MESSAGE(actualPtr == nullptr, (tensorName + " should not exist"));
85 }
86 else
87 {
88 CHECK_MESSAGE(actualPtr != nullptr, (tensorName + " should have been set"));
89 if (actualPtr != nullptr)
90 {
91 const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
92 const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
93
94 CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
95 (tensorName + " shapes don't match"));
96 CHECK_MESSAGE(
97 GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
98 (tensorName + " data types don't match"));
99
100 CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
101 (tensorName + " (GetNumBytes) data sizes do not match"));
102 if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
103 {
104 //check the data is identical
105 const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
106 const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
107 bool same = true;
108 for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
109 {
110 same = expectedData[i] == actualData[i];
111 if (!same)
112 {
113 break;
114 }
115 }
116 CHECK_MESSAGE(same, (tensorName + " data does not match"));
117 }
118 }
119 }
120 }
121
CompareConstTensor(const armnn::ConstTensor & tensor1,const armnn::ConstTensor & tensor2)122 void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
123 {
124 CHECK(tensor1.GetShape() == tensor2.GetShape());
125 CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
126
127 switch (tensor1.GetDataType())
128 {
129 case armnn::DataType::Float32:
130 CompareConstTensorData<const float*>(
131 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132 break;
133 case armnn::DataType::QAsymmU8:
134 case armnn::DataType::Boolean:
135 CompareConstTensorData<const uint8_t*>(
136 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137 break;
138 case armnn::DataType::QSymmS8:
139 CompareConstTensorData<const int8_t*>(
140 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141 break;
142 case armnn::DataType::Signed32:
143 CompareConstTensorData<const int32_t*>(
144 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145 break;
146 default:
147 // Note that Float16 is not yet implemented
148 MESSAGE("Unexpected datatype");
149 CHECK(false);
150 }
151 }
152
DeserializeNetwork(const std::string & serializerString)153 armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
154 {
155 std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156 return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157 }
158
SerializeNetwork(const armnn::INetwork & network)159 std::string SerializeNetwork(const armnn::INetwork& network)
160 {
161 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
162
163 serializer->Serialize(network);
164
165 std::stringstream stream;
166 serializer->SaveSerializedToStream(stream);
167
168 std::string serializerString{stream.str()};
169 return serializerString;
170 }
171