1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "StridedSliceLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/utility/NumericCast.hpp>
10
11 #include <armnn/backends/WorkloadData.hpp>
12 #include <armnn/backends/WorkloadFactory.hpp>
13
14 namespace armnn
15 {
16
StridedSliceLayer(const armnn::StridedSliceDescriptor & param,const char * name)17 StridedSliceLayer::StridedSliceLayer(const armnn::StridedSliceDescriptor& param, const char* name)
18 : LayerWithParameters(1, 1, LayerType::StridedSlice, param, name)
19 {
20 }
21
CreateWorkload(const IWorkloadFactory & factory) const22 std::unique_ptr<IWorkload> StridedSliceLayer::CreateWorkload(const IWorkloadFactory& factory) const
23 {
24 StridedSliceQueueDescriptor descriptor;
25
26 descriptor.m_Parameters.m_Begin = m_Param.m_Begin;
27 descriptor.m_Parameters.m_End = m_Param.m_End;
28 descriptor.m_Parameters.m_Stride = m_Param.m_Stride;
29
30 // Optional parameters
31 descriptor.m_Parameters.m_BeginMask = m_Param.m_BeginMask;
32 descriptor.m_Parameters.m_EndMask = m_Param.m_EndMask;
33 descriptor.m_Parameters.m_EllipsisMask = m_Param.m_EllipsisMask;
34 descriptor.m_Parameters.m_NewAxisMask = m_Param.m_NewAxisMask;
35 descriptor.m_Parameters.m_ShrinkAxisMask = m_Param.m_ShrinkAxisMask;
36
37 SetAdditionalInfo(descriptor);
38
39 return factory.CreateWorkload(LayerType::StridedSlice, descriptor, PrepInfoAndDesc(descriptor));
40 }
41
Clone(Graph & graph) const42 StridedSliceLayer* StridedSliceLayer::Clone(Graph& graph) const
43 {
44 return CloneBase<StridedSliceLayer>(graph, m_Param, GetName());
45 }
46
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const47 std::vector<TensorShape> StridedSliceLayer::InferOutputShapes(
48 const std::vector<TensorShape>& inputShapes) const
49 {
50 ARMNN_ASSERT(inputShapes.size() == 1);
51
52 TensorShape inputShape = inputShapes[0];
53 std::vector<unsigned int> outputShape;
54 unsigned int amountDimShrunk{0};
55
56 for (unsigned int i = 0; i < inputShape.GetNumDimensions(); i++)
57 {
58 int stride = m_Param.m_Stride[i];
59 int start = m_Param.GetStartForAxis(inputShape, i);
60 int stop = m_Param.GetStopForAxis(inputShape, i, start);
61
62 if (m_Param.m_ShrinkAxisMask & (1 << i))
63 {
64 amountDimShrunk+=1;
65
66 // If the difference between the start point and the end point of the slice on an axis being shrunk
67 // is greater than 1 then throw an error as the output will not be large enough to hold the slice
68 if (((m_Param.m_Begin[i] - m_Param.m_End[i]) > 1) || ((m_Param.m_Begin[i] - m_Param.m_End[i]) < -1))
69 {
70 throw LayerValidationException(
71 "StridedSlice: Attempting to take a larger slice than can fit in inferred output");
72 }
73
74 if (stride < 0)
75 {
76 throw LayerValidationException(
77 "StridedSlice: Stride can not be negative with Shrink Axis Mask set.");
78 }
79 continue;
80 }
81
82 int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
83 ((start - stop) - stride - 1) / -stride;
84
85 newSize = std::max(0, newSize);
86
87 outputShape.push_back(armnn::numeric_cast<unsigned int>(newSize));
88 }
89
90 if (outputShape.size() == 0 && (inputShape.GetNumDimensions() - amountDimShrunk) == 0)
91 {
92 outputShape.push_back(1);
93 }
94
95 return std::vector<TensorShape>({
96 TensorShape(armnn::numeric_cast<unsigned int>(outputShape.size()), &outputShape[0]) });
97 }
98
ValidateTensorShapesFromInputs()99 void StridedSliceLayer::ValidateTensorShapesFromInputs()
100 {
101 VerifyLayerConnections(1, CHECK_LOCATION());
102
103 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
104
105 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
106
107 auto inferredShapes = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()});
108
109 ARMNN_ASSERT(inferredShapes.size() == 1);
110
111 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "StridedSliceLayer");
112 }
113
ExecuteStrategy(IStrategy & strategy) const114 void StridedSliceLayer::ExecuteStrategy(IStrategy& strategy) const
115 {
116 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
117 }
118
119 } // namespace armnn
120