/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include namespace torch { namespace executor { namespace native { using Tensor = exec_aten::Tensor; bool check_quantized_mixed_linear_args( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, const exec_aten::optional dtype, Tensor& out) { ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2)); ET_LOG_AND_RETURN_IF_FALSE( tensor_is_rank(weight_scales, 1) || tensor_is_rank(weight_scales, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2)); ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1)); ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_size_at_dims(weight_scales, 0, weight, 0)); ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1)); ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales)); if (dtype.has_value()) { ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value()); ET_LOG_MSG_AND_RETURN_IF_FALSE( dtype.value() == ScalarType::Float || dtype.value() == ScalarType::Half, "dtype must be Float or Half"); } ET_LOG_MSG_AND_RETURN_IF_FALSE( weight.scalar_type() == ScalarType::Char, "weight dtype must be int8"); ET_LOG_MSG_AND_RETURN_IF_FALSE( in.scalar_type() == ScalarType::Float || in.scalar_type() == ScalarType::Half, "input dtype must be Float or Half"); if (opt_weight_zero_points.has_value()) { ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales)); ET_LOG_AND_RETURN_IF_FALSE( tensors_have_same_dtype(opt_weight_zero_points.value(), in)); } // Support for non-null zero points is not implemented yet. ET_LOG_MSG_AND_RETURN_IF_FALSE( !opt_weight_zero_points.has_value(), "zero points not supported yet."); return true; } Tensor& quantized_mixed_linear_out( const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, const exec_aten::optional dtype, Tensor& out) { // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. ET_CHECK(check_quantized_mixed_linear_args( in, weight, weight_scales, opt_weight_zero_points, dtype, out)); ScalarType out_dtype = dtype.has_value() ? dtype.value() : out.scalar_type(); size_t output_ndim = 2; exec_aten::SizesType output_sizes[kTensorDimensionLimit]; output_sizes[0] = in.size(0); output_sizes[1] = weight.size(0); // TODO (gjcomer) Replace with ET_KERNEL_CHECK when context is available. ET_CHECK(resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok); constexpr auto name = "quantized_decomposed::mixed_linear.out"; ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() { ET_SWITCH_FLOAT_TYPES_AND(Half, out_dtype, ctx, name, CTYPE_OUT, [&]() { size_t m = in.size(0); size_t n = in.size(1); size_t p = weight.size(0); size_t g = n; if (weight_scales.dim() == 2) { g = (n + weight_scales.size(1) - 1) / weight_scales.size(1); }; // FIXME: this currently ignores dtype vec_quantized_matmul_transb_int8< CTYPE_OUT, // T *z CTYPE>( // U *x, U *s out.mutable_data_ptr(), in.const_data_ptr(), weight.const_data_ptr(), weight_scales.const_data_ptr(), m, n, p, g); }); }); return out; } Tensor& quantized_mixed_linear_out( KernelRuntimeContext& ctx, const Tensor& in, const Tensor& weight, const Tensor& weight_scales, const exec_aten::optional& opt_weight_zero_points, const exec_aten::optional dtype, Tensor& out) { // TODO(mcandales): Remove the need for this wrapper // TODO(mkg): add support for dtype (void)ctx; return quantized_mixed_linear_out( in, weight, weight_scales, opt_weight_zero_points, dtype, out); } } // namespace native } // namespace executor } // namespace torch