# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch.library namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") ############# ## prepack ## ############# def prepack_impl(x: torch.Tensor): return x name = "prepack" lib.define(f"{name}(Tensor x) -> Tensor") lib.impl(name, prepack_impl, "CompositeExplicitAutograd") prepack_op = getattr(getattr(torch.ops, namespace), name) ##################### ## conv_with_clamp ## ##################### def conv_with_clamp_impl( input, weight, bias=None, stride=1, padding=0, dilation=1, transposed=False, output_padding=0, groups=1, output_min=-float("inf"), output_max=float("inf"), ): return torch.clamp( torch.convolution( input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, ), output_min, output_max, ) name = "conv_with_clamp" lib.define( f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor" ) lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) ######################### ## conv_with_clamp.out ## ######################### def conv_with_clamp_out_impl( input, weight, bias=None, stride=1, padding=0, dilation=1, transposed=False, output_padding=0, groups=1, output_min=-float("inf"), output_max=float("inf"), out=None, ): out = conv_with_clamp_impl( input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, output_min, output_max, ) return out name = "conv_with_clamp.out" lib.define( f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") ################# ## grid_priors ## ################# # The dimension of x should be larger than 1 def grid_priors_impl( x, stride, offset, ): height, width = x.shape[-2:] # Need to specify device of torch.arange to avoid executorch exporting error shift_x = (torch.arange(0, width, device=x.device) + offset) * stride shift_y = (torch.arange(0, height, device=x.device) + offset) * stride # Need to specify indexing parameter ('ij' is the default value) to avoid executorch exporting error shift_xx, shift_yy = torch.meshgrid([shift_y, shift_x], indexing="ij") shift_xx = shift_xx.reshape(-1) shift_yy = shift_yy.reshape(-1) shifts = torch.stack((shift_yy, shift_xx), dim=-1) return shifts name = "grid_priors" lib.define(f"{name}(Tensor self, int stride, float offset) -> Tensor") lib.impl(name, grid_priors_impl, "CompositeExplicitAutograd") grid_priors_op = getattr(getattr(torch.ops, namespace), name) # When lowering to executorch, ops are converted from default to out variant. Hence, custom ops define both variants. def grid_priors_out_impl( x, stride, offset, out, ): out = grid_priors_impl(x, stride, offset) return out name = "grid_priors_out" lib.define( f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd") ######################## ## linear_weight_int4 ## ######################## def linear_weight_int4_impl( x: torch.Tensor, weights_4x8: torch.Tensor, groupsize: int, scales_and_zeros: torch.Tensor, inner_k_tiles: int, ): original_x_size = x.size() out_features = weights_4x8.size(0) x = x.reshape(-1, original_x_size[-1]) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( weights_4x8, inner_k_tiles ) out = torch.ops.aten._weight_int4pack_mm( x, weight_int4pack, groupsize, scales_and_zeros ) out_shape = original_x_size[:-1] + (out_features,) return out.reshape(out_shape) name = "linear_weight_int4" lib.define( f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor" ) lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) ###################### ## apply_rotary_emb ## ###################### # Note that this implementation is copied from executorch.examples.models.llama.rope # but it is copied here to avoid introducing a dependency on the llama code. def apply_rotary_emb_impl( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ): def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim freqs_cis_ndim = freqs_cis.ndim if freqs_cis_ndim == 3: # freqs_cis: (seq_len, n_heads, head_dim // 2) assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) shape = [ d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 for i, d in enumerate(x.shape) ] else: # freqs_cis: (seq_len, head_dim // 2) assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) name = "apply_rotary_emb" lib.define( f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)" ) lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)