• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 #include <armnn/utility/IgnoreUnused.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 static const std::set<armnn::LayerType> broadcastOps {
18     LayerType::Addition,
19     LayerType::Division,
20     LayerType::Maximum,
21     LayerType::Minimum,
22     LayerType::Multiplication,
23     LayerType::Subtraction
24 };
25 
26 class AddBroadcastReshapeLayerImpl
27 {
28 public:
29     /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
Run(Graph & graph,Layer & layer) const30     void Run(Graph& graph, Layer& layer) const
31     {
32         if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
33         {
34             layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
35             layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
36 
37             const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
38             const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
39 
40             if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
41             {
42                 return;
43             }
44 
45             unsigned int reshapeSlot = 1;
46             TensorInfo reshapeInfo = inputInfo1;
47             TensorInfo inputInfo = inputInfo0;
48 
49             if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
50             {
51                 reshapeSlot = 0;
52                 reshapeInfo = inputInfo0;
53                 inputInfo = inputInfo1;
54             }
55 
56             uint32_t numDimensions = inputInfo.GetNumDimensions();
57 
58             std::vector<unsigned> reshapedDim;
59             for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
60             {
61                 reshapedDim.push_back(reshapeInfo.GetShape()[i]);
62             }
63 
64             std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
65             std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
66 
67             reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
68             const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
69             const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
70             ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),
71                                                                             descriptor,
72                                                                             layerName.c_str());
73             reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
74         }
75     }
76 
77 protected:
78     AddBroadcastReshapeLayerImpl() = default;
79     ~AddBroadcastReshapeLayerImpl() = default;
80 };
81 
82 using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>;
83 
84 } // namespace optimizations
85 } // namespace armnn
86