1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <boost/test/unit_test.hpp>
6
7 #include <armnn/LayerVisitorBase.hpp>
8
9 #include <armnn/backends/IBackendContext.hpp>
10 #include <armnn/backends/IBackendInternal.hpp>
11 #include <armnn/backends/IMemoryManager.hpp>
12 #include <armnn/backends/ITensorHandleFactory.hpp>
13 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
14
15 #include <optimizations/Optimization.hpp>
16
17 #include <Network.hpp>
18
19 #include <armnn/utility/IgnoreUnused.hpp>
20
21 #include <vector>
22 #include <string>
23
24
25 using namespace armnn;
26
27 class TestMemMgr : public IMemoryManager
28 {
29 public:
30 TestMemMgr() = default;
31
Acquire()32 void Acquire() override {}
Release()33 void Release() override {}
34 };
35
36 class TestFactory1 : public ITensorHandleFactory
37 {
38 public:
TestFactory1(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)39 TestFactory1(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
40 : m_Id(id)
41 , m_MemMgr(mgr)
42 {}
43
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const44 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
45 TensorShape const& subTensorShape,
46 unsigned int const* subTensorOrigin) const override
47 {
48 IgnoreUnused(parent, subTensorShape, subTensorOrigin);
49 return nullptr;
50 }
51
CreateTensorHandle(const TensorInfo & tensorInfo) const52 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
53 {
54 IgnoreUnused(tensorInfo);
55 return nullptr;
56 }
57
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const58 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
59 DataLayout dataLayout) const override
60 {
61 IgnoreUnused(tensorInfo, dataLayout);
62 return nullptr;
63 }
64
GetId() const65 const FactoryId& GetId() const override { return m_Id; }
66
SupportsSubTensors() const67 bool SupportsSubTensors() const override { return true; }
68
GetExportFlags() const69 MemorySourceFlags GetExportFlags() const override { return 1; }
70
71 private:
72 FactoryId m_Id = "UninitializedId";
73
74 std::weak_ptr<IMemoryManager> m_MemMgr;
75 };
76
77 class TestFactoryImport : public ITensorHandleFactory
78 {
79 public:
TestFactoryImport(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)80 TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
81 : m_Id(id)
82 , m_MemMgr(mgr)
83 {}
84
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const85 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
86 TensorShape const& subTensorShape,
87 unsigned int const* subTensorOrigin) const override
88 {
89 IgnoreUnused(parent, subTensorShape, subTensorOrigin);
90 return nullptr;
91 }
92
CreateTensorHandle(const TensorInfo & tensorInfo) const93 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
94 {
95 IgnoreUnused(tensorInfo);
96 return nullptr;
97 }
98
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const99 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
100 DataLayout dataLayout) const override
101 {
102 IgnoreUnused(tensorInfo, dataLayout);
103 return nullptr;
104 }
105
GetId() const106 const FactoryId& GetId() const override { return m_Id; }
107
SupportsSubTensors() const108 bool SupportsSubTensors() const override { return true; }
109
GetImportFlags() const110 MemorySourceFlags GetImportFlags() const override { return 1; }
111
112 private:
113 FactoryId m_Id = "ImporterId";
114
115 std::weak_ptr<IMemoryManager> m_MemMgr;
116 };
117
118 class TestBackendA : public IBackendInternal
119 {
120 public:
121 TestBackendA() = default;
122
GetId() const123 const BackendId& GetId() const override { return m_Id; }
124
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const125 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
126 {
127 IgnoreUnused(memoryManager);
128 return IWorkloadFactoryPtr{};
129 }
130
GetLayerSupport() const131 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
132 {
133 return ILayerSupportSharedPtr{};
134 }
135
GetHandleFactoryPreferences() const136 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
137 {
138 return std::vector<ITensorHandleFactory::FactoryId>
139 {
140 "TestHandleFactoryA1",
141 "TestHandleFactoryA2",
142 "TestHandleFactoryB1"
143 };
144 }
145
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)146 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
147 {
148 auto mgr = std::make_shared<TestMemMgr>();
149
150 registry.RegisterMemoryManager(mgr);
151 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA1"));
152 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA2"));
153 }
154
155 private:
156 BackendId m_Id = "BackendA";
157 };
158
159 class TestBackendB : public IBackendInternal
160 {
161 public:
162 TestBackendB() = default;
163
GetId() const164 const BackendId& GetId() const override { return m_Id; }
165
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const166 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
167 {
168 IgnoreUnused(memoryManager);
169 return IWorkloadFactoryPtr{};
170 }
171
GetLayerSupport() const172 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
173 {
174 return ILayerSupportSharedPtr{};
175 }
176
GetHandleFactoryPreferences() const177 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
178 {
179 return std::vector<ITensorHandleFactory::FactoryId>
180 {
181 "TestHandleFactoryB1"
182 };
183 }
184
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)185 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
186 {
187 auto mgr = std::make_shared<TestMemMgr>();
188
189 registry.RegisterMemoryManager(mgr);
190 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryB1"));
191 }
192
193 private:
194 BackendId m_Id = "BackendB";
195 };
196
197 class TestBackendC : public IBackendInternal
198 {
199 public:
200 TestBackendC() = default;
201
GetId() const202 const BackendId& GetId() const override { return m_Id; }
203
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const204 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
205 {
206 IgnoreUnused(memoryManager);
207 return IWorkloadFactoryPtr{};
208 }
209
GetLayerSupport() const210 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
211 {
212 return ILayerSupportSharedPtr{};
213 }
214
GetHandleFactoryPreferences() const215 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
216 {
217 return std::vector<ITensorHandleFactory::FactoryId>{
218 "TestHandleFactoryC1"
219 };
220 }
221
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)222 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
223 {
224 auto mgr = std::make_shared<TestMemMgr>();
225
226 registry.RegisterMemoryManager(mgr);
227 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryC1"));
228 }
229
230 private:
231 BackendId m_Id = "BackendC";
232 };
233
234 class TestBackendD : public IBackendInternal
235 {
236 public:
237 TestBackendD() = default;
238
GetId() const239 const BackendId& GetId() const override { return m_Id; }
240
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const241 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
242 {
243 IgnoreUnused(memoryManager);
244 return IWorkloadFactoryPtr{};
245 }
246
GetLayerSupport() const247 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
248 {
249 return ILayerSupportSharedPtr{};
250 }
251
GetHandleFactoryPreferences() const252 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
253 {
254 return std::vector<ITensorHandleFactory::FactoryId>{
255 "TestHandleFactoryD1"
256 };
257 }
258
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)259 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
260 {
261 auto mgr = std::make_shared<TestMemMgr>();
262
263 registry.RegisterMemoryManager(mgr);
264 registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1"));
265 }
266
267 private:
268 BackendId m_Id = "BackendD";
269 };
270
271
272 BOOST_AUTO_TEST_SUITE(TensorHandle)
273
BOOST_AUTO_TEST_CASE(RegisterFactories)274 BOOST_AUTO_TEST_CASE(RegisterFactories)
275 {
276 TestBackendA backendA;
277 TestBackendB backendB;
278
279 BOOST_TEST(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1");
280 BOOST_TEST(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2");
281 BOOST_TEST(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1");
282
283 TensorHandleFactoryRegistry registry;
284 backendA.RegisterTensorHandleFactories(registry);
285 backendB.RegisterTensorHandleFactories(registry);
286
287 BOOST_TEST((registry.GetFactory("Non-existing Backend") == nullptr));
288 BOOST_TEST((registry.GetFactory("TestHandleFactoryA1") != nullptr));
289 BOOST_TEST((registry.GetFactory("TestHandleFactoryA2") != nullptr));
290 BOOST_TEST((registry.GetFactory("TestHandleFactoryB1") != nullptr));
291 }
292
BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)293 BOOST_AUTO_TEST_CASE(TensorHandleSelectionStrategy)
294 {
295 auto backendA = std::make_unique<TestBackendA>();
296 auto backendB = std::make_unique<TestBackendB>();
297 auto backendC = std::make_unique<TestBackendC>();
298 auto backendD = std::make_unique<TestBackendD>();
299
300 TensorHandleFactoryRegistry registry;
301 backendA->RegisterTensorHandleFactories(registry);
302 backendB->RegisterTensorHandleFactories(registry);
303 backendC->RegisterTensorHandleFactories(registry);
304 backendD->RegisterTensorHandleFactories(registry);
305
306 BackendsMap backends;
307 backends["BackendA"] = std::move(backendA);
308 backends["BackendB"] = std::move(backendB);
309 backends["BackendC"] = std::move(backendC);
310 backends["BackendD"] = std::move(backendD);
311
312 armnn::Graph graph;
313
314 armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input");
315 inputLayer->SetBackendId("BackendA");
316
317 armnn::SoftmaxDescriptor smDesc;
318 armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1");
319 softmaxLayer1->SetBackendId("BackendA");
320
321 armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2");
322 softmaxLayer2->SetBackendId("BackendB");
323
324 armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
325 softmaxLayer3->SetBackendId("BackendC");
326
327 armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
328 softmaxLayer4->SetBackendId("BackendD");
329
330 armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
331 outputLayer->SetBackendId("BackendA");
332
333 inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
334 softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
335 softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
336 softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
337 softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
338
339 graph.TopologicalSort();
340
341 std::vector<std::string> errors;
342 auto result = SelectTensorHandleStrategy(graph, backends, registry, true, errors);
343
344 BOOST_TEST(result.m_Error == false);
345 BOOST_TEST(result.m_Warning == false);
346
347 OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
348 OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
349 OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
350 OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
351 OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
352
353 // Check that the correct factory was selected
354 BOOST_TEST(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryA1");
355 BOOST_TEST(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
356 BOOST_TEST(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
357 BOOST_TEST(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
358 BOOST_TEST(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
359
360 // Check that the correct strategy was selected
361 BOOST_TEST((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
362 BOOST_TEST((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
363 BOOST_TEST((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
364 BOOST_TEST((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
365 BOOST_TEST((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
366
367 graph.AddCompatibilityLayers(backends, registry);
368
369 // Test for copy layers
370 int copyCount= 0;
371 graph.ForEachLayer([©Count](Layer* layer)
372 {
373 if (layer->GetType() == LayerType::MemCopy)
374 {
375 copyCount++;
376 }
377 });
378 BOOST_TEST(copyCount == 1);
379
380 // Test for import layers
381 int importCount= 0;
382 graph.ForEachLayer([&importCount](Layer *layer)
383 {
384 if (layer->GetType() == LayerType::MemImport)
385 {
386 importCount++;
387 }
388 });
389 BOOST_TEST(importCount == 1);
390 }
391
392 BOOST_AUTO_TEST_SUITE_END()
393