• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
9 #include <armnn/backends/IBackendInternal.hpp>
10 
11 #include <backendsCommon/test/LayerTests.hpp>
12 #include <backendsCommon/test/TensorCopyUtils.hpp>
13 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
14 #include <backendsCommon/test/WorkloadTestUtils.hpp>
15 
16 #include <test/TensorHelpers.hpp>
17 
18 #include <boost/multi_array.hpp>
19 
20 namespace
21 {
22 
23 template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)24 LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
25                                   armnn::IWorkloadFactory& dstWorkloadFactory,
26                                   bool withSubtensors)
27 {
28     const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
29     const armnn::TensorShape tensorShape(4, shapeData.data());
30     const armnn::TensorInfo tensorInfo(tensorShape, dataType);
31     boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>(
32         {
33              1,  2,  3,  4,  5,
34              6,  7,  8,  9, 10,
35             11, 12, 13, 14, 15,
36             16, 17, 18, 19, 20,
37             21, 22, 23, 24, 25,
38             26, 27, 28, 29, 30,
39         })
40     );
41 
42     LayerTestResult<T, 4> ret(tensorInfo);
43     ret.outputExpected = inputData;
44 
45     boost::multi_array<T, 4> outputData(shapeData);
46 
47     ARMNN_NO_DEPRECATE_WARN_BEGIN
48     auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
49     auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
50     ARMNN_NO_DEPRECATE_WARN_END
51 
52     AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
53     outputTensorHandle->Allocate();
54 
55     armnn::MemCopyQueueDescriptor memCopyQueueDesc;
56     armnn::WorkloadInfo workloadInfo;
57 
58     const unsigned int origin[4] = {};
59 
60     ARMNN_NO_DEPRECATE_WARN_BEGIN
61     auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
62                          ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
63                          : std::move(inputTensorHandle);
64     auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
65                           ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
66                           : std::move(outputTensorHandle);
67     ARMNN_NO_DEPRECATE_WARN_END
68 
69     AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
70     AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
71 
72     dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute();
73 
74     CopyDataFromITensorHandle(outputData.data(), workloadOutput.get());
75     ret.output = outputData;
76 
77     return ret;
78 }
79 
80 template<typename SrcWorkloadFactory,
81          typename DstWorkloadFactory,
82          armnn::DataType dataType,
83          typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)84 LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
85 {
86     armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
87         WorkloadFactoryHelper<SrcWorkloadFactory>::GetMemoryManager();
88 
89     armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
90         WorkloadFactoryHelper<DstWorkloadFactory>::GetMemoryManager();
91 
92     SrcWorkloadFactory srcWorkloadFactory = WorkloadFactoryHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
93     DstWorkloadFactory dstWorkloadFactory = WorkloadFactoryHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
94 
95     return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
96 }
97 
98 } // anonymous namespace
99