1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <ResolveType.hpp>
8
9 #include <armnn/backends/IBackendInternal.hpp>
10
11 #include <backendsCommon/test/LayerTests.hpp>
12 #include <backendsCommon/test/TensorCopyUtils.hpp>
13 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
14 #include <backendsCommon/test/WorkloadTestUtils.hpp>
15
16 #include <test/TensorHelpers.hpp>
17
18 #include <boost/multi_array.hpp>
19
20 namespace
21 {
22
23 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)24 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
25 armnn::IWorkloadFactory& dstWorkloadFactory,
26 bool withSubtensors)
27 {
28 const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
29 const armnn::TensorShape tensorShape(4, shapeData.data());
30 const armnn::TensorInfo tensorInfo(tensorShape, dataType);
31 boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>(
32 {
33 1, 2, 3, 4, 5,
34 6, 7, 8, 9, 10,
35 11, 12, 13, 14, 15,
36 16, 17, 18, 19, 20,
37 21, 22, 23, 24, 25,
38 26, 27, 28, 29, 30,
39 })
40 );
41
42 LayerTestResult<T, 4> ret(tensorInfo);
43 ret.outputExpected = inputData;
44
45 boost::multi_array<T, 4> outputData(shapeData);
46
47 ARMNN_NO_DEPRECATE_WARN_BEGIN
48 auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
49 auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
50 ARMNN_NO_DEPRECATE_WARN_END
51
52 AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
53 outputTensorHandle->Allocate();
54
55 armnn::MemCopyQueueDescriptor memCopyQueueDesc;
56 armnn::WorkloadInfo workloadInfo;
57
58 const unsigned int origin[4] = {};
59
60 ARMNN_NO_DEPRECATE_WARN_BEGIN
61 auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
62 ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
63 : std::move(inputTensorHandle);
64 auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
65 ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
66 : std::move(outputTensorHandle);
67 ARMNN_NO_DEPRECATE_WARN_END
68
69 AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
70 AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
71
72 dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
73
74 CopyDataFromITensorHandle(outputData.data(), workloadOutput.get());
75 ret.output = outputData;
76
77 return ret;
78 }
79
80 template<typename SrcWorkloadFactory,
81 typename DstWorkloadFactory,
82 armnn::DataType dataType,
83 typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)84 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
85 {
86 armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
87 WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
88
89 armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
90 WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
91
92 SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
93 DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
94
95 return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
96 }
97
98 } // anonymous namespace
99