#define TORCH_ASSERT_NO_OPERATORS #include #include #include #include #include #include #include #include namespace at::native { namespace { template void fill_non_native_type(TensorIterator& iter, const Scalar& value_scalar) { auto value = value_scalar.to().x; using H = typename std::make_signed::type; // Signed type has more acceleration // Reserve the representation of value. static_cast(value) is implementation defined. H val = *reinterpret_cast(std::addressof(value)); cpu_kernel_vec( iter, [val]() -> H { return val; }, [val]() { return Vectorized(val); }); } template <> void fill_non_native_type>(TensorIterator& iter, const Scalar& value_scalar) { static_assert(sizeof(c10::complex) == sizeof(int32_t), "Size of ComplexHalf should be 32-bits"); auto value = c10::complex(value_scalar.to>()); auto val = *reinterpret_cast(std::addressof(value)); cpu_kernel_vec( iter, [val]() -> int32_t { return val; }, [val]() { return Vectorized(val); }); } void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) { if (iter.dtype() == ScalarType::Half) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::BFloat16) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::ComplexHalf) { fill_non_native_type>(iter, value_scalar); } else if (iter.dtype() == ScalarType::Float8_e4m3fn) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::Float8_e5m2) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::Float8_e4m3fnuz) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) { fill_non_native_type(iter, value_scalar); } else { AT_DISPATCH_V2( iter.dtype(), "fill_cpu", AT_WRAP([&]() { scalar_t value = value_scalar.to(); cpu_kernel_vec( iter, [=]() -> scalar_t { return value; }, [=]() { return Vectorized(value); }); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kBool, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) ); } } } // namespace REGISTER_DISPATCH(fill_stub, &fill_kernel); } // namespace at::native