• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
7 #include <neon/NeonBackend.hpp>
8 #include <neon/NeonTensorHandleFactory.hpp>
9 
10 #include <doctest/doctest.h>
11 
12 using namespace armnn;
13 
14 TEST_SUITE("NeonBackendTests")
15 {
16 TEST_CASE("NeonRegisterTensorHandleFactoriesMatchingImportFactoryId")
17 {
18     auto neonBackend = std::make_unique<NeonBackend>();
19     TensorHandleFactoryRegistry registry;
20     neonBackend->RegisterTensorHandleFactories(registry);
21 
22     // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
23     // Get matching import factory id correctly
24     CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
25            NeonTensorHandleFactory::GetIdStatic()));
26 }
27 
28 TEST_CASE("NeonCreateWorkloadFactoryMatchingImportFactoryId")
29 {
30     auto neonBackend = std::make_unique<NeonBackend>();
31     TensorHandleFactoryRegistry registry;
32     neonBackend->CreateWorkloadFactory(registry);
33 
34     // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
35     // Get matching import factory id correctly
36     CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
37            NeonTensorHandleFactory::GetIdStatic()));
38 }
39 
40 TEST_CASE("NeonCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
41 {
42     auto neonBackend = std::make_unique<NeonBackend>();
43     TensorHandleFactoryRegistry registry;
44     ModelOptions modelOptions;
45     neonBackend->CreateWorkloadFactory(registry, modelOptions);
46 
47     // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
48     // Get matching import factory id correctly
49     CHECK((registry.GetMatchingImportFactoryId(NeonTensorHandleFactory::GetIdStatic()) ==
50            NeonTensorHandleFactory::GetIdStatic()));
51 }
52 }
53