#pragma once #include #include #include #include namespace c10 { // Dynamic type casting utils: // - fetch_and_cast // - cast_and_store // // fetch_and_cast fetch a value with dynamic type specified by a ScalarType // from a void pointer and cast it to a static type. // // cast_and_store casts a static typed value into dynamic type specified // by a ScalarType, and store it into a void pointer. // // NOTE: // // Dynamic casting allows us to support type promotion without blowing up // the combination space: For example, without dynamic cast, in order to // implement `add_` with type promotion, we would need something like // // AT_DISPATCH_ALL_TYPES(output.dtype(), // AT_DISPATCH_ALL_TYPES(input1.dtype(), // AT_DISPATCH_ALL_TYPES(input2.dtype(), // [](arg0_t a, arg1_t b) -> out_t { return a + b; } // ) // ) // ) // // If we support N dtypes, the above code would generate the a+b kernel for // all the N * N * N different supported types, the compilation time and // binary size would become horrible. // // Dynamic casting might sounds like a bad idea in terms of performance. // Especially if you ever do it in a loop, you are going to do a billion tests. // But in practice it is not as bad as it might look: // // - on CPU, this is a branch that always has the same outcome, therefore // hopefully the branch predictor could do the job pretty well // - on GPU, these branches will not diverge, so we could still have the same // warp executing the same line of code // - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to // check an integer does not hurt the performance much because the ALUs would // wait for load instructions anyway. // // For the discussion and benchmark, refer to: // - https://github.com/pytorch/pytorch/pull/28343 // - https://github.com/pytorch/pytorch/pull/28344 // - https://github.com/pytorch/pytorch/pull/28345 // #ifdef C10_HOST_DEVICE #define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); #else #define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); #endif // Fetch a value with dynamic type src_type from ptr, and cast it to static type // dest_t. #define FETCH_AND_CAST_CASE(type, scalartype) \ case ScalarType::scalartype: \ return c10::convert(c10::load(ptr)); template C10_HOST_DEVICE inline dest_t fetch_and_cast( const ScalarType src_type, const void* ptr) { switch (src_type) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) FETCH_AND_CAST_CASE(uint16_t, UInt16) FETCH_AND_CAST_CASE(uint32_t, UInt32) FETCH_AND_CAST_CASE(uint64_t, UInt64) default: ERROR_UNSUPPORTED_CAST } return dest_t(0); // just to avoid compiler warning } // Cast a value with static type src_t into dynamic dest_type, and store it to // ptr. #define CAST_AND_STORE_CASE(type, scalartype) \ case ScalarType::scalartype: \ *(type*)ptr = c10::convert(value); \ return; template C10_HOST_DEVICE inline void cast_and_store( const ScalarType dest_type, void* ptr, src_t value) { switch (dest_type) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) CAST_AND_STORE_CASE(uint16_t, UInt16) CAST_AND_STORE_CASE(uint32_t, UInt32) CAST_AND_STORE_CASE(uint64_t, UInt64) default:; } ERROR_UNSUPPORTED_CAST } #define DEFINE_UNCASTABLE(T, scalartype_) \ template <> \ C10_HOST_DEVICE inline T fetch_and_cast( \ const ScalarType src_type, const void* ptr) { \ CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \ return c10::load(ptr); \ } \ template <> \ C10_HOST_DEVICE inline void cast_and_store( \ const ScalarType dest_type, void* ptr, T value) { \ CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ *(T*)ptr = value; \ } AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) #undef FETCH_AND_CAST_CASE #undef CAST_AND_STORE_CASE #undef DEFINE_UNCASTABLE #undef ERROR_UNSUPPORTED_CAST } // namespace c10