• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "LogSoftmax.hpp"
7 
8 #include <armnnUtils/TensorUtils.hpp>
9 #include <armnn/utility/Assert.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/NumericCast.hpp>
12 
13 #include <cmath>
14 
15 namespace
16 {
17 
ValidateAxis(int axis,unsigned int numDimensions)18 inline bool ValidateAxis(int axis, unsigned int numDimensions)
19 {
20     const int sNumDimensions = armnn::numeric_cast<int>(numDimensions);
21     return axis < sNumDimensions && axis >= -sNumDimensions;
22 }
23 
24 } // anonymous namespace
25 
26 namespace armnn
27 {
28 
LogSoftmax(Decoder<float> & input,Encoder<float> & output,const TensorInfo & inputInfo,const LogSoftmaxDescriptor & descriptor)29 void LogSoftmax(Decoder<float>& input,
30                 Encoder<float>& output,
31                 const TensorInfo& inputInfo,
32                 const LogSoftmaxDescriptor& descriptor)
33 {
34     const unsigned int numDimensions = inputInfo.GetNumDimensions();
35 
36     bool axisIsValid = ValidateAxis(descriptor.m_Axis, numDimensions);
37     ARMNN_ASSERT_MSG(axisIsValid,
38         "Axis index is not in range [-numDimensions, numDimensions).");
39     IgnoreUnused(axisIsValid);
40 
41     unsigned int uAxis = descriptor.m_Axis < 0  ?
42         numDimensions - armnn::numeric_cast<unsigned int>(std::abs(descriptor.m_Axis)) :
43         armnn::numeric_cast<unsigned int>(descriptor.m_Axis);
44 
45     const TensorShape& inputShape = inputInfo.GetShape();
46     const unsigned int outerSize  = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
47     const unsigned int axisSize   = inputShape[uAxis];
48     const unsigned int innerSize  = armnnUtils::GetNumElementsBetween(inputShape,
49                                                                       uAxis + 1,
50                                                                       inputShape.GetNumDimensions());
51 
52     for (unsigned int outer = 0; outer < outerSize; ++outer)
53     {
54         for (unsigned int inner = 0; inner < innerSize; ++inner)
55         {
56             // Find max
57             input[outer * axisSize * innerSize + inner];
58             float maxValue = input.Get();
59             for (unsigned int i = 1u; i < axisSize; ++i)
60             {
61                 input[(outer * axisSize + i) * innerSize + inner];
62                 maxValue = std::max(maxValue, input.Get());
63             }
64 
65             // Compute sum
66             float sum = 0.0f;
67             for (unsigned int i = 0u; i < axisSize; ++i)
68             {
69                 input[(outer * axisSize + i) * innerSize + inner];
70                 sum += std::exp((input.Get() - maxValue) * descriptor.m_Beta);
71             }
72 
73             // Compute log sum
74             const float logSum = std::log(sum);
75 
76             // Compute result
77             for (unsigned int i = 0u; i < axisSize; ++i)
78             {
79                 const unsigned int index = (outer * axisSize + i) * innerSize + inner;
80 
81                 input [index];
82                 output[index];
83 
84                 output.Set((input.Get() - maxValue) * descriptor.m_Beta - logSum);
85             }
86         }
87     }
88 }
89 
90 } // namespace armnn
91