1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""sparsify utils""" 16import ast 17import builtins 18import inspect 19import types 20from enum import Enum, auto 21from typing import NamedTuple, Any, Optional, Callable, Union 22 23import mindspore 24from mindspore import ops, Tensor, CSRTensor, COOTensor 25from mindspore.rewrite.common.namespace import get_functional 26 27 28class ArgType(Enum): 29 """ 30 Argument types for sparsify. 31 32 - CSR represents a CSRTensor. 33 - COO represents a COOTensor. 34 - NONSPARSE represents a non-sparse value. 35 36 .. warning:: 37 This is a set of experimental APIs that is subject to change or deletion. 38 """ 39 NONSPARSE = auto() 40 CSR = auto() 41 COO = auto() 42 ANY = auto() 43 44 45class SparseFunc(NamedTuple): 46 """ 47 Represents a sparse function in sparsify. 48 49 Note: 50 If `fn` is a function with type hints, `inputs` and/or `outputs`, when provided, override function type hints. 51 52 .. warning:: 53 This is a set of experimental APIs that is subject to change or deletion. 54 55 Args: 56 fn (Union[str, Callable]): a sparse function. If `fn` is a string, the function represents a mindspore 57 functional op; or `fn` can be any function object. 58 inputs (Any, optional): input types for the function. If `inputs` is None, use the input types in function 59 type hints. Default: ``None`` . 60 outputs (Any, optional): output types for the function. If `outputs` is None, use the output types in function 61 type hints. Default: ``None`` . 62 """ 63 fn: Union[str, Callable] 64 inputs: Optional[Any] = None 65 outputs: Optional[Any] = None 66 67 68# maps function to a list of strings or SparseFunc, each representing the name of a sparse_func 69sparse_rules = { 70 ops.reduce_sum: ["csr_reduce_sum"], 71 ops.mul: ["csr_mul"], 72 ops.matmul: ["csr_mv"], 73 "+": [], 74 "-": [], 75 "*": ["csr_mul"], 76 "/": ["csr_div"] 77} 78 79 80builtin_ops = {i for i, v in vars(builtins).items() if isinstance(v, types.BuiltinFunctionType)} 81tensor_to_arg_type_map = {Tensor: ArgType.NONSPARSE, CSRTensor: ArgType.CSR, COOTensor: ArgType.COO} 82arg_type_to_tensor_map = {ArgType.CSR: CSRTensor, ArgType.COO: COOTensor} 83arg_type_to_prefix_map = {ArgType.CSR: "csr", ArgType.COO: "coo"} 84 85 86def get_arg_type(annotation): 87 """Returns arg_type based on typing annotation.""" 88 if isinstance(annotation, str): 89 annotation = getattr(mindspore, annotation, None) 90 arg_type = tensor_to_arg_type_map.get(annotation, None) 91 if arg_type is None: 92 if annotation in (int, float, bool, str): 93 return ArgType.NONSPARSE 94 raise ValueError(f"Type {annotation} cannot be mapped to ArgType!") 95 return arg_type 96 97 98def get_tuple(x): 99 """get tuple""" 100 if not isinstance(x, (tuple, list)): 101 return (x,) 102 return tuple(x) 103 104 105def get_inputs_outputs(fn): 106 """Returns input and output types for function based on typing.""" 107 sig = inspect.signature(fn) 108 inputs = [] 109 for i in sig.parameters.values(): 110 if i.annotation == inspect.Parameter.empty: 111 inputs = None 112 break 113 input_type = get_arg_type(i.annotation) 114 inputs.append(input_type) 115 if sig.return_annotation == inspect.Parameter.empty: 116 outputs = None 117 else: 118 outputs = get_tuple(get_arg_type(sig.return_annotation)) 119 return inputs, outputs 120 121 122def get_sparse_method_outputs(method_name, sparse_type): 123 """Returns output types for sparse tensor method.""" 124 tensor = arg_type_to_tensor_map.get(sparse_type, None) 125 if tensor is None: 126 raise ValueError(f"Unrecognized sparse type {sparse_type}!") 127 method = getattr(tensor, method_name, None) 128 if method is None: 129 raise ValueError(f"{tensor} does not have attr {method_name}!") 130 _, outputs = get_inputs_outputs(method) 131 return outputs 132 133 134def get_sparse_func(rule): 135 """ 136 Returns SparseFunc with string for `fn`, `inputs` and `outputs` extracted from 137 function annotation. 138 """ 139 if isinstance(rule, str): 140 # only mindspore functional ops can be passed as strings 141 sparse_func = get_functional(rule) 142 if not sparse_func: 143 raise ValueError(f"{rule} not a valid name for mindspore functional op!") 144 inputs, outputs = get_inputs_outputs(sparse_func) 145 return SparseFunc(rule, inputs, outputs) 146 if isinstance(rule, SparseFunc): 147 if isinstance(rule.fn, str): 148 return get_sparse_func(rule.fn) 149 if callable(rule.fn): 150 inputs, outputs = get_inputs_outputs(rule.fn) 151 if rule.inputs: 152 inputs = get_tuple(rule.inputs) 153 elif inputs is None: 154 raise ValueError(f"Input types not provided for {rule}!") 155 if rule.outputs: 156 outputs = get_tuple(rule.outputs) 157 elif outputs is None: 158 raise ValueError(f"Output types not provided for {rule}!") 159 return SparseFunc(rule.fn.__name__, inputs, outputs) 160 raise ValueError(f"`fn` {rule.fn} for SparseFunc should be either a string or a function!") 161 if callable(rule): 162 inputs, outputs = get_inputs_outputs(rule) 163 if inputs is None or outputs is None: 164 raise ValueError(f"Both input types and output types should be provided for {rule}!") 165 return SparseFunc(rule.__name__, inputs, outputs) 166 raise ValueError(f"Sparse rule {rule} should be either a string or a SparseFunc!") 167 168 169def get_binop_name(binop): 170 """Maps ast.BinOp operator to string.""" 171 if binop == ast.Add(): 172 return "+" 173 if binop == ast.Sub(): 174 return "-" 175 if binop == ast.Mult(): 176 return "*" 177 if binop == ast.Div(): 178 return "/" 179 return "" 180