• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict, Optional
8
9import torch
10from torch._export.utils import (
11    get_buffer,
12    get_lifted_tensor_constant,
13    get_param,
14    is_buffer,
15    is_lifted_tensor_constant,
16    is_param,
17)
18
19
20def is_parameter(
21    node: torch.fx.Node, edge_program: torch.export.ExportedProgram
22) -> bool:
23    return (
24        is_param(edge_program, node)
25        or is_buffer(edge_program, node)
26        or is_lifted_tensor_constant(edge_program, node)
27    )
28
29
30def get_parameter(
31    node: torch.fx.Node, edge_program: torch.export.ExportedProgram
32) -> torch.Tensor:
33    param = None
34    if is_param(edge_program, node):
35        param = get_param(edge_program, node)
36    if is_buffer(edge_program, node):
37        param = get_buffer(edge_program, node)
38    if is_lifted_tensor_constant(edge_program, node):
39        param = get_lifted_tensor_constant(edge_program, node)
40    if param is not None:
41        # update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
42        assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
43        param = param.type(node.meta["val"].dtype)
44    return param
45
46
47def set_parameter(
48    param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram
49):
50    status = False
51    if is_param(edge_program, node):
52        edge_program.state_dict[
53            edge_program.graph_signature.inputs_to_parameters[node.name]
54        ] = param
55        status = True
56    if is_buffer(edge_program, node):
57        buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
58        if buffer_name in edge_program.graph_signature.non_persistent_buffers:
59            edge_program.constants[buffer_name] = param
60        else:
61            edge_program.state_dict[buffer_name] = param
62        status = True
63    assert status, "Failed to set parameter"
64
65
66def is_graph_input(
67    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
68) -> bool:
69    """
70    Check if the given tensor is a graph input
71
72    Args:
73        tensor: EdgeIR Tensor that is being checked for graph input
74    """
75    return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
76
77
78def is_graph_output(tensor: torch.fx.Node) -> bool:
79    """
80    Check if the given tensor is used as a graph output
81
82    Args:
83        tensor: EdgeIR Tensor that is being checked for graph input
84    """
85    for user in tensor.users.keys():
86        # getitem node is skiped, check the op_skip_ops.py
87        if user.op == "output" or (
88            user.target.__name__ == "getitem" and is_graph_output(user)
89        ):
90            return True
91    return False
92
93
94def is_constant(
95    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
96) -> bool:
97    """
98    Check if the given tensor is a constant
99
100    Args:
101        tensor: EdgeIR Tensor that is being checked for graph input
102    """
103    # constants should not be treated as input placeholder
104    # pay attention to the pytorch design, change this if
105    # breakage happened:
106    # pytorch/torch/_export/passes/lift_constant_tensor_pass.py
107    if is_parameter(tensor, edge_program):
108        return tensor.meta["val"].constant is not None
109
110    return False
111
112
113def deduce_dtype(
114    tensor: torch.Tensor, quant_infos: Optional[Dict] = None
115) -> torch.dtype:
116    if quant_infos:
117        quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
118        unsigned = quant_infos["quant_min"] >= 0
119        if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
120            return torch.uint8 if unsigned else torch.int8
121
122        elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
123            return torch.uint16 if unsigned else torch.int16
124
125        return quant_infos["dtype"]
126
127    return tensor.dtype
128