1 /****************************************************************************** 2 * Copyright (c) 2024, Tri Dao. 3 ******************************************************************************/ 4 5 #pragma once 6 7 #include <ATen/native/transformers/cuda/flash_attn/philox.cuh> 8 #include <ATen/native/transformers/cuda/flash_attn/utils.h> 9 10 namespace pytorch_flash { 11 12 using namespace cute; 13 14 struct Dropout { 15 16 const unsigned long long seed, offset; 17 const uint8_t p_dropout_in_uint8_t; 18 DropoutDropout19 __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, 20 const uint8_t p_dropout_in_uint8_t, 21 const int bid, const int hid, const int tid, const int nheads) 22 : seed(seed) 23 , offset(offset + (bid * nheads + hid) * 32 + tid % 32) 24 , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { 25 } 26 27 template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> apply_dropoutDropout28 __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_, 29 int block_row_start, int block_col_start, int block_row_stride) { 30 // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) 31 Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout())); 32 using T = typename Engine::value_type; 33 auto encode_dropout = [](bool keep, T val) { 34 return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); 35 }; 36 static_assert(decltype(size<2>(tensor))::value % 2 == 0); 37 const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); 38 const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); 39 // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } 40 #pragma unroll 41 for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { 42 uint2 rowcol = make_uint2(block_row_start, block_col_start); 43 #pragma unroll 44 for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { 45 // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} 46 uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset); 47 // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} 48 uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4); 49 // Special implementation for 16-bit types: we duplicate the threshold to the 50 // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction 51 // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, 52 // and the high 16 bits will be either 0xffff or 0x0000, depending on whether 53 // the random value is less than the threshold. 54 // We then do a bit-wise AND between the mask and the original value (in 32-bit). 55 // We're exploiting the fact that floating point comparison is equivalent to integer 56 // comparison, since we're comparing unsigned integers whose top 8-bits are zero. 57 if (!encode_dropout_in_sign_bit 58 && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) { 59 uint16_t rnd_16[16]; 60 #pragma unroll 61 for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } 62 uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16); 63 #pragma unroll 64 for (int j = 0; j < 2; j++) { 65 Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); 66 // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } 67 // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 68 #pragma unroll 69 for (int i = 0; i < 4; i++) { 70 uint32_t mask; 71 asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); 72 tensor_uint32(i) &= mask; 73 } 74 // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 75 } 76 } else { 77 #pragma unroll 78 for (int j = 0; j < 2; j++) { 79 #pragma unroll 80 for (int i = 0; i < 8; i++) { 81 tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); 82 } 83 Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); 84 // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } 85 } 86 } 87 // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { 88 // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); 89 // // } 90 } 91 } 92 } 93 94 }; 95 96 } // namespace pytorch_flash 97