/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include namespace torch { namespace executor { namespace native { Tensor& sdpa_with_kv_cache_out_no_context( const Tensor& q_projected, const Tensor& k_projected, const Tensor& v_projected, Tensor& key_cache, Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, Tensor& output) { executorch::runtime::KernelRuntimeContext context{}; return torch::executor::native::sdpa_with_kv_cache_out( context, q_projected, k_projected, v_projected, key_cache, value_cache, start_pos, seq_len, attn_mask, dropout_p, is_causal, scale, output); } at::Tensor sdpa_with_kv_cache_aten( const at::Tensor& q_projected, const at::Tensor& k_projected, const at::Tensor& v_projected, at::Tensor& key_cache, at::Tensor& value_cache, const int64_t start_pos, const int64_t seq_len, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q_projected); WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11) (q_projected, k_projected, v_projected, key_cache, value_cache, start_pos, seq_len, attn_mask, dropout_p, is_causal, scale, output); return output; } Tensor& custom_sdpa_out_no_context( const Tensor& q, const Tensor& k, const Tensor& v, const int64_t start_pos, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const optional scale, Tensor& output) { exec_aten::RuntimeContext context{}; return torch::executor::native::custom_sdpa_out( context, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); } at::Tensor custom_sdpa_aten( const at::Tensor& q, const at::Tensor& k, const at::Tensor& v, const int64_t start_pos, // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional attn_mask, const double dropout_p, const bool is_causal, // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy const std::optional scale) { auto output = at::empty_like(q); WRAP_TO_ATEN(custom_sdpa_out_no_context, 8) (q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, output); return output; } Tensor& update_quantized_cache_out_no_context( const Tensor& value, Tensor& cache, const int64_t start_pos, Tensor& output) { exec_aten::RuntimeContext context{}; return torch::executor::native::update_quantized_cache_out( context, value, cache, start_pos, output); } at::Tensor update_quantized_cache_aten( const at::Tensor& value, at::Tensor& cache, const int64_t start_pos) { auto output = at::empty({1}); WRAP_TO_ATEN(update_quantized_cache_out_no_context, 3) (value, cache, start_pos, output); return output; } } // namespace native } // namespace executor } // namespace torch TORCH_LIBRARY_FRAGMENT(llama, m) { m.def( "sdpa_with_kv_cache(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None) -> Tensor"); m.def( "sdpa_with_kv_cache.out(Tensor query, Tensor key, Tensor value, Tensor(a!) key_cache, " "Tensor(b!) value_cache, SymInt start_pos, SymInt seq_len, Tensor? attn_mask=None, " "float drpout_p=0.0, bool is_causal=False, float? scale=None, *, Tensor(c!) out) -> Tensor(c!)"); m.def( "custom_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " "float? scale=None) -> Tensor"); m.def( "custom_sdpa.out(Tensor query, Tensor key, Tensor value, SymInt start_pos, " "Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, " "float? scale=None, *, Tensor(a!) out) -> Tensor(a!)"); m.def( "update_quantized_cache(Tensor value, Tensor(a!) cache, " "SymInt start_pos) -> Tensor"); m.def( "update_quantized_cache.out(Tensor value, Tensor(a!) cache, " "SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)"); } // TODO: Rename this file to op_custom_ops_aot.cpp TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) { m.impl( "sdpa_with_kv_cache", torch::executor::native::sdpa_with_kv_cache_aten); m.impl( "sdpa_with_kv_cache.out", WRAP_TO_ATEN( torch::executor::native::sdpa_with_kv_cache_out_no_context, 11)); m.impl("custom_sdpa", torch::executor::native::custom_sdpa_aten); m.impl( "custom_sdpa.out", WRAP_TO_ATEN(torch::executor::native::custom_sdpa_out_no_context, 8)); m.impl( "update_quantized_cache", torch::executor::native::update_quantized_cache_aten); m.impl( "update_quantized_cache.out", WRAP_TO_ATEN( torch::executor::native::update_quantized_cache_out_no_context, 3)); }