1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 #include <armnn/utility/IgnoreUnused.hpp> 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 static const std::set<armnn::LayerType> broadcastOps { 18 LayerType::Addition, 19 LayerType::Division, 20 LayerType::Maximum, 21 LayerType::Minimum, 22 LayerType::Multiplication, 23 LayerType::Subtraction 24 }; 25 26 class AddBroadcastReshapeLayerImpl 27 { 28 public: 29 /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different. Run(Graph & graph,Layer & layer) const30 void Run(Graph& graph, Layer& layer) const 31 { 32 if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end()) 33 { 34 layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet(); 35 layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet(); 36 37 const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); 38 const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(); 39 40 if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions()) 41 { 42 return; 43 } 44 45 unsigned int reshapeSlot = 1; 46 TensorInfo reshapeInfo = inputInfo1; 47 TensorInfo inputInfo = inputInfo0; 48 49 if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions()) 50 { 51 reshapeSlot = 0; 52 reshapeInfo = inputInfo0; 53 inputInfo = inputInfo1; 54 } 55 56 uint32_t numDimensions = inputInfo.GetNumDimensions(); 57 58 std::vector<unsigned> reshapedDim; 59 for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i) 60 { 61 reshapedDim.push_back(reshapeInfo.GetShape()[i]); 62 } 63 64 std::vector<unsigned int> reshapedDimensions(numDimensions, 1); 65 std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end()); 66 67 reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() }); 68 const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot); 69 const ReshapeDescriptor descriptor{reshapeInfo.GetShape()}; 70 ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), 71 descriptor, 72 layerName.c_str()); 73 reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo); 74 } 75 } 76 77 protected: 78 AddBroadcastReshapeLayerImpl() = default; 79 ~AddBroadcastReshapeLayerImpl() = default; 80 }; 81 82 using AddBroadcastReshapeLayer = OptimizeForType<Layer, AddBroadcastReshapeLayerImpl>; 83 84 } // namespace optimizations 85 } // namespace armnn 86