1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include "LayerTestResult.hpp"
8 #include "TensorCopyUtils.hpp"
9 #include "TensorHelpers.hpp"
10 #include "WorkloadTestUtils.hpp"
11 #include <ResolveType.hpp>
12 #include <armnn/backends/IBackendInternal.hpp>
13 #include <armnnTestUtils/MockBackend.hpp>
14
15 namespace
16 {
17
18 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)19 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
20 armnn::IWorkloadFactory& dstWorkloadFactory,
21 bool withSubtensors)
22 {
23 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
24 const armnn::TensorShape tensorShape(4, shapeData.data());
25 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
26 std::vector<T> inputData =
27 {
28 1, 2, 3, 4, 5,
29 6, 7, 8, 9, 10,
30 11, 12, 13, 14, 15,
31 16, 17, 18, 19, 20,
32 21, 22, 23, 24, 25,
33 26, 27, 28, 29, 30,
34 };
35
36 LayerTestResult<T, 4> ret(tensorInfo);
37 ret.m_ExpectedData = inputData;
38
39 std::vector<T> actualOutput(tensorInfo.GetNumElements());
40
41 ARMNN_NO_DEPRECATE_WARN_BEGIN
42 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
43 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
44 ARMNN_NO_DEPRECATE_WARN_END
45
46 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
47 outputTensorHandle->Allocate();
48
49 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
50 armnn::WorkloadInfo workloadInfo;
51
52 const unsigned int origin[4] = {};
53
54 ARMNN_NO_DEPRECATE_WARN_BEGIN
55 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
56 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
57 : std::move(inputTensorHandle);
58 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
59 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
60 : std::move(outputTensorHandle);
61 ARMNN_NO_DEPRECATE_WARN_END
62
63 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
64 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
65
66 dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
67
68 CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
69 ret.m_ActualData = actualOutput;
70
71 return ret;
72 }
73
74 template <typename WorkloadFactoryType>
75 struct MemCopyTestHelper
76 {};
77 template <>
78 struct MemCopyTestHelper<armnn::MockWorkloadFactory>
79 {
GetMemoryManager__anon00cf49b20111::MemCopyTestHelper80 static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
81 {
82 armnn::MockBackend backend;
83 return backend.CreateMemoryManager();
84 }
85
86 static armnn::MockWorkloadFactory
GetFactory__anon00cf49b20111::MemCopyTestHelper87 GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr)
88 {
89 IgnoreUnused(memoryManager);
90 return armnn::MockWorkloadFactory();
91 }
92 };
93
94 using MockMemCopyTestHelper = MemCopyTestHelper<armnn::MockWorkloadFactory>;
95
96 template <typename SrcWorkloadFactory,
97 typename DstWorkloadFactory,
98 armnn::DataType dataType,
99 typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)100 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
101 {
102
103 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
104 MemCopyTestHelper<SrcWorkloadFactory>::GetMemoryManager();
105
106 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
107 MemCopyTestHelper<DstWorkloadFactory>::GetMemoryManager();
108
109 SrcWorkloadFactory srcWorkloadFactory = MemCopyTestHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
110 DstWorkloadFactory dstWorkloadFactory = MemCopyTestHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
111
112 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
113 }
114
115 } // anonymous namespace
116