• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""sparsify utils"""
16import ast
17import builtins
18import inspect
19import types
20from enum import Enum, auto
21from typing import NamedTuple, Any, Optional, Callable, Union
22
23import mindspore
24from mindspore import ops, Tensor, CSRTensor, COOTensor
25from mindspore.rewrite.common.namespace import get_functional
26
27
28class ArgType(Enum):
29    """
30    Argument types for sparsify.
31
32    - CSR represents a CSRTensor.
33    - COO represents a COOTensor.
34    - NONSPARSE represents a non-sparse value.
35
36    .. warning::
37        This is a set of experimental APIs that is subject to change or deletion.
38    """
39    NONSPARSE = auto()
40    CSR = auto()
41    COO = auto()
42    ANY = auto()
43
44
45class SparseFunc(NamedTuple):
46    """
47    Represents a sparse function in sparsify.
48
49    Note:
50        If `fn` is a function with type hints, `inputs` and/or `outputs`, when provided, override function type hints.
51
52    .. warning::
53        This is a set of experimental APIs that is subject to change or deletion.
54
55    Args:
56        fn (Union[str, Callable]): a sparse function. If `fn` is a string, the function represents a mindspore
57            functional op; or `fn` can be any function object.
58        inputs (Any, optional): input types for the function. If `inputs` is None, use the input types in function
59            type hints. Default: ``None`` .
60        outputs (Any, optional): output types for the function. If `outputs` is None, use the output types in function
61            type hints. Default: ``None`` .
62    """
63    fn: Union[str, Callable]
64    inputs: Optional[Any] = None
65    outputs: Optional[Any] = None
66
67
68# maps function to a list of strings or SparseFunc, each representing the name of a sparse_func
69sparse_rules = {
70    ops.reduce_sum: ["csr_reduce_sum"],
71    ops.mul: ["csr_mul"],
72    ops.matmul: ["csr_mv"],
73    "+": [],
74    "-": [],
75    "*": ["csr_mul"],
76    "/": ["csr_div"]
77}
78
79
80builtin_ops = {i for i, v in vars(builtins).items() if isinstance(v, types.BuiltinFunctionType)}
81tensor_to_arg_type_map = {Tensor: ArgType.NONSPARSE, CSRTensor: ArgType.CSR, COOTensor: ArgType.COO}
82arg_type_to_tensor_map = {ArgType.CSR: CSRTensor, ArgType.COO: COOTensor}
83arg_type_to_prefix_map = {ArgType.CSR: "csr", ArgType.COO: "coo"}
84
85
86def get_arg_type(annotation):
87    """Returns arg_type based on typing annotation."""
88    if isinstance(annotation, str):
89        annotation = getattr(mindspore, annotation, None)
90    arg_type = tensor_to_arg_type_map.get(annotation, None)
91    if arg_type is None:
92        if annotation in (int, float, bool, str):
93            return ArgType.NONSPARSE
94        raise ValueError(f"Type {annotation} cannot be mapped to ArgType!")
95    return arg_type
96
97
98def get_tuple(x):
99    """get tuple"""
100    if not isinstance(x, (tuple, list)):
101        return (x,)
102    return tuple(x)
103
104
105def get_inputs_outputs(fn):
106    """Returns input and output types for function based on typing."""
107    sig = inspect.signature(fn)
108    inputs = []
109    for i in sig.parameters.values():
110        if i.annotation == inspect.Parameter.empty:
111            inputs = None
112            break
113        input_type = get_arg_type(i.annotation)
114        inputs.append(input_type)
115    if sig.return_annotation == inspect.Parameter.empty:
116        outputs = None
117    else:
118        outputs = get_tuple(get_arg_type(sig.return_annotation))
119    return inputs, outputs
120
121
122def get_sparse_method_outputs(method_name, sparse_type):
123    """Returns output types for sparse tensor method."""
124    tensor = arg_type_to_tensor_map.get(sparse_type, None)
125    if tensor is None:
126        raise ValueError(f"Unrecognized sparse type {sparse_type}!")
127    method = getattr(tensor, method_name, None)
128    if method is None:
129        raise ValueError(f"{tensor} does not have attr {method_name}!")
130    _, outputs = get_inputs_outputs(method)
131    return outputs
132
133
134def get_sparse_func(rule):
135    """
136    Returns SparseFunc with string for `fn`, `inputs` and `outputs` extracted from
137    function annotation.
138    """
139    if isinstance(rule, str):
140        # only mindspore functional ops can be passed as strings
141        sparse_func = get_functional(rule)
142        if not sparse_func:
143            raise ValueError(f"{rule} not a valid name for mindspore functional op!")
144        inputs, outputs = get_inputs_outputs(sparse_func)
145        return SparseFunc(rule, inputs, outputs)
146    if isinstance(rule, SparseFunc):
147        if isinstance(rule.fn, str):
148            return get_sparse_func(rule.fn)
149        if callable(rule.fn):
150            inputs, outputs = get_inputs_outputs(rule.fn)
151            if rule.inputs:
152                inputs = get_tuple(rule.inputs)
153            elif inputs is None:
154                raise ValueError(f"Input types not provided for {rule}!")
155            if rule.outputs:
156                outputs = get_tuple(rule.outputs)
157            elif outputs is None:
158                raise ValueError(f"Output types not provided for {rule}!")
159            return SparseFunc(rule.fn.__name__, inputs, outputs)
160        raise ValueError(f"`fn` {rule.fn} for SparseFunc should be either a string or a function!")
161    if callable(rule):
162        inputs, outputs = get_inputs_outputs(rule)
163        if inputs is None or outputs is None:
164            raise ValueError(f"Both input types and output types should be provided for {rule}!")
165        return SparseFunc(rule.__name__, inputs, outputs)
166    raise ValueError(f"Sparse rule {rule} should be either a string or a SparseFunc!")
167
168
169def get_binop_name(binop):
170    """Maps ast.BinOp operator to string."""
171    if binop == ast.Add():
172        return "+"
173    if binop == ast.Sub():
174        return "-"
175    if binop == ast.Mult():
176        return "*"
177    if binop == ast.Div():
178        return "/"
179    return ""
180