1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6from typing import cast 7 8import torch 9from executorch.backends.apple.mps.operators.node_visitor import ( 10 NodeVisitor, 11 register_node_visitor, 12) 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 MPSClamp, 15 MPSGraph, 16 MPSMinMax, 17 MPSWhere, 18) 19 20 21@register_node_visitor 22class ClampVisitor(NodeVisitor): 23 target = "aten.clamp.default" 24 25 def __init__(self, *args) -> None: 26 super().__init__(*args) 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 mps_graph: MPSGraph, 32 ) -> None: 33 mps_node = self.create_unary_node(node, mps_graph, MPSClamp) 34 35 min_value = "-inf" 36 max_value = "inf" 37 38 if len(node.args) >= 2 and node.args[1] is not None: 39 min_value = cast(float, node.args[1]) 40 41 if len(node.args) >= 3 and node.args[2] is not None: 42 max_value = cast(float, node.args[2]) 43 44 mps_node.min_max = MPSMinMax(min_value=min_value, max_value=max_value) 45 mps_graph.mps_nodes.append(mps_node) 46 47 48@register_node_visitor 49class WhereVisitor(NodeVisitor): 50 target = "aten.where.self" 51 52 def __init__(self, *args) -> None: 53 super().__init__(*args) 54 55 def define_node( 56 self, 57 node: torch.fx.Node, 58 mps_graph: MPSGraph, 59 ) -> None: 60 mps_graph.mps_nodes.append(self.create_tertiary_node(node, mps_graph, MPSWhere)) 61