• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SerializerTestUtils.hpp"
7 #include "../Serializer.hpp"
8 
9 #include <doctest/doctest.h>
10 
11 using armnnDeserializer::IDeserializer;
12 
LayerVerifierBase(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)13 LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14                                      const std::vector<armnn::TensorInfo>& inputInfos,
15                                      const std::vector<armnn::TensorInfo>& outputInfos)
16                                      : m_LayerName(layerName)
17                                      , m_InputTensorInfos(inputInfos)
18                                      , m_OutputTensorInfos(outputInfos)
19 {}
20 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id)21 void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22                      const armnn::BaseDescriptor& descriptor,
23                      const std::vector<armnn::ConstTensor>& constants,
24                      const char* name,
25                      const armnn::LayerBindingId id)
26 {
27     armnn::IgnoreUnused(descriptor, constants, id);
28     switch (layer->GetType())
29     {
30         case armnn::LayerType::Input: break;
31         case armnn::LayerType::Output: break;
32         default:
33         {
34             VerifyNameAndConnections(layer, name);
35         }
36     }
37 }
38 
39 
VerifyNameAndConnections(const armnn::IConnectableLayer * layer,const char * name)40 void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41 {
42     CHECK(std::string(name) == m_LayerName.c_str());
43 
44     CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45     CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
46 
47     for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48     {
49         const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
50         CHECK(connectedOutput);
51 
52         const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
53         CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
54         CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
55 
56         if (connectedInfo.HasMultipleQuantizationScales())
57         {
58             CHECK(connectedInfo.GetQuantizationScales() == m_InputTensorInfos[i].GetQuantizationScales());
59         }
60         else
61         {
62             CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
63         }
64         CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
65     }
66 
67     for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
68     {
69         const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
70         CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
71         CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
72 
73         CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
74         CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
75     }
76 }
77 
VerifyConstTensors(const std::string & tensorName,const armnn::ConstTensor * expectedPtr,const armnn::ConstTensor * actualPtr)78 void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
79                                            const armnn::ConstTensor* expectedPtr,
80                                            const armnn::ConstTensor* actualPtr)
81 {
82     if (expectedPtr == nullptr)
83     {
84         CHECK_MESSAGE(actualPtr == nullptr, (tensorName + " should not exist"));
85     }
86     else
87     {
88         CHECK_MESSAGE(actualPtr != nullptr, (tensorName + " should have been set"));
89         if (actualPtr != nullptr)
90         {
91             const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
92             const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
93 
94             CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
95                           (tensorName + " shapes don't match"));
96             CHECK_MESSAGE(
97                     GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
98                     (tensorName + " data types don't match"));
99 
100             CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
101                           (tensorName + " (GetNumBytes) data sizes do not match"));
102             if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
103             {
104                 //check the data is identical
105                 const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
106                 const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
107                 bool same = true;
108                 for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
109                 {
110                     same = expectedData[i] == actualData[i];
111                     if (!same)
112                     {
113                         break;
114                     }
115                 }
116                 CHECK_MESSAGE(same, (tensorName + " data does not match"));
117             }
118         }
119     }
120 }
121 
CompareConstTensor(const armnn::ConstTensor & tensor1,const armnn::ConstTensor & tensor2)122 void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
123 {
124     CHECK(tensor1.GetShape() == tensor2.GetShape());
125     CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
126 
127     switch (tensor1.GetDataType())
128     {
129         case armnn::DataType::Float32:
130             CompareConstTensorData<const float*>(
131                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132             break;
133         case armnn::DataType::QAsymmU8:
134         case armnn::DataType::Boolean:
135             CompareConstTensorData<const uint8_t*>(
136                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137             break;
138         case armnn::DataType::QSymmS8:
139             CompareConstTensorData<const int8_t*>(
140                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141             break;
142         case armnn::DataType::Signed32:
143             CompareConstTensorData<const int32_t*>(
144                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145             break;
146         default:
147             // Note that Float16 is not yet implemented
148             MESSAGE("Unexpected datatype");
149             CHECK(false);
150     }
151 }
152 
DeserializeNetwork(const std::string & serializerString)153 armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
154 {
155     std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156     return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157 }
158 
SerializeNetwork(const armnn::INetwork & network)159 std::string SerializeNetwork(const armnn::INetwork& network)
160 {
161     armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
162 
163     serializer->Serialize(network);
164 
165     std::stringstream stream;
166     serializer->SaveSerializedToStream(stream);
167 
168     std::string serializerString{stream.str()};
169     return serializerString;
170 }
171