• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <CreateWorkload.hpp>
8 #include <armnnTestUtils/PredicateResult.hpp>
9 #include <armnn/utility/PolymorphicDowncast.hpp>
10 #include <armnn/backends/MemCopyWorkload.hpp>
11 #include <reference/RefWorkloadFactory.hpp>
12 #include <reference/RefTensorHandle.hpp>
13 
14 #if defined(ARMCOMPUTECL_ENABLED)
15 #include <cl/ClTensorHandle.hpp>
16 #endif
17 
18 #if defined(ARMCOMPUTENEON_ENABLED)
19 #include <neon/NeonTensorHandle.hpp>
20 #endif
21 
22 #include <doctest/doctest.h>
23 
24 using namespace armnn;
25 
26 namespace
27 {
28 
29 using namespace std;
30 
31 template<typename IComputeTensorHandle>
CompareTensorHandleShape(IComputeTensorHandle * tensorHandle,std::initializer_list<unsigned int> expectedDimensions)32 PredicateResult CompareTensorHandleShape(IComputeTensorHandle* tensorHandle,
33                                          std::initializer_list<unsigned int> expectedDimensions)
34 {
35     arm_compute::ITensorInfo* info = tensorHandle->GetTensor().info();
36 
37     auto infoNumDims = info->num_dimensions();
38     auto numExpectedDims = expectedDimensions.size();
39     if (infoNumDims != numExpectedDims)
40     {
41         PredicateResult res(false);
42         res.Message() << "Different number of dimensions [" << info->num_dimensions()
43                       << "!=" << expectedDimensions.size() << "]";
44         return res;
45     }
46 
47     size_t i = info->num_dimensions() - 1;
48 
49     for (unsigned int expectedDimension : expectedDimensions)
50     {
51         if (info->dimension(i) != expectedDimension)
52         {
53             PredicateResult res(false);
54             res.Message() << "For dimension " << i <<
55                              " expected size " << expectedDimension <<
56                              " got " << info->dimension(i);
57             return res;
58         }
59 
60         i--;
61     }
62 
63     return PredicateResult(true);
64 }
65 
66 template<typename IComputeTensorHandle>
CreateMemCopyWorkloads(IWorkloadFactory & factory)67 void CreateMemCopyWorkloads(IWorkloadFactory& factory)
68 {
69     TensorHandleFactoryRegistry registry;
70     Graph graph;
71     RefWorkloadFactory refFactory;
72 
73     // Creates the layers we're testing.
74     Layer* const layer1 = graph.AddLayer<MemCopyLayer>("layer1");
75     Layer* const layer2 = graph.AddLayer<MemCopyLayer>("layer2");
76 
77     // Creates extra layers.
78     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
79     Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
80 
81     // Connects up.
82     TensorInfo tensorInfo({2, 3}, DataType::Float32);
83     Connect(input, layer1, tensorInfo);
84     Connect(layer1, layer2, tensorInfo);
85     Connect(layer2, output, tensorInfo);
86 
87     input->CreateTensorHandles(registry, refFactory);
88     layer1->CreateTensorHandles(registry, factory);
89     layer2->CreateTensorHandles(registry, refFactory);
90     output->CreateTensorHandles(registry, refFactory);
91 
92     // make the workloads and check them
93     auto workload1 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer1, factory);
94     auto workload2 = MakeAndCheckWorkload<CopyMemGenericWorkload>(*layer2, refFactory);
95 
96     MemCopyQueueDescriptor queueDescriptor1 = workload1->GetData();
97     CHECK(queueDescriptor1.m_Inputs.size() == 1);
98     CHECK(queueDescriptor1.m_Outputs.size() == 1);
99     auto inputHandle1  = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor1.m_Inputs[0]);
100     auto outputHandle1 = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor1.m_Outputs[0]);
101     CHECK((inputHandle1->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
102     auto result = CompareTensorHandleShape<IComputeTensorHandle>(outputHandle1, {2, 3});
103     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
104 
105 
106     MemCopyQueueDescriptor queueDescriptor2 = workload2->GetData();
107     CHECK(queueDescriptor2.m_Inputs.size() == 1);
108     CHECK(queueDescriptor2.m_Outputs.size() == 1);
109     auto inputHandle2  = PolymorphicDowncast<IComputeTensorHandle*>(queueDescriptor2.m_Inputs[0]);
110     auto outputHandle2 = PolymorphicDowncast<RefTensorHandle*>(queueDescriptor2.m_Outputs[0]);
111     result = CompareTensorHandleShape<IComputeTensorHandle>(inputHandle2, {2, 3});
112     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
113     CHECK((outputHandle2->GetTensorInfo() == TensorInfo({2, 3}, DataType::Float32)));
114 }
115 
116 } //namespace
117