• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <doctest/doctest.h>
7 
8 
9 #include <armnn/backends/IBackendContext.hpp>
10 #include <armnn/backends/IBackendInternal.hpp>
11 #include <armnn/backends/IMemoryManager.hpp>
12 #include <armnn/backends/ITensorHandleFactory.hpp>
13 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
14 
15 #include <optimizations/Optimization.hpp>
16 
17 #include <Network.hpp>
18 
19 #include <armnn/utility/IgnoreUnused.hpp>
20 
21 #include <vector>
22 #include <string>
23 
24 
25 using namespace armnn;
26 
27 class TestMemMgr : public IMemoryManager
28 {
29 public:
30     TestMemMgr() = default;
31 
Acquire()32     void Acquire() override {}
Release()33     void Release() override {}
34 };
35 
36 class TestFactory1 : public ITensorHandleFactory
37 {
38 public:
TestFactory1(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)39     TestFactory1(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
40         : m_Id(id)
41         , m_MemMgr(mgr)
42     {}
43 
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const44     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
45                                                          TensorShape const& subTensorShape,
46                                                          unsigned int const* subTensorOrigin) const override
47     {
48         IgnoreUnused(parent, subTensorShape, subTensorOrigin);
49         return nullptr;
50     }
51 
CreateTensorHandle(const TensorInfo & tensorInfo) const52     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
53     {
54         IgnoreUnused(tensorInfo);
55         return nullptr;
56     }
57 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const58     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
59                                                       DataLayout dataLayout) const override
60     {
61         IgnoreUnused(tensorInfo, dataLayout);
62         return nullptr;
63     }
64 
GetId() const65     const FactoryId& GetId() const override { return m_Id; }
66 
SupportsSubTensors() const67     bool SupportsSubTensors() const override { return true; }
68 
GetExportFlags() const69     MemorySourceFlags GetExportFlags() const override { return 1; }
70 
71 private:
72     FactoryId m_Id = "UninitializedId";
73 
74     std::weak_ptr<IMemoryManager> m_MemMgr;
75 };
76 
77 class TestFactoryImport : public ITensorHandleFactory
78 {
79 public:
TestFactoryImport(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)80     TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id)
81         : m_Id(id)
82         , m_MemMgr(mgr)
83     {}
84 
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const85     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
86                                                          TensorShape const& subTensorShape,
87                                                          unsigned int const* subTensorOrigin) const override
88     {
89         IgnoreUnused(parent, subTensorShape, subTensorOrigin);
90         return nullptr;
91     }
92 
CreateTensorHandle(const TensorInfo & tensorInfo) const93     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override
94     {
95         IgnoreUnused(tensorInfo);
96         return nullptr;
97     }
98 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const99     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
100                                                       DataLayout dataLayout) const override
101     {
102         IgnoreUnused(tensorInfo, dataLayout);
103         return nullptr;
104     }
105 
GetId() const106     const FactoryId& GetId() const override { return m_Id; }
107 
SupportsSubTensors() const108     bool SupportsSubTensors() const override { return true; }
109 
GetImportFlags() const110     MemorySourceFlags GetImportFlags() const override { return 1; }
111 
112 private:
113     FactoryId m_Id = "ImporterId";
114 
115     std::weak_ptr<IMemoryManager> m_MemMgr;
116 };
117 
118 class TestBackendA : public IBackendInternal
119 {
120 public:
121     TestBackendA() = default;
122 
GetId() const123     const BackendId& GetId() const override { return m_Id; }
124 
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const125     IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
126     {
127         IgnoreUnused(memoryManager);
128         return IWorkloadFactoryPtr{};
129     }
130 
GetLayerSupport() const131     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
132     {
133         return ILayerSupportSharedPtr{};
134     }
135 
GetHandleFactoryPreferences() const136     std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
137     {
138         return std::vector<ITensorHandleFactory::FactoryId>
139         {
140             "TestHandleFactoryA1",
141             "TestHandleFactoryA2",
142             "TestHandleFactoryB1",
143             "TestHandleFactoryD1"
144         };
145     }
146 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)147     void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
148     {
149         auto mgr = std::make_shared<TestMemMgr>();
150 
151         registry.RegisterMemoryManager(mgr);
152         registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA1"));
153         registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA2"));
154     }
155 
156 private:
157     BackendId m_Id = "BackendA";
158 };
159 
160 class TestBackendB : public IBackendInternal
161 {
162 public:
163     TestBackendB() = default;
164 
GetId() const165     const BackendId& GetId() const override { return m_Id; }
166 
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const167     IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
168     {
169         IgnoreUnused(memoryManager);
170         return IWorkloadFactoryPtr{};
171     }
172 
GetLayerSupport() const173     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
174     {
175         return ILayerSupportSharedPtr{};
176     }
177 
GetHandleFactoryPreferences() const178     std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
179     {
180         return std::vector<ITensorHandleFactory::FactoryId>
181         {
182             "TestHandleFactoryB1"
183         };
184     }
185 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)186     void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
187     {
188         auto mgr = std::make_shared<TestMemMgr>();
189 
190         registry.RegisterMemoryManager(mgr);
191         registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryB1"));
192     }
193 
194 private:
195     BackendId m_Id = "BackendB";
196 };
197 
198 class TestBackendC : public IBackendInternal
199 {
200 public:
201     TestBackendC() = default;
202 
GetId() const203     const BackendId& GetId() const override { return m_Id; }
204 
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const205     IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
206     {
207         IgnoreUnused(memoryManager);
208         return IWorkloadFactoryPtr{};
209     }
210 
GetLayerSupport() const211     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
212     {
213         return ILayerSupportSharedPtr{};
214     }
215 
GetHandleFactoryPreferences() const216     std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
217     {
218         return std::vector<ITensorHandleFactory::FactoryId>{
219             "TestHandleFactoryC1"
220         };
221     }
222 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)223     void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
224     {
225         auto mgr = std::make_shared<TestMemMgr>();
226 
227         registry.RegisterMemoryManager(mgr);
228         registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryC1"));
229     }
230 
231 private:
232     BackendId m_Id = "BackendC";
233 };
234 
235 class TestBackendD : public IBackendInternal
236 {
237 public:
238     TestBackendD() = default;
239 
GetId() const240     const BackendId& GetId() const override { return m_Id; }
241 
CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const242     IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override
243     {
244         IgnoreUnused(memoryManager);
245         return IWorkloadFactoryPtr{};
246     }
247 
GetLayerSupport() const248     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
249     {
250         return ILayerSupportSharedPtr{};
251     }
252 
GetHandleFactoryPreferences() const253     std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override
254     {
255         return std::vector<ITensorHandleFactory::FactoryId>{
256             "TestHandleFactoryD1",
257         };
258     }
259 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)260     void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override
261     {
262         auto mgr = std::make_shared<TestMemMgr>();
263 
264         registry.RegisterMemoryManager(mgr);
265         registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1"));
266     }
267 
268 private:
269     BackendId m_Id = "BackendD";
270 };
271 
272 
273 TEST_SUITE("TensorHandle")
274 {
275 TEST_CASE("RegisterFactories")
276 {
277     TestBackendA backendA;
278     TestBackendB backendB;
279 
280     CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1");
281     CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2");
282     CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1");
283     CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1");
284 
285     TensorHandleFactoryRegistry registry;
286     backendA.RegisterTensorHandleFactories(registry);
287     backendB.RegisterTensorHandleFactories(registry);
288 
289     CHECK((registry.GetFactory("Non-existing Backend") == nullptr));
290     CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr));
291     CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr));
292     CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr));
293 }
294 
295 TEST_CASE("TensorHandleSelectionStrategy")
296 {
297     auto backendA = std::make_unique<TestBackendA>();
298     auto backendB = std::make_unique<TestBackendB>();
299     auto backendC = std::make_unique<TestBackendC>();
300     auto backendD = std::make_unique<TestBackendD>();
301 
302     TensorHandleFactoryRegistry registry;
303     backendA->RegisterTensorHandleFactories(registry);
304     backendB->RegisterTensorHandleFactories(registry);
305     backendC->RegisterTensorHandleFactories(registry);
306     backendD->RegisterTensorHandleFactories(registry);
307 
308     BackendsMap backends;
309     backends["BackendA"] = std::move(backendA);
310     backends["BackendB"] = std::move(backendB);
311     backends["BackendC"] = std::move(backendC);
312     backends["BackendD"] = std::move(backendD);
313 
314     armnn::Graph graph;
315 
316     armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input");
317     inputLayer->SetBackendId("BackendA");
318 
319     armnn::SoftmaxDescriptor smDesc;
320     armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1");
321     softmaxLayer1->SetBackendId("BackendA");
322 
323     armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2");
324     softmaxLayer2->SetBackendId("BackendB");
325 
326     armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3");
327     softmaxLayer3->SetBackendId("BackendC");
328 
329     armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4");
330     softmaxLayer4->SetBackendId("BackendD");
331 
332     armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
333     outputLayer->SetBackendId("BackendA");
334 
335     inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0));
336     softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0));
337     softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0));
338     softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0));
339     softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
340 
341     graph.TopologicalSort();
342 
343     std::vector<std::string> errors;
344     auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors);
345 
346     CHECK(result.m_Error == false);
347     CHECK(result.m_Warning == false);
348 
349     OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0);
350     OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0);
351     OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0);
352     OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0);
353     OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0);
354 
355     // Check that the correct factory was selected
356     CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
357     CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
358     CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1");
359     CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1");
360     CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1");
361 
362     // Check that the correct strategy was selected
363     CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
364     CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
365     CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget));
366     CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget));
367     CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility));
368 
369     graph.AddCompatibilityLayers(backends, registry);
370 
371     // Test for copy layers
372     int copyCount= 0;
373     graph.ForEachLayer([&copyCount](Layer* layer)
__anon0ec13aec0102(Layer* layer) 374     {
375         if (layer->GetType() == LayerType::MemCopy)
376         {
377             copyCount++;
378         }
379     });
380     CHECK(copyCount == 1);
381 
382     // Test for import layers
383     int importCount= 0;
384     graph.ForEachLayer([&importCount](Layer *layer)
__anon0ec13aec0202(Layer *layer) 385     {
386         if (layer->GetType() == LayerType::MemImport)
387         {
388             importCount++;
389         }
390     });
391     CHECK(importCount == 1);
392 }
393 
394 TEST_CASE("RegisterCopyAndImportFactoryPairTest")
395 {
396     TensorHandleFactoryRegistry registry;
397     ITensorHandleFactory::FactoryId copyId = "CopyFactoryId";
398     ITensorHandleFactory::FactoryId importId = "ImportFactoryId";
399     registry.RegisterCopyAndImportFactoryPair(copyId, importId);
400 
401     // Get mathing import factory id correctly
402     CHECK((registry.GetMatchingImportFactoryId(copyId) == importId));
403 
404     // Return empty id when Invalid Id is given
405     CHECK((registry.GetMatchingImportFactoryId("InvalidFactoryId") == ""));
406 }
407 
408 }
409