#pragma once #include #include #include #include #include namespace torch::utils { template inline T unpackIntegral(PyObject* obj, const char* type) { #if PY_VERSION_HEX >= 0x030a00f0 // In Python-3.10 floats can no longer be silently converted to integers // Keep backward compatible behavior for now if (PyFloat_Check(obj)) { return c10::checked_convert(THPUtils_unpackDouble(obj), type); } return c10::checked_convert(THPUtils_unpackLong(obj), type); #else return static_cast(THPUtils_unpackLong(obj)); #endif } inline void store_scalar(void* data, at::ScalarType scalarType, PyObject* obj) { switch (scalarType) { case at::kByte: *(uint8_t*)data = unpackIntegral(obj, "uint8"); break; case at::kUInt16: *(uint16_t*)data = unpackIntegral(obj, "uint16"); break; case at::kUInt32: *(uint32_t*)data = unpackIntegral(obj, "uint32"); break; case at::kUInt64: // NB: This doesn't allow implicit conversion of float to int *(uint64_t*)data = THPUtils_unpackUInt64(obj); break; case at::kChar: *(int8_t*)data = unpackIntegral(obj, "int8"); break; case at::kShort: *(int16_t*)data = unpackIntegral(obj, "int16"); break; case at::kInt: *(int32_t*)data = unpackIntegral(obj, "int32"); break; case at::kLong: *(int64_t*)data = unpackIntegral(obj, "int64"); break; case at::kHalf: *(at::Half*)data = at::convert(THPUtils_unpackDouble(obj)); break; case at::kFloat: *(float*)data = (float)THPUtils_unpackDouble(obj); break; case at::kDouble: *(double*)data = THPUtils_unpackDouble(obj); break; case at::kComplexHalf: *(c10::complex*)data = (c10::complex)static_cast>( THPUtils_unpackComplexDouble(obj)); break; case at::kComplexFloat: *(c10::complex*)data = (c10::complex)THPUtils_unpackComplexDouble(obj); break; case at::kComplexDouble: *(c10::complex*)data = THPUtils_unpackComplexDouble(obj); break; case at::kBool: *(bool*)data = THPUtils_unpackNumberAsBool(obj); break; case at::kBFloat16: *(at::BFloat16*)data = at::convert(THPUtils_unpackDouble(obj)); break; case at::kFloat8_e5m2: *(at::Float8_e5m2*)data = at::convert(THPUtils_unpackDouble(obj)); break; case at::kFloat8_e5m2fnuz: *(at::Float8_e5m2fnuz*)data = at::convert(THPUtils_unpackDouble(obj)); break; case at::kFloat8_e4m3fn: *(at::Float8_e4m3fn*)data = at::convert(THPUtils_unpackDouble(obj)); break; case at::kFloat8_e4m3fnuz: *(at::Float8_e4m3fnuz*)data = at::convert(THPUtils_unpackDouble(obj)); break; default: throw std::runtime_error("invalid type"); } } inline PyObject* load_scalar(const void* data, at::ScalarType scalarType) { switch (scalarType) { case at::kByte: return THPUtils_packInt64(*(uint8_t*)data); case at::kUInt16: return THPUtils_packInt64(*(uint16_t*)data); case at::kUInt32: return THPUtils_packUInt32(*(uint32_t*)data); case at::kUInt64: return THPUtils_packUInt64(*(uint64_t*)data); case at::kChar: return THPUtils_packInt64(*(int8_t*)data); case at::kShort: return THPUtils_packInt64(*(int16_t*)data); case at::kInt: return THPUtils_packInt64(*(int32_t*)data); case at::kLong: return THPUtils_packInt64(*(int64_t*)data); case at::kHalf: return PyFloat_FromDouble( at::convert(*(at::Half*)data)); case at::kFloat: return PyFloat_FromDouble(*(float*)data); case at::kDouble: return PyFloat_FromDouble(*(double*)data); case at::kComplexHalf: { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); } case at::kComplexFloat: { auto data_ = reinterpret_cast*>(data); return PyComplex_FromDoubles(data_->real(), data_->imag()); } case at::kComplexDouble: return PyComplex_FromCComplex( *reinterpret_cast((c10::complex*)data)); case at::kBool: return PyBool_FromLong(*(bool*)data); case at::kBFloat16: return PyFloat_FromDouble( at::convert(*(at::BFloat16*)data)); case at::kFloat8_e5m2: return PyFloat_FromDouble( at::convert(*(at::Float8_e5m2*)data)); case at::kFloat8_e4m3fn: return PyFloat_FromDouble( at::convert(*(at::Float8_e4m3fn*)data)); case at::kFloat8_e5m2fnuz: return PyFloat_FromDouble(at::convert( *(at::Float8_e5m2fnuz*)data)); case at::kFloat8_e4m3fnuz: return PyFloat_FromDouble(at::convert( *(at::Float8_e4m3fnuz*)data)); default: throw std::runtime_error("invalid type"); } } } // namespace torch::utils