1 //
2 // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RankLayer.hpp"
7
8 #include "LayerCloneBase.hpp"
9
10 #include <armnn/backends/WorkloadData.hpp>
11 #include <armnn/backends/WorkloadFactory.hpp>
12
13 namespace armnn
14 {
15
RankLayer(const char * name)16 RankLayer::RankLayer(const char* name)
17 : Layer(1, 1, LayerType::Rank, name)
18 {}
19
CreateWorkload(const IWorkloadFactory & factory) const20 std::unique_ptr<IWorkload> RankLayer::CreateWorkload(const IWorkloadFactory& factory) const
21 {
22 RankQueueDescriptor descriptor;
23 SetAdditionalInfo(descriptor);
24
25 return factory.CreateWorkload(LayerType::Rank, descriptor, PrepInfoAndDesc(descriptor));
26 }
27
Clone(Graph & graph) const28 Layer* RankLayer::Clone(Graph& graph) const
29 {
30 RankLayer* clone = CloneBase<RankLayer>(graph, GetName());
31 return clone;
32 }
33
InferOutputShapes(const std::vector<TensorShape> &) const34 std::vector<TensorShape> RankLayer::InferOutputShapes(const std::vector<TensorShape>&) const
35 {
36 return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar) });
37 }
38
ValidateTensorShapesFromInputs()39 void RankLayer::ValidateTensorShapesFromInputs()
40 {
41 VerifyLayerConnections(1, CHECK_LOCATION());
42
43 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44 const TensorShape inferredShape = TensorShape(Dimensionality::Scalar);
45
46 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
47 ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "RankLayer");
48 }
49
ExecuteStrategy(IStrategy & strategy) const50 void RankLayer::ExecuteStrategy(IStrategy& strategy) const
51 {
52 strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName());
53 }
54
55 } //namespace armnn