• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <Layer.hpp>
9 
10 namespace armnn
11 {
12 
13 /// NOTE: this is an abstract class to encapsulate the element wise operations, it does not implement:
14 /// std::unique_ptr<IWorkload> Layer::CreateWorkload(const IWorkloadFactory& factory) const = 0;
15 /// Layer* Clone(Graph& graph) const = 0;
16 class ElementwiseBaseLayer : public Layer
17 {
18 public:
19     /// Check if the input tensor shape(s)
20     /// will lead to a valid configuration of the element wise operation.
21     /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
22     void ValidateTensorShapesFromInputs() override;
23 
24     /// By default returns inputShapes if the number of inputs are equal to number of outputs,
25     /// otherwise infers the output shapes from given input shapes and layer properties.
26     /// @param [in] inputShapes The input shapes layer has.
27     /// @return A vector to the inferred output shape.
28     std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
29 
30 protected:
31     /// @param numInputSlots The number of input slots for the layer.
32     /// @param numOutputSlots The number of output slots for the layer.
33     /// @param type The layer type.
34     /// @param name Optional name for the layer.
35     ElementwiseBaseLayer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name);
36 
37     /// Default destructor
38     ~ElementwiseBaseLayer() = default;
39 };
40 
41 } // namespace
42