1 /* 2 * Copyright (c) 2017-2020 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H 25 #define ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H 26 27 #include "arm_compute/core/KernelDescriptors.h" 28 #include "src/core/CL/ICLSimple3DKernel.h" 29 30 namespace arm_compute 31 { 32 class ICLTensor; 33 34 /** Interface for max, shifting, exponentiating and summing the logits */ 35 class CLLogits1DMaxShiftExpSumKernel : public ICLKernel 36 { 37 public: 38 /** Info for whether a parallel reduction will be run and the vector size of the execution. */ 39 using ParallelReductionInfo = std::tuple<bool, unsigned int>; 40 41 public: 42 /** Default constructor */ 43 CLLogits1DMaxShiftExpSumKernel(); 44 /** Prevent instances of this class from being copied (As this class contains pointers) */ 45 CLLogits1DMaxShiftExpSumKernel(const CLLogits1DMaxShiftExpSumKernel &) = delete; 46 /** Prevent instances of this class from being copied (As this class contains pointers) */ 47 CLLogits1DMaxShiftExpSumKernel &operator=(const CLLogits1DMaxShiftExpSumKernel &) = delete; 48 /** Allow instances of this class to be moved */ 49 CLLogits1DMaxShiftExpSumKernel(CLLogits1DMaxShiftExpSumKernel &&) = default; 50 /** Allow instances of this class to be moved */ 51 CLLogits1DMaxShiftExpSumKernel &operator=(CLLogits1DMaxShiftExpSumKernel &&) = default; 52 /** Set the input and output tensors. 53 * 54 * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 55 * @param[in,out] max Max values tensor. Data types supported: same as @p input 56 * @param[out] output Destination tensor. Data types supported: same as @p input 57 * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input 58 * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. 59 */ 60 void configure(const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info); 61 /** Set the input and output tensors. 62 * 63 * @param[in] compile_context The compile context to be used. 64 * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 65 * @param[in,out] max Max values tensor. Data types supported: same as @p input 66 * @param[out] output Destination tensor. Data types supported: same as @p input 67 * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input 68 * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. 69 */ 70 void configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, const SoftmaxKernelInfo &info); 71 /** Static function to check if given info will lead to a valid configuration of @ref CLLogits1DMaxShiftExpSumKernel 72 * 73 * @param[in] input Source tensor. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 74 * @param[in] max Max values tensor. Data types supported: same as @p input 75 * @param[in] output Destination tensor. Data types supported: same as @p input 76 * @param[in] sum Sum of 1D logits tensor. Data types supported: same as @p input 77 * 78 * @return a status 79 */ 80 static Status validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum); 81 /** Checks if the given size is eligible for parallel reduction 82 * 83 * @note Serial reduction is launched for width < (_grid_size * _serial_vector_size). 84 * @note Parallel reduction is launched for width >= (_grid_size * _serial_vector_size) and vector_size is forced to 4. 85 * 86 * @param[in] size Size to check 87 * 88 * @return A two-element tuple where the first element is a boolean specifying if a parallel reduction will be run, 89 * while the second element is the vector size of the execution. 90 */ 91 static ParallelReductionInfo is_parallel_reduction(size_t size); 92 93 // Inherited methods overridden: 94 void run(const Window &window, cl::CommandQueue &queue) override; 95 96 private: 97 const ICLTensor *_input; 98 ICLTensor *_max; 99 ICLTensor *_output; 100 ICLTensor *_sum; 101 102 private: 103 static const unsigned int _grid_size; 104 static const unsigned int _serial_vector_size; 105 static const unsigned int _parallel_vector_size; 106 }; 107 /** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */ 108 class CLLogits1DNormKernel : public ICLKernel 109 { 110 public: 111 /** Default constructor */ 112 CLLogits1DNormKernel(); 113 /** Prevent instances of this class from being copied (As this class contains pointers) */ 114 CLLogits1DNormKernel(const CLLogits1DNormKernel &) = delete; 115 /** Prevent instances of this class from being copied (As this class contains pointers) */ 116 CLLogits1DNormKernel &operator=(const CLLogits1DNormKernel &) = delete; 117 /** Allow instances of this class to be moved */ 118 CLLogits1DNormKernel(CLLogits1DNormKernel &&) = default; 119 /** Allow instances of this class to be moved */ 120 CLLogits1DNormKernel &operator=(CLLogits1DNormKernel &&) = default; 121 /** Set the input and output tensors. 122 * 123 * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. 124 * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input 125 * @param[out] output Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input 126 * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. 127 */ 128 void configure(const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info); 129 /** Set the input and output tensors. 130 * 131 * @param[in] compile_context The compile context to be used. 132 * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. 133 * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input 134 * @param[out] output Destination tensor. Data types supported: QASYMM8/QASYMM8_SIGNED for S32 @p input, or same as @p input 135 * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. 136 */ 137 void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, const SoftmaxKernelInfo &info); 138 /** Static function to check if given info will lead to a valid configuration of @ref CLLogits1DNormKernel 139 * 140 * @param[in] input Source tensor. Data types supported: S32/F16/F32. If this kernel is used for log softmax, only F32/F16 is supported. 141 * @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input 142 * @param[in] output Destination tensor. Data types supported: QASYMM8 for S32 @p input, or same as @p input 143 * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo. 144 * 145 * @return a status 146 */ 147 static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info); 148 149 // Inherited methods overridden: 150 void run(const Window &window, cl::CommandQueue &queue) override; 151 152 private: 153 const ICLTensor *_input; 154 const ICLTensor *_sum; 155 ICLTensor *_output; 156 }; 157 } // namespace arm_compute 158 #endif /*ARM_COMPUTE_CLSOFTMAXLAYERKERNEL_H */ 159