1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <CreateWorkload.hpp>
8 #include <armnnTestUtils/PredicateResult.hpp>
9 #include <armnn/utility/PolymorphicDowncast.hpp>
10 #include <armnn/backends/MemCopyWorkload.hpp>
11 #include <reference/RefWorkloadFactory.hpp>
12 #include <reference/RefTensorHandle.hpp>
13
14 #if defined(ARMCOMPUTECL_ENABLED)
15 #include <cl/ClTensorHandle.hpp>
16 #endif
17
18 #if defined(ARMCOMPUTENEON_ENABLED)
19 #include <neon/NeonTensorHandle.hpp>
20 #endif
21
22 #include <doctest/doctest.h>
23
24 using namespace armnn;
25
26 namespace
27 {
28
29 using namespace std;
30
31 template<typename IComputeTensorHandle>
CompareTensorHandleShape(IComputeTensorHandle * tensorHandle,std::initializer_list<unsigned int> expectedDimensions)32 PredicateResult CompareTensorHandleShape(IComputeTensorHandle* tensorHandle,
33 std::initializer_list<unsigned int> expectedDimensions)
34 {
35 arm_compute::ITensorInfo* info = tensorHandle->GetTensor().info();
36
37 auto infoNumDims = info->num_dimensions();
38 auto numExpectedDims = expectedDimensions.size();
39 if (infoNumDims != numExpectedDims)
40 {
41 PredicateResult res(false);
42 res.Message() << "Different number of dimensions [" << info->num_dimensions()
43 << "!=" << expectedDimensions.size() << "]";
44 return res;
45 }
46
47 size_t i = info->num_dimensions() - 1;
48
49 for (unsigned int expectedDimension : expectedDimensions)
50 {
51 if (info->dimension(i) != expectedDimension)
52 {
53 PredicateResult res(false);
54 res.Message() << "For dimension " << i <<
55 " expected size " << expectedDimension <<
56 " got " << info->dimension(i);
57 return res;
58 }
59
60 i--;
61 }
62
63 return PredicateResult(true);
64 }
65
66 template<typename IComputeTensorHandle>
CreateMemCopyWorkloads(IWorkloadFactory & factory)67 void CreateMemCopyWorkloads(IWorkloadFactory& factory)
68 {
69 TensorHandleFactoryRegistry registry;
70 Graph graph;
71 RefWorkloadFactory refFactory;
72
73 // Creates the layers we're testing.
74 Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1");
75 Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2");
76
77 // Creates extra layers.
78 Layer* const input = graph.AddLayer<InputLayer>(0, "input");
79 Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
80
81 // Connects up.
82 TensorInfo tensorInfo({2, 3}, DataType::Float32);
83 Connect(input, layer1, tensorInfo);
84 Connect(layer1, layer2, tensorInfo);
85 Connect(layer2, output, tensorInfo);
86
87 input->CreateTensorHandles(registry, refFactory);
88 layer1->CreateTensorHandles(registry, factory);
89 layer2->CreateTensorHandles(registry, refFactory);
90 output->CreateTensorHandles(registry, refFactory);
91
92 // make the workloads and check them
93 auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, factory);
94 auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, refFactory);
95
96 MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData();
97 CHECK(queueDescriptor1.m_Inputs.size() == 1);
98 CHECK(queueDescriptor1.m_Outputs.size() == 1);
99 auto inputHandle1 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor1.m_Inputs[0]);
100 auto outputHandle1 = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor1.m_Outputs[0]);
101 CHECK((inputHandle1->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
102 auto result = CompareTensorHandleShape<IComputeTensorHandle>(outputHandle1, {2, 3});
103 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
104
105
106 MemCopyQueueDescriptor queueDescriptor2 = workload2->GetData();
107 CHECK(queueDescriptor2.m_Inputs.size() == 1);
108 CHECK(queueDescriptor2.m_Outputs.size() == 1);
109 auto inputHandle2 = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor2.m_Inputs[0]);
110 auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor2.m_Outputs[0]);
111 result = CompareTensorHandleShape<IComputeTensorHandle>(inputHandle2, {2, 3});
112 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
113 CHECK((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
114 }
115
116 } //namespace
117