• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6from typing import cast
7
8import torch
9from executorch.backends.apple.mps.operators.node_visitor import (
10    NodeVisitor,
11    register_node_visitor,
12)
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    MPSClamp,
15    MPSGraph,
16    MPSMinMax,
17    MPSWhere,
18)
19
20
21@register_node_visitor
22class ClampVisitor(NodeVisitor):
23    target = "aten.clamp.default"
24
25    def __init__(self, *args) -> None:
26        super().__init__(*args)
27
28    def define_node(
29        self,
30        node: torch.fx.Node,
31        mps_graph: MPSGraph,
32    ) -> None:
33        mps_node = self.create_unary_node(node, mps_graph, MPSClamp)
34
35        min_value = "-inf"
36        max_value = "inf"
37
38        if len(node.args) >= 2 and node.args[1] is not None:
39            min_value = cast(float, node.args[1])
40
41        if len(node.args) >= 3 and node.args[2] is not None:
42            max_value = cast(float, node.args[2])
43
44        mps_node.min_max = MPSMinMax(min_value=min_value, max_value=max_value)
45        mps_graph.mps_nodes.append(mps_node)
46
47
48@register_node_visitor
49class WhereVisitor(NodeVisitor):
50    target = "aten.where.self"
51
52    def __init__(self, *args) -> None:
53        super().__init__(*args)
54
55    def define_node(
56        self,
57        node: torch.fx.Node,
58        mps_graph: MPSGraph,
59    ) -> None:
60        mps_graph.mps_nodes.append(self.create_tertiary_node(node, mps_graph, MPSWhere))
61