• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSConstantPadND,
15    MPSGraph,
16)
17from executorch.exir.sym_util import eval_shape
18
19
20@register_node_visitor
21class ConstantPadNDVisitor(NodeVisitor):
22    target = "aten.constant_pad_nd.default"
23
24    def __init__(self, *args) -> None:
25        super().__init__(*args)
26
27    def define_node(
28        self,
29        node: torch.fx.Node,
30        mps_graph: MPSGraph,
31    ) -> None:
32        mps_node = self.create_unary_node(node, mps_graph, MPSConstantPadND)
33
34        mps_node.mpsnode_union.pad = eval_shape(cast(torch.SymInt, node.args[1]))
35        mps_node.mpsnode_union.value = float(node.args[2])
36
37        mps_graph.mps_nodes.append(mps_node)
38