• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
9 #include <armnn/INetwork.hpp>
10 
11 #include <doctest/doctest.h>
12 #include <CommonTestUtils.hpp>
13 
14 namespace
15 {
16 
17 template<typename armnn::DataType DataType>
CreateBatchMatMulNetwork(const armnn::TensorShape & inputXShape,const armnn::TensorShape & inputYShape,const armnn::TensorShape & outputShape,const float qScale=1.0f,const int32_t qOffset=0)18 armnn::INetworkPtr CreateBatchMatMulNetwork(const armnn::TensorShape& inputXShape,
19                                      const armnn::TensorShape& inputYShape,
20                                      const armnn::TensorShape& outputShape,
21                                      const float qScale = 1.0f,
22                                      const int32_t qOffset = 0)
23 {
24     using namespace armnn;
25 
26     INetworkPtr network(INetwork::Create());
27 
28     TensorInfo inputXTensorInfo(inputXShape, DataType, qScale, qOffset, true);
29     TensorInfo inputYTensorInfo(inputYShape, DataType, qScale, qOffset, true);
30 
31     TensorInfo outputTensorInfo(outputShape, DataType, qScale, qOffset);
32 
33     BatchMatMulDescriptor batchMatMulDesc;
34     batchMatMulDesc.m_TransposeX = false;
35     batchMatMulDesc.m_TransposeY = true;
36 
37     IConnectableLayer* batchMatMul = network->AddBatchMatMulLayer(batchMatMulDesc, "batchMatMul");
38     IConnectableLayer* inputX = network->AddInputLayer(0, "inputX");
39     IConnectableLayer* inputY = network->AddInputLayer(1, "inputY");
40     IConnectableLayer* output = network->AddOutputLayer(0, "output");
41 
42     Connect(inputX, batchMatMul, inputXTensorInfo, 0, 0);
43     Connect(inputY, batchMatMul, inputYTensorInfo, 0, 1);
44     Connect(batchMatMul, output, outputTensorInfo, 0, 0);
45 
46     return network;
47 }
48 
49 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
BatchMatMulEndToEnd(const std::vector<armnn::BackendId> & backends)50 void BatchMatMulEndToEnd(const std::vector<armnn::BackendId>& backends)
51 {
52     using namespace armnn;
53 
54     const TensorShape& inputXShape = { 2, 2, 2 };
55     const TensorShape& inputYShape = { 2, 2, 2 };
56     const TensorShape& outputShape = { 2, 2, 2 };
57 
58     INetworkPtr network = CreateBatchMatMulNetwork<ArmnnType>(inputXShape, inputYShape, outputShape);
59 
60     CHECK(network);
61 
62     std::vector<T> inputXData{ 1, 2,
63                                3, 4,
64 
65                                9, 10,
66                                11, 12 };
67     std::vector<T> inputYData{ 5, 7,
68                                6, 8,
69 
70                                13, 15,
71                                14, 16 };
72     std::vector<T> expectedOutput{ 19, 22,
73                                    43, 50,
74 
75                                    267, 286,
76                                    323, 346 };
77 
78     std::map<int, std::vector<T>> inputTensorData = {{ 0, inputXData }, {1, inputYData}};
79     std::map<int, std::vector<T>> expectedOutputData = { { 0, expectedOutput } };
80 
81     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(std::move(network), inputTensorData, expectedOutputData, backends);
82 }
83 
84 } // anonymous namespace