Lines Matching refs:axis
54 int32_t axis, float* outputData, const Shape& outputShape) { in softmaxSlowFloat32() argument
56 const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis); in softmaxSlowFloat32()
57 const uint32_t axisSize = getSizeOfDimension(inputShape, axis); in softmaxSlowFloat32()
59 getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape)); in softmaxSlowFloat32()
85 bool softmaxFloat32(const float* inputData, const Shape& inputShape, const float beta, int32_t axis, in softmaxFloat32() argument
88 NN_CHECK(handleNegativeAxis(inputShape, &axis)); in softmaxFloat32()
90 if (axis == ndim - 1) { in softmaxFloat32()
97 return softmaxSlowFloat32(inputData, inputShape, beta, axis, outputData, outputShape); in softmaxFloat32()
102 int32_t axis, _Float16* outputData, const Shape& outputShape) { in softmaxFloat16() argument
108 softmaxFloat32(inputData_float32.data(), inputShape, beta, axis, outputData_float32.data(), in softmaxFloat16()
116 bool softmaxQuant8Impl(const T* inputData, const Shape& inputShape, const float beta, int32_t axis, in softmaxQuant8Impl() argument
131 const uint32_t outerSize = getNumberOfElements(inputShape, 0, axis); in softmaxQuant8Impl()
132 const uint32_t axisSize = getSizeOfDimension(inputShape, axis); in softmaxQuant8Impl()
134 getNumberOfElements(inputShape, axis + 1, getNumberOfDimensions(inputShape)); in softmaxQuant8Impl()
203 bool softmaxQuant8(const T* inputData, const Shape& inputShape, const float beta, int32_t axis, in softmaxQuant8() argument
206 NN_CHECK(handleNegativeAxis(inputShape, &axis)); in softmaxQuant8()
228 return softmaxQuant8Impl(inputData, inputShape, beta, axis, inputMultiplier, inputLeftShift, in softmaxQuant8()
287 int32_t axis = (context->getNumInputs() == kNumInputs) in execute() local
294 context->getInputValue<_Float16>(kBetaScalar), axis, in execute()
300 context->getInputValue<float>(kBetaScalar), axis, in execute()
306 context->getInputValue<float>(kBetaScalar), axis, in execute()
312 context->getInputValue<float>(kBetaScalar), axis, in execute()