• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RankLayer.hpp"
7 
8 #include "LayerCloneBase.hpp"
9 
10 #include <armnn/backends/WorkloadData.hpp>
11 #include <armnn/backends/WorkloadFactory.hpp>
12 
13 namespace armnn
14 {
15 
RankLayer(const char * name)16 RankLayer::RankLayer(const char* name)
17         : Layer(1, 1, LayerType::Rank, name)
18 {}
19 
CreateWorkload(const IWorkloadFactory & factory) const20 std::unique_ptr<IWorkload> RankLayer::CreateWorkload(const IWorkloadFactory& factory) const
21 {
22     RankQueueDescriptor descriptor;
23     SetAdditionalInfo(descriptor);
24 
25     return factory.CreateWorkload(LayerType::Rank, descriptor, PrepInfoAndDesc(descriptor));
26 }
27 
Clone(Graph & graph) const28 Layer* RankLayer::Clone(Graph& graph) const
29 {
30     RankLayer* clone = CloneBase<RankLayer>(graph, GetName());
31     return clone;
32 }
33 
InferOutputShapes(const std::vector<TensorShape> &) const34 std::vector<TensorShape> RankLayer::InferOutputShapes(const std::vector<TensorShape>&) const
35 {
36     return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar) });
37 }
38 
ValidateTensorShapesFromInputs()39 void RankLayer::ValidateTensorShapesFromInputs()
40 {
41     VerifyLayerConnections(1, CHECK_LOCATION());
42 
43     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44     const TensorShape inferredShape = TensorShape(Dimensionality::Scalar);
45 
46     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
47     ValidateAndCopyShape(outputShape, inferredShape, m_ShapeInferenceMethod, "RankLayer");
48 }
49 
ExecuteStrategy(IStrategy & strategy) const50 void RankLayer::ExecuteStrategy(IStrategy& strategy) const
51 {
52     strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName());
53 }
54 
55 } //namespace armnn