1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "LogSoftmax.hpp"
7
8 #include <armnnUtils/TensorUtils.hpp>
9 #include <armnn/utility/Assert.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/NumericCast.hpp>
12
13 #include <cmath>
14
15 namespace
16 {
17
ValidateAxis(int axis,unsigned int numDimensions)18 inline bool ValidateAxis(int axis, unsigned int numDimensions)
19 {
20 const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
21 return axis < sNumDimensions && axis >= -sNumDimensions;
22 }
23
24 } // anonymous namespace
25
26 namespace armnn
27 {
28
LogSoftmax(Decoder<float> & input,Encoder<float> & output,const TensorInfo & inputInfo,const LogSoftmaxDescriptor & descriptor)29 void LogSoftmax(Decoder<float>& input,
30 Encoder<float>& output,
31 const TensorInfo& inputInfo,
32 const LogSoftmaxDescriptor& descriptor)
33 {
34 const unsigned int numDimensions = inputInfo.GetNumDimensions();
35
36 bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
37 ARMNN_ASSERT_MSG(axisIsValid,
38 "Axis index is not in range [-numDimensions, numDimensions).");
39 IgnoreUnused(axisIsValid);
40
41 unsigned int uAxis = descriptor.m_Axis < 0 ?
42 numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
43 armnn::numeric_cast<unsigned int>(descriptor.m_Axis);
44
45 const TensorShape& inputShape = inputInfo.GetShape();
46 const unsigned int outerSize = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
47 const unsigned int axisSize = inputShape[uAxis];
48 const unsigned int innerSize = armnnUtils::GetNumElementsBetween(inputShape,
49 uAxis + 1,
50 inputShape.GetNumDimensions());
51
52 for (unsigned int outer = 0; outer < outerSize; ++outer)
53 {
54 for (unsigned int inner = 0; inner < innerSize; ++inner)
55 {
56 // Find max
57 input[outer * axisSize * innerSize + inner];
58 float maxValue = input.Get();
59 for (unsigned int i = 1u; i < axisSize; ++i)
60 {
61 input[(outer * axisSize + i) * innerSize + inner];
62 maxValue = std::max(maxValue, input.Get());
63 }
64
65 // Compute sum
66 float sum = 0.0f;
67 for (unsigned int i = 0u; i < axisSize; ++i)
68 {
69 input[(outer * axisSize + i) * innerSize + inner];
70 sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
71 }
72
73 // Compute log sum
74 const float logSum = std::log(sum);
75
76 // Compute result
77 for (unsigned int i = 0u; i < axisSize; ++i)
78 {
79 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80
81 input [index];
82 output[index];
83
84 output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
85 }
86 }
87 }
88 }
89
90 } // namespace armnn
91