• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "LayerTestResult.hpp"
8 #include "TensorCopyUtils.hpp"
9 #include "TensorHelpers.hpp"
10 #include "WorkloadTestUtils.hpp"
11 #include <ResolveType.hpp>
12 #include <armnn/backends/IBackendInternal.hpp>
13 #include <armnnTestUtils/MockBackend.hpp>
14 
15 namespace
16 {
17 
18 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)19 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
20                                   armnn::IWorkloadFactory& dstWorkloadFactory,
21                                   bool withSubtensors)
22 {
23     const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
24     const armnn::TensorShape tensorShape(4, shapeData.data());
25     const armnn::TensorInfo tensorInfo(tensorShape, dataType);
26     std::vector<T> inputData =
27     {
28          1,  2,  3,  4,  5,
29          6,  7,  8,  9, 10,
30         11, 12, 13, 14, 15,
31         16, 17, 18, 19, 20,
32         21, 22, 23, 24, 25,
33         26, 27, 28, 29, 30,
34     };
35 
36     LayerTestResult<T, 4> ret(tensorInfo);
37     ret.m_ExpectedData = inputData;
38 
39     std::vector<T> actualOutput(tensorInfo.GetNumElements());
40 
41     ARMNN_NO_DEPRECATE_WARN_BEGIN
42     auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
43     auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
44     ARMNN_NO_DEPRECATE_WARN_END
45 
46     AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
47     outputTensorHandle->Allocate();
48 
49     armnn::MemCopyQueueDescriptor memCopyQueueDesc;
50     armnn::WorkloadInfo workloadInfo;
51 
52     const unsigned int origin[4] = {};
53 
54     ARMNN_NO_DEPRECATE_WARN_BEGIN
55     auto workloadInput  = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
56                               ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
57                               : std::move(inputTensorHandle);
58     auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
59                               ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
60                               : std::move(outputTensorHandle);
61     ARMNN_NO_DEPRECATE_WARN_END
62 
63     AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
64     AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
65 
66     dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
67 
68     CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
69     ret.m_ActualData = actualOutput;
70 
71     return ret;
72 }
73 
74 template <typename WorkloadFactoryType>
75 struct MemCopyTestHelper
76 {};
77 template <>
78 struct MemCopyTestHelper<armnn::MockWorkloadFactory>
79 {
GetMemoryManager__anon00cf49b20111::MemCopyTestHelper80     static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
81     {
82         armnn::MockBackend backend;
83         return backend.CreateMemoryManager();
84     }
85 
86     static armnn::MockWorkloadFactory
GetFactory__anon00cf49b20111::MemCopyTestHelper87         GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr)
88     {
89         IgnoreUnused(memoryManager);
90         return armnn::MockWorkloadFactory();
91     }
92 };
93 
94 using MockMemCopyTestHelper = MemCopyTestHelper<armnn::MockWorkloadFactory>;
95 
96 template <typename SrcWorkloadFactory,
97           typename DstWorkloadFactory,
98           armnn::DataType dataType,
99           typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)100 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
101 {
102 
103     armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
104         MemCopyTestHelper<SrcWorkloadFactory>::GetMemoryManager();
105 
106     armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
107         MemCopyTestHelper<DstWorkloadFactory>::GetMemoryManager();
108 
109     SrcWorkloadFactory srcWorkloadFactory = MemCopyTestHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
110     DstWorkloadFactory dstWorkloadFactory = MemCopyTestHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
111 
112     return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
113 }
114 
115 }    // anonymous namespace
116