• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../Serializer.hpp"
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/INetwork.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnnDeserializer/IDeserializer.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 #include <doctest/doctest.h>
15 
16 #include <sstream>
17 
18 TEST_SUITE("SerializerTests")
19 {
20 class VerifyActivationName : public armnn::IStrategy
21 {
22 public:
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)23     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
24                          const armnn::BaseDescriptor& descriptor,
25                          const std::vector<armnn::ConstTensor>& constants,
26                          const char* name,
27                          const armnn::LayerBindingId id = 0) override
28     {
29         IgnoreUnused(layer, descriptor, constants, id);
30         if (layer->GetType() == armnn::LayerType::Activation)
31         {
32             CHECK(std::string(name) == "activation");
33         }
34     }
35 };
36 
37 TEST_CASE("ActivationSerialization")
38 {
39     armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
40 
41     armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
42     armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
43 
44     // Construct network
45     armnn::INetworkPtr network = armnn::INetwork::Create();
46 
47     armnn::ActivationDescriptor descriptor;
48     descriptor.m_Function = armnn::ActivationFunction::ReLu;
49     descriptor.m_A = 0;
50     descriptor.m_B = 0;
51 
52     armnn::IConnectableLayer* const inputLayer      = network->AddInputLayer(0, "input");
53     armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
54     armnn::IConnectableLayer* const outputLayer     = network->AddOutputLayer(0, "output");
55 
56     inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
57     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
58 
59     activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
60     activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
61 
62     armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
63 
64     serializer->Serialize(*network);
65 
66     std::stringstream stream;
67     serializer->SaveSerializedToStream(stream);
68 
69     std::string const serializerString{stream.str()};
70     std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
71 
72     armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
73 
74     VerifyActivationName visitor;
75     deserializedNetwork->ExecuteStrategy(visitor);
76 
77     armnn::IRuntime::CreationOptions options; // default options
78     armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
79     auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
80 
81     armnn::NetworkId networkIdentifier;
82 
83     // Load graph into runtime
84     run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
85 
86     std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
87     armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
88     inputTensorInfo.SetConstant(true);
89     armnn::InputTensors inputTensors
90     {
91         {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
92     };
93 
94     std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
95 
96     std::vector<float> outputData(4);
97     armnn::OutputTensors outputTensors
98     {
99         {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
100     };
101     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
102     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutputData.begin(), expectedOutputData.end()));
103 }
104 
105 }
106