1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/Tensor.hpp>
9
10 #include <fmt/format.h>
11 #include <mapbox/variant.hpp>
12
13 namespace armnnUtils
14 {
15
16 template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17 inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18 const std::vector<TContainer>& inputDataContainers)
19 {
20 armnn::InputTensors inputTensors;
21
22 const size_t numInputs = inputBindings.size();
23 if (numInputs != inputDataContainers.size())
24 {
25 throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26 "tensor data containers: {0} != {1}",
27 numInputs,
28 inputDataContainers.size()));
29 }
30
31 for (size_t i = 0; i < numInputs; i++)
32 {
33 const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34 const TContainer& inputData = inputDataContainers[i];
35
36 mapbox::util::apply_visitor([&](auto&& value)
37 {
38 if (value.size() != inputBinding.second.GetNumElements())
39 {
40 throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41 inputBinding.second.GetNumElements(),
42 value.size()));
43 }
44
45 armnn::ConstTensor inputTensor(inputBinding.second, value.data());
46 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
47 },
48 inputData);
49 }
50
51 return inputTensors;
52 }
53
54 template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)55 inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
56 std::vector<TContainer>& outputDataContainers)
57 {
58 armnn::OutputTensors outputTensors;
59
60 const size_t numOutputs = outputBindings.size();
61 if (numOutputs != outputDataContainers.size())
62 {
63 throw armnn::Exception(fmt::format("Number of outputs does not match number"
64 "of tensor data containers: {0} != {1}",
65 numOutputs,
66 outputDataContainers.size()));
67 }
68
69 for (size_t i = 0; i < numOutputs; i++)
70 {
71 const armnn::BindingPointInfo& outputBinding = outputBindings[i];
72 TContainer& outputData = outputDataContainers[i];
73
74 mapbox::util::apply_visitor([&](auto&& value)
75 {
76 if (value.size() != outputBinding.second.GetNumElements())
77 {
78 throw armnn::Exception("Output tensor has incorrect size");
79 }
80
81 armnn::Tensor outputTensor(outputBinding.second, value.data());
82 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
83 },
84 outputData);
85 }
86
87 return outputTensors;
88 }
89
90 } // namespace armnnUtils
91