• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import collections
3from collections import defaultdict
4from typing import Any, Callable, Dict, Optional
5
6import torch
7import torch.utils._pytree as pytree
8
9
10aten = torch.ops.aten
11
12# We would like to split modules into two subgraphs for runtime weight updates to work correctly.
13# The use case and more information could be found at:
14# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing
15META_TAG = "MODULE_TYPE"
16MODULE_TAG = "_MAIN_MODULE"
17CONST_MODULE_TAG = "_CONST_MODULE"
18
19
20def replace_node_with_constant(gm, node, constant, name=None):
21    g = gm.graph
22
23    if name:
24        qualname = name
25    else:
26        if not hasattr(gm, "_frozen_param_count"):
27            gm._frozen_param_count = 0
28        i = gm._frozen_param_count
29
30        while True:
31            qualname = f"_frozen_param{i}"
32            if not hasattr(gm, qualname):
33                break
34            i += 1
35
36        gm._frozen_param_count = i + 1
37
38    with g.inserting_before(node):
39        new_input_node = g.create_node("get_attr", qualname, (), {})
40        node.replace_all_uses_with(new_input_node)
41        new_input_node.meta.update(node.meta)
42        g.erase_node(node)
43
44    # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
45    gm.register_buffer(qualname, constant)
46    setattr(gm, qualname, constant)
47
48
49class ConstantFolder(torch.fx.Interpreter):
50    def __init__(
51        self,
52        gm,
53        skip_constructors=False,
54    ):
55        super().__init__(gm)
56        self.node_replacements: Dict[torch.fx.Node, Any] = {}
57        self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
58        self.unknown_value = object()
59        self.skip_constructors: bool = skip_constructors
60
61        # overwrite this to deallocate env values if their only remaining use
62        # is the output
63        self.user_to_last_uses = self.node_to_last_non_output_use()
64
65    def is_impure(self, node: torch.fx.node.Node):
66        if (
67            node.target == torch.ops.prims.convert_element_type.default
68            and node.args[0].op == "get_attr"  # type: ignore[union-attr]
69            and node.args[0].meta["val"].dtype == torch.int8  # type: ignore[union-attr]
70            and node.args[1] == torch.bfloat16
71        ):
72            # For int8_weight -> dq -> bf16_weight
73            return True
74        if node.target in [
75            torch.ops.quantized_decomposed.dequantize_per_channel.default,
76            torch.ops.quantized_decomposed.dequantize_per_tensor.default,
77            torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
78        ]:
79            # For the pattern fp32_weight -> q -> dq
80            # We only folding fp32_weight -> q
81            # int8_weight and leave dq in graph to be fused
82            return True
83        return False
84
85    def node_to_last_non_output_use(self):
86        last_non_output_use = collections.defaultdict(list)
87        seen_uses = set()
88        output_node = next(iter(reversed(self.module.graph.nodes)))
89
90        for node in reversed(self.module.graph.nodes):
91            if node.target == "output":
92                continue
93
94            def add_use(inp):
95                if inp in seen_uses:
96                    return
97
98                seen_uses.add(inp)
99                last_non_output_use[node].append(inp)
100
101            # In-place is fine since we don't mutate
102            pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))
103
104            # if this node is only used in output, we want to gc it right away
105            if len(node.users) == 1 and output_node in node.users:
106                last_non_output_use[node].append(node)
107
108        return last_non_output_use
109
110    def run_node(self, node):
111        if node.target == "output":
112            # because we remove nodes from env on last non output use,
113            # re-define them now or we'll get error in interpreter
114            def set_env(arg):
115                self.env[arg] = self.unknown_value
116
117            # In-place is fine since we don't mutate
118            pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
119            return super().run_node(node)
120
121        args, kwargs = self.fetch_args_kwargs_from_env(node)
122        flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)
123
124        # We need to do this weird thing because in cases where flattened_inputs
125        # contains a ScriptObject, equality checking results in a type error if
126        # the types are different.
127        if any(
128            type(self.unknown_value) == type(input_) and self.unknown_value == input_
129            for input_ in flattened_inputs
130        ):
131            return self.unknown_value
132
133        # TODO - fix errors with this
134        if (
135            node.op == "call_function"
136            and node.target == aten._efficientzerotensor.default
137        ):
138            return self.unknown_value
139
140        # TODO - constant folding triton kernel returns the inputs -- fix this
141        if (
142            node.op == "call_function"
143            and node.name == "triton_kernel_wrapper_functional_proxy"
144        ):
145            return self.unknown_value
146
147        # skip constructors, since inductor generates optimal code for them already
148        # and turning into tensor would result in an additional global memory read
149        # TODO - more complicated strategy
150        if (
151            self.skip_constructors
152            and node.op != "get_attr"
153            and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
154        ):
155            return self.unknown_value
156
157        # All mutations should either be removed or on inputs which we did not make constant
158        if (
159            isinstance(node.target, torch._ops.OpOverload)
160            and torch.Tag.nondeterministic_seeded in node.target.tags
161        ):
162            return self.unknown_value
163
164        out = super().run_node(node)
165
166        if node.op != "get_attr" and isinstance(out, torch.Tensor):
167            if out.device.type == "meta":
168                return out
169
170            if not self.insertable_tensor_check(out):
171                return out
172
173            if self.is_impure(node):
174                return self.unknown_value
175
176            self.add_node_replacement(node, out)
177
178            flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
179
180            for n in flattened_node_inps:
181                if not isinstance(n, torch.fx.Node):
182                    continue
183
184                self.replaced_uses[n] += 1
185
186            for to_delete in self.user_to_last_uses.get(node, []):
187                if self.replaced_uses[to_delete] == len(to_delete.users):
188                    self.node_replacements.pop(to_delete, None)
189
190        return out
191
192    def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
193        return True
194
195    def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
196        self.node_replacements[node] = tensor
197
198    def run(self):
199        env = {}
200        for n in self.module.graph.find_nodes(op="placeholder"):
201            env[n] = self.unknown_value
202        return super().run(initial_env=env)
203
204
205def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None):
206    with torch.utils._python_dispatch._disable_current_modes():
207        cf = ConstantFolder(gm, skip_constructors=True)
208        cf.run()
209
210        for node, constant in cf.node_replacements.items():
211            if constraint_fn is not None and not constraint_fn(node):
212                continue
213            replace_node_with_constant(gm, node, constant)
214
215        erased_params = []
216        # Get all attr users by looking up the graph instead from node.users, because in this case
217        # _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
218
219        #     opcode         name                 target            args                         kwargs
220        # -------------  -------------------  ----------------  ---------------------------  --------
221        # placeholder    arg0_1               arg0              ()                           {}
222        # get_attr       _tensor_constant0    state             ()                           {}
223        # call_function  add                  aten.add.Tensor   (arg0_1, _tensor_constant0)  {}
224        # get_attr       _tensor_constant0_1  state             ()                           {}
225        # call_function  add_                 aten.add_.Tensor  (_tensor_constant0_1, 1)     {}
226        # output         output               output            ([add],)                     {}
227
228        get_attr_node_users = defaultdict(list)
229        for node in gm.graph.nodes:
230            if node.op == "get_attr":
231                get_attr_node_users[node.target].extend(node.users.keys())
232        for node in gm.graph.find_nodes(op="get_attr"):
233            if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
234                if hasattr(gm, node.target):
235                    delattr(gm, node.target)
236                erased_params.append(node)
237        for node in erased_params:
238            gm.graph.erase_node(node)
239
240        gm.graph.eliminate_dead_code()
241        gm.graph.lint()
242        gm.recompile()
243
244
245def constant_graph_tag(gm: torch.fx.GraphModule):
246    with torch.utils._python_dispatch._disable_current_modes():
247        cf = ConstantFolder(gm, skip_constructors=True)
248        cf.run()
249
250        for node in gm.graph.nodes:
251            if (
252                node.op == "get_attr"
253                or node in cf.node_replacements
254                or node in cf.replaced_uses
255            ):
256                node.meta[META_TAG] = CONST_MODULE_TAG
257            else:
258                node.meta[META_TAG] = MODULE_TAG
259
260
261def run_and_get_constant_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
262    """
263    Construct a GraphModule which corresponds to the part which could be
264    constant folded in provided gm.
265    """
266
267    constant_graph_tag(gm)
268    # We rewrite the tags, if it's a constant being directly consumed, without
269    # any folding opportunity, we keep it in main gm.
270    for node in gm.graph.find_nodes(op="get_attr"):
271        used_to_fold = False
272        for u in node.users:
273            if u.meta[META_TAG] == CONST_MODULE_TAG:
274                used_to_fold = True
275                break
276        if not used_to_fold:
277            node.meta[META_TAG] = MODULE_TAG
278
279    new_graph = torch.fx.Graph()
280
281    node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
282    output_nodes = []
283    for node in gm.graph.nodes:
284        if node.meta[META_TAG] == MODULE_TAG:
285            continue
286
287        new_node = new_graph.node_copy(node, lambda x: node_remapping[x])
288        node_remapping[node] = new_node
289
290        for user in node.users:
291            if user.meta[META_TAG] == MODULE_TAG:
292                output_nodes.append(new_node)
293                break
294
295    new_graph.output(tuple(output_nodes))
296    new_graph.lint()
297    new_gm = torch.fx.GraphModule(gm, new_graph)
298
299    return new_gm
300