1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSConstantPadND, 15 MPSGraph, 16) 17from executorch.exir.sym_util import eval_shape 18 19 20@register_node_visitor 21class ConstantPadNDVisitor(NodeVisitor): 22 target = "aten.constant_pad_nd.default" 23 24 def __init__(self, *args) -> None: 25 super().__init__(*args) 26 27 def define_node( 28 self, 29 node: torch.fx.Node, 30 mps_graph: MPSGraph, 31 ) -> None: 32 mps_node = self.create_unary_node(node, mps_graph, MPSConstantPadND) 33 34 mps_node.mpsnode_union.pad = eval_shape(cast(torch.SymInt, node.args[1])) 35 mps_node.mpsnode_union.value = float(node.args[2]) 36 37 mps_graph.mps_nodes.append(mps_node) 38