• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Tensor.hpp>
9 
10 #include <fmt/format.h>
11 #include <mapbox/variant.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 template<typename TContainer>
MakeInputTensors(const std::vector<armnn::BindingPointInfo> & inputBindings,const std::vector<TContainer> & inputDataContainers)17 inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
18                                             const std::vector<TContainer>& inputDataContainers)
19 {
20     armnn::InputTensors inputTensors;
21 
22     const size_t numInputs = inputBindings.size();
23     if (numInputs != inputDataContainers.size())
24     {
25         throw armnn::Exception(fmt::format("The number of inputs does not match number of "
26                                            "tensor data containers: {0} != {1}",
27                                            numInputs,
28                                            inputDataContainers.size()));
29     }
30 
31     for (size_t i = 0; i < numInputs; i++)
32     {
33         const armnn::BindingPointInfo& inputBinding = inputBindings[i];
34         const TContainer& inputData = inputDataContainers[i];
35 
36         mapbox::util::apply_visitor([&](auto&& value)
37         {
38             if (value.size() != inputBinding.second.GetNumElements())
39             {
40                throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
41                                                   inputBinding.second.GetNumElements(),
42                                                   value.size()));
43             }
44 
45             armnn::ConstTensor inputTensor(inputBinding.second, value.data());
46             inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
47         },
48         inputData);
49     }
50 
51     return inputTensors;
52 }
53 
54 template<typename TContainer>
MakeOutputTensors(const std::vector<armnn::BindingPointInfo> & outputBindings,std::vector<TContainer> & outputDataContainers)55 inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
56                                               std::vector<TContainer>& outputDataContainers)
57 {
58     armnn::OutputTensors outputTensors;
59 
60     const size_t numOutputs = outputBindings.size();
61     if (numOutputs != outputDataContainers.size())
62     {
63         throw armnn::Exception(fmt::format("Number of outputs does not match number"
64                                            "of tensor data containers: {0} != {1}",
65                                            numOutputs,
66                                            outputDataContainers.size()));
67     }
68 
69     for (size_t i = 0; i < numOutputs; i++)
70     {
71         const armnn::BindingPointInfo& outputBinding = outputBindings[i];
72         TContainer& outputData = outputDataContainers[i];
73 
74         mapbox::util::apply_visitor([&](auto&& value)
75         {
76             if (value.size() != outputBinding.second.GetNumElements())
77             {
78                 throw armnn::Exception("Output tensor has incorrect size");
79             }
80 
81             armnn::Tensor outputTensor(outputBinding.second, value.data());
82             outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
83         },
84         outputData);
85     }
86 
87     return outputTensors;
88 }
89 
90 } // namespace armnnUtils
91