1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/backends/OptimizationViews.hpp>
9
10 namespace armnn
11 {
12
13 namespace
14 {
15
16 //
17 // this helper only works if all layers where the inputs connect to are not selected
18 //
CreateInputsFrom(const std::vector<Layer * > & layers)19 SubgraphView::InputSlots CreateInputsFrom(const std::vector<Layer*>& layers)
20 {
21 SubgraphView::InputSlots result;
22 for (auto&& layer : layers)
23 {
24 for (auto&& it = layer->BeginInputSlots(); it != layer->EndInputSlots(); ++it)
25 {
26 result.push_back(&(*it));
27 }
28 }
29 return result;
30 }
31
32 //
33 // this helper only works if all layers where the outputs connect to are not selected
34 //
CreateOutputsFrom(const std::vector<Layer * > & layers)35 SubgraphView::OutputSlots CreateOutputsFrom(const std::vector<Layer*>& layers)
36 {
37 SubgraphView::OutputSlots result;
38 for (auto&& layer : layers)
39 {
40 for (auto&& it = layer->BeginOutputSlots(); it != layer->EndOutputSlots(); ++it)
41 {
42 result.push_back(&(*it));
43 }
44 }
45 return result;
46 }
47
48 } // namespace
49
ReportUntouchedLayers(OptimizationViews & optimizationViews,std::map<LayerGuid,Layer * > untouched)50 inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
51 {
52 std::vector<Layer*> untouchedVector;
53 for (const auto& pair : untouched)
54 {
55 Layer* layer = pair.second;
56 SubgraphView subgraphView(CreateInputsFrom({layer}),
57 CreateOutputsFrom({layer}),
58 {layer});
59 optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
60 }
61 }
62
63 template<typename LayerType>
FuseLayerWithoutParameters(OptimizationViews & optimizationViews,LayerType * baseLayer,ActivationLayer * activationLayer,ActivationDescriptor & activationDesc,std::string name)64 LayerType* FuseLayerWithoutParameters(OptimizationViews& optimizationViews,
65 LayerType* baseLayer,
66 ActivationLayer* activationLayer,
67 ActivationDescriptor& activationDesc,
68 std::string name)
69 {
70 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(name.c_str());
71
72 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
73
74 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
75 CreateOutputsFrom({activationLayer}),
76 {baseLayer, activationLayer});
77 SubgraphView replacementSubgraph(replacementLayer);
78
79 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
80 return replacementLayer;
81 }
82
83 template<typename LayerType>
FuseLayerWithParameters(OptimizationViews & optimizationViews,LayerType * baseLayer,ActivationLayer * activationLayer,ActivationDescriptor & activationDesc,std::string name)84 LayerType* FuseLayerWithParameters(OptimizationViews& optimizationViews,
85 LayerType* baseLayer,
86 ActivationLayer* activationLayer,
87 ActivationDescriptor& activationDesc,
88 std::string name)
89 {
90 LayerType* replacementLayer = optimizationViews.GetGraph().AddLayer<LayerType>(baseLayer->GetParameters(),
91 name.c_str());
92
93 replacementLayer->SetAdditionalInfoForObject(std::make_shared<ActivationDescriptor>(activationDesc));
94
95 SubgraphView substitutionSubgraph(CreateInputsFrom({baseLayer}),
96 CreateOutputsFrom({activationLayer}),
97 {baseLayer, activationLayer});
98 SubgraphView replacementSubgraph(replacementLayer);
99
100 optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
101 return replacementLayer;
102 }
103
104 template<typename LayerType>
FuseLayerWithWeightsAndBiases(OptimizationViews & optimizationViews,LayerType * baseLayer,ActivationLayer * activationLayer,ActivationDescriptor & activationDesc,std::string name)105 LayerType* FuseLayerWithWeightsAndBiases(OptimizationViews& optimizationViews,
106 LayerType* baseLayer,
107 ActivationLayer* activationLayer,
108 ActivationDescriptor& activationDesc,
109 std::string name)
110 {
111 LayerType* replacementLayer = FuseLayerWithParameters(optimizationViews,
112 baseLayer,
113 activationLayer,
114 activationDesc,
115 name);
116
117 replacementLayer->m_Weight = std::move(baseLayer->m_Weight);
118 replacementLayer->m_Bias = std::move(baseLayer->m_Bias);
119
120 return replacementLayer;
121 }
122
123 } // namespace armnn
124