# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass def get_meandim_decomposition(op) -> tuple: if op == exir_ops.edge.aten.mean.dim: return ( exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.full.default, exir_ops.edge.aten.mul.Tensor, ) if op == torch.ops.aten.mean.dim: return ( torch.ops.aten.sum.dim_IntList, torch.ops.aten.full.default, torch.ops.aten.mul.Tensor, ) raise RuntimeError(f"Can't get meandim decomposition for op {op}") class DecomposeMeanDimPass(ExportPass): """ This pass decomposes meandim into a sum and mul node. Example: y = mean_dim(x, dim, keepdim) Becomes: sum = sum.dim_IntList(x, dim, keepdim) y = mul(sum, 1/N) """ def call_operator(self, op, args, kwargs, meta): if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim): return super().call_operator(op, args, kwargs, meta) x = args[0] dim = args[1] keepdim = args[2] if len(args) > 2 else False if not keepdim: return super().call_operator(op, args, kwargs, meta) # if keepdim == True and dim == [-1, -2], mean.dim can be # decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool. if dim == [-1, -2]: # Simply return the mean.dim operator for future decomposition. return super().call_operator(op, args, kwargs, meta) shape = meta["val"].size() dtype = meta["val"].dtype input_shape = x.data.size() N = 1 for d in dim: N *= input_shape[d] sum_op, full_op, mul_op = get_meandim_decomposition(op) sum = super().call_operator(sum_op, (x, dim, keepdim), {}, meta) full = super().call_operator( full_op, ([1] * len(shape), 1 / N), {"dtype": dtype}, meta ) return super().call_operator(mul_op, (sum, full), {}, meta)