• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "TestUtils.hpp"
7 
8 #include <armnn/utility/Assert.hpp>
9 
10 #include "armnnTestUtils/Version.hpp"
11 
12 using namespace armnn;
13 
Connect(armnn::IConnectableLayer * from,armnn::IConnectableLayer * to,const armnn::TensorInfo & tensorInfo,unsigned int fromIndex,unsigned int toIndex)14 void Connect(armnn::IConnectableLayer* from, armnn::IConnectableLayer* to, const armnn::TensorInfo& tensorInfo,
15              unsigned int fromIndex, unsigned int toIndex)
16 {
17     ARMNN_ASSERT(from);
18     ARMNN_ASSERT(to);
19 
20     try
21     {
22         from->GetOutputSlot(fromIndex).Connect(to->GetInputSlot(toIndex));
23     }
24     catch (const std::out_of_range& exc)
25     {
26         std::ostringstream message;
27 
28         if (to->GetType() == armnn::LayerType::FullyConnected && toIndex == 2)
29         {
30             message << "Tried to connect bias to FullyConnected layer when bias is not enabled: ";
31         }
32 
33         message << "Failed to connect to input slot "
34                 << toIndex
35                 << " on "
36                 << GetLayerTypeAsCString(to->GetType())
37                 << " layer "
38                 << std::quoted(to->GetName())
39                 << " as the slot does not exist or is unavailable";
40         throw LayerValidationException(message.str());
41     }
42 
43     from->GetOutputSlot(fromIndex).SetTensorInfo(tensorInfo);
44 }
45 
46 namespace armnn
47 {
48 
GetGraphForTesting(IOptimizedNetwork * optNet)49 Graph& GetGraphForTesting(IOptimizedNetwork* optNet)
50 {
51     return optNet->pOptimizedNetworkImpl->GetGraph();
52 }
53 
GetModelOptionsForTesting(IOptimizedNetwork * optNet)54 ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNet)
55 {
56     return optNet->pOptimizedNetworkImpl->GetModelOptions();
57 }
58 
GetProfilingService(armnn::RuntimeImpl * runtime)59 arm::pipe::IProfilingService& GetProfilingService(armnn::RuntimeImpl* runtime)
60 {
61     return *(runtime->m_ProfilingService.get());
62 }
63 
64 }
65