#include #include #include #include #include #if defined(_MSC_VER) #include #endif namespace { static inline void swapBytes16(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint16_t output; memcpy(&output, ptr, sizeof(uint16_t)); #if defined(_MSC_VER) && !defined(_DEBUG) output = _byteswap_ushort(output); #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC) output = __builtin_bswap16(output); #else uint16_t Hi = output >> 8; uint16_t Lo = output << 8; output = Hi | Lo; #endif memcpy(ptr, &output, sizeof(uint16_t)); } static inline void swapBytes32(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t output; memcpy(&output, ptr, sizeof(uint32_t)); #if defined(_MSC_VER) && !defined(_DEBUG) output = _byteswap_ulong(output); #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC) output = __builtin_bswap32(output); #else uint32_t Byte0 = output & 0x000000FF; uint32_t Byte1 = output & 0x0000FF00; uint32_t Byte2 = output & 0x00FF0000; uint32_t Byte3 = output & 0xFF000000; output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24); #endif memcpy(ptr, &output, sizeof(uint32_t)); } static inline void swapBytes64(void* ptr) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t output; memcpy(&output, ptr, sizeof(uint64_t)); #if defined(_MSC_VER) output = _byteswap_uint64(output); #elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC) output = __builtin_bswap64(output); #else uint64_t Byte0 = output & 0x00000000000000FF; uint64_t Byte1 = output & 0x000000000000FF00; uint64_t Byte2 = output & 0x0000000000FF0000; uint64_t Byte3 = output & 0x00000000FF000000; uint64_t Byte4 = output & 0x000000FF00000000; uint64_t Byte5 = output & 0x0000FF0000000000; uint64_t Byte6 = output & 0x00FF000000000000; uint64_t Byte7 = output & 0xFF00000000000000; output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) | (Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) | (Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8)); #endif memcpy(ptr, &output, sizeof(uint64_t)); } static inline uint16_t decodeUInt16(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint16_t output; memcpy(&output, data, sizeof(uint16_t)); return output; } static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) { uint16_t output = decodeUInt16(data); swapBytes16(&output); return output; } static inline uint32_t decodeUInt32(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint32_t output; memcpy(&output, data, sizeof(uint32_t)); return output; } static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) { uint32_t output = decodeUInt32(data); swapBytes32(&output); return output; } static inline uint64_t decodeUInt64(const uint8_t* data) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) uint64_t output; memcpy(&output, data, sizeof(uint64_t)); return output; } static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) { uint64_t output = decodeUInt64(data); swapBytes64(&output); return output; } } // anonymous namespace namespace torch::utils { THPByteOrder THP_nativeByteOrder() { uint32_t x = 1; return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN; } void THP_decodeInt16Buffer( int16_t* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src)); src += sizeof(int16_t); } } void THP_decodeInt32Buffer( int32_t* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src)); src += sizeof(int32_t); } } void THP_decodeInt64Buffer( int64_t* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src)); src += sizeof(int64_t); } } void THP_decodeHalfBuffer( c10::Half* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint16_t x; c10::Half f; }; x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src)); dst[i] = f; src += sizeof(uint16_t); } } void THP_decodeBFloat16Buffer( at::BFloat16* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { uint16_t x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src)); std::memcpy(&dst[i], &x, sizeof(dst[i])); src += sizeof(uint16_t); } } void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, size_t len) { for (const auto i : c10::irange(len)) { dst[i] = (int)src[i] != 0 ? true : false; } } void THP_decodeFloatBuffer( float* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float f; }; x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src)); dst[i] = f; src += sizeof(float); } } void THP_decodeDoubleBuffer( double* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double d; }; x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src)); dst[i] = d; src += sizeof(double); } } void THP_decodeComplexFloatBuffer( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float re; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t y; float im; }; x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src)); src += sizeof(float); y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src)); src += sizeof(float); dst[i] = c10::complex(re, im); } } void THP_decodeComplexDoubleBuffer( c10::complex* dst, const uint8_t* src, bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double re; }; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t y; double im; }; static_assert(sizeof(uint64_t) == sizeof(double)); x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src)); src += sizeof(double); y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src)); src += sizeof(double); dst[i] = c10::complex(re, im); } } void THP_decodeInt16Buffer( int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeInt32Buffer( int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeInt64Buffer( int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeHalfBuffer( c10::Half* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeBFloat16Buffer( at::BFloat16* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeFloatBuffer( float* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeDoubleBuffer( double* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeComplexFloatBuffer( c10::complex* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len); } void THP_decodeComplexDoubleBuffer( c10::complex* dst, const uint8_t* src, THPByteOrder order, size_t len) { THP_decodeComplexDoubleBuffer( dst, src, (order != THP_nativeByteOrder()), len); } void THP_encodeInt16Buffer( uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len) { memcpy(dst, src, sizeof(int16_t) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; swapBytes16(dst); dst += sizeof(int16_t); } } } void THP_encodeInt32Buffer( uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len) { memcpy(dst, src, sizeof(int32_t) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; swapBytes32(dst); dst += sizeof(int32_t); } } } void THP_encodeInt64Buffer( uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len) { memcpy(dst, src, sizeof(int64_t) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; swapBytes64(dst); dst += sizeof(int64_t); } } } void THP_encodeFloatBuffer( uint8_t* dst, const float* src, THPByteOrder order, size_t len) { memcpy(dst, src, sizeof(float) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; swapBytes32(dst); dst += sizeof(float); } } } void THP_encodeDoubleBuffer( uint8_t* dst, const double* src, THPByteOrder order, size_t len) { memcpy(dst, src, sizeof(double) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(len)) { (void)i; swapBytes64(dst); dst += sizeof(double); } } } template std::vector complex_to_float(const c10::complex* src, size_t len) { std::vector new_src; new_src.reserve(2 * len); for (const auto i : c10::irange(len)) { auto elem = src[i]; new_src.emplace_back(elem.real()); new_src.emplace_back(elem.imag()); } return new_src; } void THP_encodeComplexFloatBuffer( uint8_t* dst, const c10::complex* src, THPByteOrder order, size_t len) { auto new_src = complex_to_float(src, len); memcpy(dst, static_cast(&new_src), 2 * sizeof(float) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(2 * len)) { (void)i; // Suppress unused variable warning swapBytes32(dst); dst += sizeof(float); } } } void THP_encodeComplexDoubleBuffer( uint8_t* dst, const c10::complex* src, THPByteOrder order, size_t len) { auto new_src = complex_to_float(src, len); memcpy(dst, static_cast(&new_src), 2 * sizeof(double) * len); if (order != THP_nativeByteOrder()) { for (const auto i : c10::irange(2 * len)) { (void)i; // Suppress unused variable warning swapBytes64(dst); dst += sizeof(double); } } } } // namespace torch::utils