1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <doctest/doctest.h> 7 8 9 #include <armnn/backends/IBackendContext.hpp> 10 #include <armnn/backends/IBackendInternal.hpp> 11 #include <armnn/backends/IMemoryManager.hpp> 12 #include <armnn/backends/ITensorHandleFactory.hpp> 13 #include <backendsCommon/TensorHandleFactoryRegistry.hpp> 14 15 #include <optimizations/Optimization.hpp> 16 17 #include <Network.hpp> 18 19 #include <armnn/utility/IgnoreUnused.hpp> 20 21 #include <vector> 22 #include <string> 23 24 25 using namespace armnn; 26 27 class TestMemMgr : public IMemoryManager 28 { 29 public: 30 TestMemMgr() = default; 31 Acquire()32 void Acquire() override {} Release()33 void Release() override {} 34 }; 35 36 class TestFactory1 : public ITensorHandleFactory 37 { 38 public: TestFactory1(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)39 TestFactory1(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id) 40 : m_Id(id) 41 , m_MemMgr(mgr) 42 {} 43 CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const44 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 45 TensorShape const& subTensorShape, 46 unsigned int const* subTensorOrigin) const override 47 { 48 IgnoreUnused(parent, subTensorShape, subTensorOrigin); 49 return nullptr; 50 } 51 CreateTensorHandle(const TensorInfo & tensorInfo) const52 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override 53 { 54 IgnoreUnused(tensorInfo); 55 return nullptr; 56 } 57 CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const58 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 59 DataLayout dataLayout) const override 60 { 61 IgnoreUnused(tensorInfo, dataLayout); 62 return nullptr; 63 } 64 GetId() const65 const FactoryId& GetId() const override { return m_Id; } 66 SupportsSubTensors() const67 bool SupportsSubTensors() const override { return true; } 68 GetExportFlags() const69 MemorySourceFlags GetExportFlags() const override { return 1; } 70 71 private: 72 FactoryId m_Id = "UninitializedId"; 73 74 std::weak_ptr<IMemoryManager> m_MemMgr; 75 }; 76 77 class TestFactoryImport : public ITensorHandleFactory 78 { 79 public: TestFactoryImport(std::weak_ptr<IMemoryManager> mgr,ITensorHandleFactory::FactoryId id)80 TestFactoryImport(std::weak_ptr<IMemoryManager> mgr, ITensorHandleFactory::FactoryId id) 81 : m_Id(id) 82 , m_MemMgr(mgr) 83 {} 84 CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const85 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 86 TensorShape const& subTensorShape, 87 unsigned int const* subTensorOrigin) const override 88 { 89 IgnoreUnused(parent, subTensorShape, subTensorOrigin); 90 return nullptr; 91 } 92 CreateTensorHandle(const TensorInfo & tensorInfo) const93 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override 94 { 95 IgnoreUnused(tensorInfo); 96 return nullptr; 97 } 98 CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const99 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 100 DataLayout dataLayout) const override 101 { 102 IgnoreUnused(tensorInfo, dataLayout); 103 return nullptr; 104 } 105 GetId() const106 const FactoryId& GetId() const override { return m_Id; } 107 SupportsSubTensors() const108 bool SupportsSubTensors() const override { return true; } 109 GetImportFlags() const110 MemorySourceFlags GetImportFlags() const override { return 1; } 111 112 private: 113 FactoryId m_Id = "ImporterId"; 114 115 std::weak_ptr<IMemoryManager> m_MemMgr; 116 }; 117 118 class TestBackendA : public IBackendInternal 119 { 120 public: 121 TestBackendA() = default; 122 GetId() const123 const BackendId& GetId() const override { return m_Id; } 124 CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const125 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override 126 { 127 IgnoreUnused(memoryManager); 128 return IWorkloadFactoryPtr{}; 129 } 130 GetLayerSupport() const131 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override 132 { 133 return ILayerSupportSharedPtr{}; 134 } 135 GetHandleFactoryPreferences() const136 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override 137 { 138 return std::vector<ITensorHandleFactory::FactoryId> 139 { 140 "TestHandleFactoryA1", 141 "TestHandleFactoryA2", 142 "TestHandleFactoryB1", 143 "TestHandleFactoryD1" 144 }; 145 } 146 RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)147 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override 148 { 149 auto mgr = std::make_shared<TestMemMgr>(); 150 151 registry.RegisterMemoryManager(mgr); 152 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA1")); 153 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryA2")); 154 } 155 156 private: 157 BackendId m_Id = "BackendA"; 158 }; 159 160 class TestBackendB : public IBackendInternal 161 { 162 public: 163 TestBackendB() = default; 164 GetId() const165 const BackendId& GetId() const override { return m_Id; } 166 CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const167 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override 168 { 169 IgnoreUnused(memoryManager); 170 return IWorkloadFactoryPtr{}; 171 } 172 GetLayerSupport() const173 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override 174 { 175 return ILayerSupportSharedPtr{}; 176 } 177 GetHandleFactoryPreferences() const178 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override 179 { 180 return std::vector<ITensorHandleFactory::FactoryId> 181 { 182 "TestHandleFactoryB1" 183 }; 184 } 185 RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)186 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override 187 { 188 auto mgr = std::make_shared<TestMemMgr>(); 189 190 registry.RegisterMemoryManager(mgr); 191 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryB1")); 192 } 193 194 private: 195 BackendId m_Id = "BackendB"; 196 }; 197 198 class TestBackendC : public IBackendInternal 199 { 200 public: 201 TestBackendC() = default; 202 GetId() const203 const BackendId& GetId() const override { return m_Id; } 204 CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const205 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override 206 { 207 IgnoreUnused(memoryManager); 208 return IWorkloadFactoryPtr{}; 209 } 210 GetLayerSupport() const211 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override 212 { 213 return ILayerSupportSharedPtr{}; 214 } 215 GetHandleFactoryPreferences() const216 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override 217 { 218 return std::vector<ITensorHandleFactory::FactoryId>{ 219 "TestHandleFactoryC1" 220 }; 221 } 222 RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)223 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override 224 { 225 auto mgr = std::make_shared<TestMemMgr>(); 226 227 registry.RegisterMemoryManager(mgr); 228 registry.RegisterFactory(std::make_unique<TestFactory1>(mgr, "TestHandleFactoryC1")); 229 } 230 231 private: 232 BackendId m_Id = "BackendC"; 233 }; 234 235 class TestBackendD : public IBackendInternal 236 { 237 public: 238 TestBackendD() = default; 239 GetId() const240 const BackendId& GetId() const override { return m_Id; } 241 CreateWorkloadFactory(const IMemoryManagerSharedPtr & memoryManager=nullptr) const242 IWorkloadFactoryPtr CreateWorkloadFactory(const IMemoryManagerSharedPtr& memoryManager = nullptr) const override 243 { 244 IgnoreUnused(memoryManager); 245 return IWorkloadFactoryPtr{}; 246 } 247 GetLayerSupport() const248 IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override 249 { 250 return ILayerSupportSharedPtr{}; 251 } 252 GetHandleFactoryPreferences() const253 std::vector<ITensorHandleFactory::FactoryId> GetHandleFactoryPreferences() const override 254 { 255 return std::vector<ITensorHandleFactory::FactoryId>{ 256 "TestHandleFactoryD1", 257 }; 258 } 259 RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)260 void RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry) override 261 { 262 auto mgr = std::make_shared<TestMemMgr>(); 263 264 registry.RegisterMemoryManager(mgr); 265 registry.RegisterFactory(std::make_unique<TestFactoryImport>(mgr, "TestHandleFactoryD1")); 266 } 267 268 private: 269 BackendId m_Id = "BackendD"; 270 }; 271 272 273 TEST_SUITE("TensorHandle") 274 { 275 TEST_CASE("RegisterFactories") 276 { 277 TestBackendA backendA; 278 TestBackendB backendB; 279 280 CHECK(backendA.GetHandleFactoryPreferences()[0] == "TestHandleFactoryA1"); 281 CHECK(backendA.GetHandleFactoryPreferences()[1] == "TestHandleFactoryA2"); 282 CHECK(backendA.GetHandleFactoryPreferences()[2] == "TestHandleFactoryB1"); 283 CHECK(backendA.GetHandleFactoryPreferences()[3] == "TestHandleFactoryD1"); 284 285 TensorHandleFactoryRegistry registry; 286 backendA.RegisterTensorHandleFactories(registry); 287 backendB.RegisterTensorHandleFactories(registry); 288 289 CHECK((registry.GetFactory("Non-existing Backend") == nullptr)); 290 CHECK((registry.GetFactory("TestHandleFactoryA1") != nullptr)); 291 CHECK((registry.GetFactory("TestHandleFactoryA2") != nullptr)); 292 CHECK((registry.GetFactory("TestHandleFactoryB1") != nullptr)); 293 } 294 295 TEST_CASE("TensorHandleSelectionStrategy") 296 { 297 auto backendA = std::make_unique<TestBackendA>(); 298 auto backendB = std::make_unique<TestBackendB>(); 299 auto backendC = std::make_unique<TestBackendC>(); 300 auto backendD = std::make_unique<TestBackendD>(); 301 302 TensorHandleFactoryRegistry registry; 303 backendA->RegisterTensorHandleFactories(registry); 304 backendB->RegisterTensorHandleFactories(registry); 305 backendC->RegisterTensorHandleFactories(registry); 306 backendD->RegisterTensorHandleFactories(registry); 307 308 BackendsMap backends; 309 backends["BackendA"] = std::move(backendA); 310 backends["BackendB"] = std::move(backendB); 311 backends["BackendC"] = std::move(backendC); 312 backends["BackendD"] = std::move(backendD); 313 314 armnn::Graph graph; 315 316 armnn::InputLayer* const inputLayer = graph.AddLayer<armnn::InputLayer>(0, "input"); 317 inputLayer->SetBackendId("BackendA"); 318 319 armnn::SoftmaxDescriptor smDesc; 320 armnn::SoftmaxLayer* const softmaxLayer1 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax1"); 321 softmaxLayer1->SetBackendId("BackendA"); 322 323 armnn::SoftmaxLayer* const softmaxLayer2 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax2"); 324 softmaxLayer2->SetBackendId("BackendB"); 325 326 armnn::SoftmaxLayer* const softmaxLayer3 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax3"); 327 softmaxLayer3->SetBackendId("BackendC"); 328 329 armnn::SoftmaxLayer* const softmaxLayer4 = graph.AddLayer<armnn::SoftmaxLayer>(smDesc, "softmax4"); 330 softmaxLayer4->SetBackendId("BackendD"); 331 332 armnn::OutputLayer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output"); 333 outputLayer->SetBackendId("BackendA"); 334 335 inputLayer->GetOutputSlot(0).Connect(softmaxLayer1->GetInputSlot(0)); 336 softmaxLayer1->GetOutputSlot(0).Connect(softmaxLayer2->GetInputSlot(0)); 337 softmaxLayer2->GetOutputSlot(0).Connect(softmaxLayer3->GetInputSlot(0)); 338 softmaxLayer3->GetOutputSlot(0).Connect(softmaxLayer4->GetInputSlot(0)); 339 softmaxLayer4->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 340 341 graph.TopologicalSort(); 342 343 std::vector<std::string> errors; 344 auto result = SelectTensorHandleStrategy(graph, backends, registry, true, true, errors); 345 346 CHECK(result.m_Error == false); 347 CHECK(result.m_Warning == false); 348 349 OutputSlot& inputLayerOut = inputLayer->GetOutputSlot(0); 350 OutputSlot& softmaxLayer1Out = softmaxLayer1->GetOutputSlot(0); 351 OutputSlot& softmaxLayer2Out = softmaxLayer2->GetOutputSlot(0); 352 OutputSlot& softmaxLayer3Out = softmaxLayer3->GetOutputSlot(0); 353 OutputSlot& softmaxLayer4Out = softmaxLayer4->GetOutputSlot(0); 354 355 // Check that the correct factory was selected 356 CHECK(inputLayerOut.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); 357 CHECK(softmaxLayer1Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); 358 CHECK(softmaxLayer2Out.GetTensorHandleFactoryId() == "TestHandleFactoryB1"); 359 CHECK(softmaxLayer3Out.GetTensorHandleFactoryId() == "TestHandleFactoryC1"); 360 CHECK(softmaxLayer4Out.GetTensorHandleFactoryId() == "TestHandleFactoryD1"); 361 362 // Check that the correct strategy was selected 363 CHECK((inputLayerOut.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); 364 CHECK((softmaxLayer1Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); 365 CHECK((softmaxLayer2Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::CopyToTarget)); 366 CHECK((softmaxLayer3Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::ExportToTarget)); 367 CHECK((softmaxLayer4Out.GetEdgeStrategyForConnection(0) == EdgeStrategy::DirectCompatibility)); 368 369 graph.AddCompatibilityLayers(backends, registry); 370 371 // Test for copy layers 372 int copyCount= 0; 373 graph.ForEachLayer([©Count](Layer* layer) __anon0ec13aec0102(Layer* layer) 374 { 375 if (layer->GetType() == LayerType::MemCopy) 376 { 377 copyCount++; 378 } 379 }); 380 CHECK(copyCount == 1); 381 382 // Test for import layers 383 int importCount= 0; 384 graph.ForEachLayer([&importCount](Layer *layer) __anon0ec13aec0202(Layer *layer) 385 { 386 if (layer->GetType() == LayerType::MemImport) 387 { 388 importCount++; 389 } 390 }); 391 CHECK(importCount == 1); 392 } 393 394 TEST_CASE("RegisterCopyAndImportFactoryPairTest") 395 { 396 TensorHandleFactoryRegistry registry; 397 ITensorHandleFactory::FactoryId copyId = "CopyFactoryId"; 398 ITensorHandleFactory::FactoryId importId = "ImportFactoryId"; 399 registry.RegisterCopyAndImportFactoryPair(copyId, importId); 400 401 // Get mathing import factory id correctly 402 CHECK((registry.GetMatchingImportFactoryId(copyId) == importId)); 403 404 // Return empty id when Invalid Id is given 405 CHECK((registry.GetMatchingImportFactoryId("InvalidFactoryId") == "")); 406 } 407 408 } 409