1# mypy: ignore-errors 2 3""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc 4in the 'public' layer. 5 6Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc 7""" 8from __future__ import annotations 9 10import functools 11from typing import Optional, TYPE_CHECKING 12 13import torch 14 15from . import _dtypes_impl, _util 16 17 18if TYPE_CHECKING: 19 from ._normalizations import ( 20 ArrayLike, 21 AxisLike, 22 DTypeLike, 23 KeepDims, 24 NotImplementedType, 25 OutArray, 26 ) 27 28 29def _deco_axis_expand(func): 30 """ 31 Generically handle axis arguments in reductions. 32 axis is *always* the 2nd arg in the function so no need to have a look at its signature 33 """ 34 35 @functools.wraps(func) 36 def wrapped(a, axis=None, *args, **kwds): 37 if axis is not None: 38 axis = _util.normalize_axis_tuple(axis, a.ndim) 39 40 if axis == (): 41 # So we insert a length-one axis and run the reduction along it. 42 # We cannot return a.clone() as this would sidestep the checks inside the function 43 newshape = _util.expand_shape(a.shape, axis=0) 44 a = a.reshape(newshape) 45 axis = (0,) 46 47 return func(a, axis, *args, **kwds) 48 49 return wrapped 50 51 52def _atleast_float(dtype, other_dtype): 53 """Return a dtype that is real or complex floating-point. 54 55 For inputs that are boolean or integer dtypes, this returns the default 56 float dtype; inputs that are complex get converted to the default complex 57 dtype; real floating-point dtypes (`float*`) get passed through unchanged 58 """ 59 if dtype is None: 60 dtype = other_dtype 61 if not (dtype.is_floating_point or dtype.is_complex): 62 return _dtypes_impl.default_dtypes().float_dtype 63 return dtype 64 65 66@_deco_axis_expand 67def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False): 68 return a.count_nonzero(axis) 69 70 71@_deco_axis_expand 72def argmax( 73 a: ArrayLike, 74 axis: AxisLike = None, 75 out: Optional[OutArray] = None, 76 *, 77 keepdims: KeepDims = False, 78): 79 if a.is_complex(): 80 raise NotImplementedError(f"argmax with dtype={a.dtype}.") 81 82 axis = _util.allow_only_single_axis(axis) 83 84 if a.dtype == torch.bool: 85 # RuntimeError: "argmax_cpu" not implemented for 'Bool' 86 a = a.to(torch.uint8) 87 88 return torch.argmax(a, axis) 89 90 91@_deco_axis_expand 92def argmin( 93 a: ArrayLike, 94 axis: AxisLike = None, 95 out: Optional[OutArray] = None, 96 *, 97 keepdims: KeepDims = False, 98): 99 if a.is_complex(): 100 raise NotImplementedError(f"argmin with dtype={a.dtype}.") 101 102 axis = _util.allow_only_single_axis(axis) 103 104 if a.dtype == torch.bool: 105 # RuntimeError: "argmin_cpu" not implemented for 'Bool' 106 a = a.to(torch.uint8) 107 108 return torch.argmin(a, axis) 109 110 111@_deco_axis_expand 112def any( 113 a: ArrayLike, 114 axis: AxisLike = None, 115 out: Optional[OutArray] = None, 116 keepdims: KeepDims = False, 117 *, 118 where: NotImplementedType = None, 119): 120 axis = _util.allow_only_single_axis(axis) 121 axis_kw = {} if axis is None else {"dim": axis} 122 return torch.any(a, **axis_kw) 123 124 125@_deco_axis_expand 126def all( 127 a: ArrayLike, 128 axis: AxisLike = None, 129 out: Optional[OutArray] = None, 130 keepdims: KeepDims = False, 131 *, 132 where: NotImplementedType = None, 133): 134 axis = _util.allow_only_single_axis(axis) 135 axis_kw = {} if axis is None else {"dim": axis} 136 return torch.all(a, **axis_kw) 137 138 139@_deco_axis_expand 140def amax( 141 a: ArrayLike, 142 axis: AxisLike = None, 143 out: Optional[OutArray] = None, 144 keepdims: KeepDims = False, 145 initial: NotImplementedType = None, 146 where: NotImplementedType = None, 147): 148 if a.is_complex(): 149 raise NotImplementedError(f"amax with dtype={a.dtype}") 150 151 return a.amax(axis) 152 153 154max = amax 155 156 157@_deco_axis_expand 158def amin( 159 a: ArrayLike, 160 axis: AxisLike = None, 161 out: Optional[OutArray] = None, 162 keepdims: KeepDims = False, 163 initial: NotImplementedType = None, 164 where: NotImplementedType = None, 165): 166 if a.is_complex(): 167 raise NotImplementedError(f"amin with dtype={a.dtype}") 168 169 return a.amin(axis) 170 171 172min = amin 173 174 175@_deco_axis_expand 176def ptp( 177 a: ArrayLike, 178 axis: AxisLike = None, 179 out: Optional[OutArray] = None, 180 keepdims: KeepDims = False, 181): 182 return a.amax(axis) - a.amin(axis) 183 184 185@_deco_axis_expand 186def sum( 187 a: ArrayLike, 188 axis: AxisLike = None, 189 dtype: Optional[DTypeLike] = None, 190 out: Optional[OutArray] = None, 191 keepdims: KeepDims = False, 192 initial: NotImplementedType = None, 193 where: NotImplementedType = None, 194): 195 assert dtype is None or isinstance(dtype, torch.dtype) 196 197 if dtype == torch.bool: 198 dtype = _dtypes_impl.default_dtypes().int_dtype 199 200 axis_kw = {} if axis is None else {"dim": axis} 201 return a.sum(dtype=dtype, **axis_kw) 202 203 204@_deco_axis_expand 205def prod( 206 a: ArrayLike, 207 axis: AxisLike = None, 208 dtype: Optional[DTypeLike] = None, 209 out: Optional[OutArray] = None, 210 keepdims: KeepDims = False, 211 initial: NotImplementedType = None, 212 where: NotImplementedType = None, 213): 214 axis = _util.allow_only_single_axis(axis) 215 216 if dtype == torch.bool: 217 dtype = _dtypes_impl.default_dtypes().int_dtype 218 219 axis_kw = {} if axis is None else {"dim": axis} 220 return a.prod(dtype=dtype, **axis_kw) 221 222 223product = prod 224 225 226@_deco_axis_expand 227def mean( 228 a: ArrayLike, 229 axis: AxisLike = None, 230 dtype: Optional[DTypeLike] = None, 231 out: Optional[OutArray] = None, 232 keepdims: KeepDims = False, 233 *, 234 where: NotImplementedType = None, 235): 236 dtype = _atleast_float(dtype, a.dtype) 237 238 axis_kw = {} if axis is None else {"dim": axis} 239 result = a.mean(dtype=dtype, **axis_kw) 240 241 return result 242 243 244@_deco_axis_expand 245def std( 246 a: ArrayLike, 247 axis: AxisLike = None, 248 dtype: Optional[DTypeLike] = None, 249 out: Optional[OutArray] = None, 250 ddof=0, 251 keepdims: KeepDims = False, 252 *, 253 where: NotImplementedType = None, 254): 255 in_dtype = dtype 256 dtype = _atleast_float(dtype, a.dtype) 257 tensor = _util.cast_if_needed(a, dtype) 258 result = tensor.std(dim=axis, correction=ddof) 259 return _util.cast_if_needed(result, in_dtype) 260 261 262@_deco_axis_expand 263def var( 264 a: ArrayLike, 265 axis: AxisLike = None, 266 dtype: Optional[DTypeLike] = None, 267 out: Optional[OutArray] = None, 268 ddof=0, 269 keepdims: KeepDims = False, 270 *, 271 where: NotImplementedType = None, 272): 273 in_dtype = dtype 274 dtype = _atleast_float(dtype, a.dtype) 275 tensor = _util.cast_if_needed(a, dtype) 276 result = tensor.var(dim=axis, correction=ddof) 277 return _util.cast_if_needed(result, in_dtype) 278 279 280# cumsum / cumprod are almost reductions: 281# 1. no keepdims 282# 2. axis=None flattens 283 284 285def cumsum( 286 a: ArrayLike, 287 axis: AxisLike = None, 288 dtype: Optional[DTypeLike] = None, 289 out: Optional[OutArray] = None, 290): 291 if dtype == torch.bool: 292 dtype = _dtypes_impl.default_dtypes().int_dtype 293 if dtype is None: 294 dtype = a.dtype 295 296 (a,), axis = _util.axis_none_flatten(a, axis=axis) 297 axis = _util.normalize_axis_index(axis, a.ndim) 298 299 return a.cumsum(axis=axis, dtype=dtype) 300 301 302def cumprod( 303 a: ArrayLike, 304 axis: AxisLike = None, 305 dtype: Optional[DTypeLike] = None, 306 out: Optional[OutArray] = None, 307): 308 if dtype == torch.bool: 309 dtype = _dtypes_impl.default_dtypes().int_dtype 310 if dtype is None: 311 dtype = a.dtype 312 313 (a,), axis = _util.axis_none_flatten(a, axis=axis) 314 axis = _util.normalize_axis_index(axis, a.ndim) 315 316 return a.cumprod(axis=axis, dtype=dtype) 317 318 319cumproduct = cumprod 320 321 322def average( 323 a: ArrayLike, 324 axis=None, 325 weights: ArrayLike = None, 326 returned=False, 327 *, 328 keepdims=False, 329): 330 if weights is None: 331 result = mean(a, axis=axis) 332 wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype) 333 else: 334 if not a.dtype.is_floating_point: 335 a = a.double() 336 337 # axis & weights 338 if a.shape != weights.shape: 339 if axis is None: 340 raise TypeError( 341 "Axis must be specified when shapes of a and weights differ." 342 ) 343 if weights.ndim != 1: 344 raise TypeError( 345 "1D weights expected when shapes of a and weights differ." 346 ) 347 if weights.shape[0] != a.shape[axis]: 348 raise ValueError( 349 "Length of weights not compatible with specified axis." 350 ) 351 352 # setup weight to broadcast along axis 353 weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape) 354 weights = weights.swapaxes(-1, axis) 355 356 # do the work 357 result_dtype = _dtypes_impl.result_type_impl(a, weights) 358 numerator = sum(a * weights, axis, dtype=result_dtype) 359 wsum = sum(weights, axis, dtype=result_dtype) 360 result = numerator / wsum 361 362 # We process keepdims manually because the decorator does not deal with variadic returns 363 if keepdims: 364 result = _util.apply_keepdims(result, axis, a.ndim) 365 366 if returned: 367 if wsum.shape != result.shape: 368 wsum = torch.broadcast_to(wsum, result.shape).clone() 369 return result, wsum 370 else: 371 return result 372 373 374# Not using deco_axis_expand as it assumes that axis is the second arg 375def quantile( 376 a: ArrayLike, 377 q: ArrayLike, 378 axis: AxisLike = None, 379 out: Optional[OutArray] = None, 380 overwrite_input=False, 381 method="linear", 382 keepdims: KeepDims = False, 383 *, 384 interpolation: NotImplementedType = None, 385): 386 if overwrite_input: 387 # raise NotImplementedError("overwrite_input in quantile not implemented.") 388 # NumPy documents that `overwrite_input` MAY modify inputs: 389 # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile 390 # Here we choose to work out-of-place because why not. 391 pass 392 393 if not a.dtype.is_floating_point: 394 dtype = _dtypes_impl.default_dtypes().float_dtype 395 a = a.to(dtype) 396 397 # edge case: torch.quantile only supports float32 and float64 398 if a.dtype == torch.float16: 399 a = a.to(torch.float32) 400 401 if axis is None: 402 a = a.flatten() 403 q = q.flatten() 404 axis = (0,) 405 else: 406 axis = _util.normalize_axis_tuple(axis, a.ndim) 407 408 # FIXME(Mario) Doesn't np.quantile accept a tuple? 409 # torch.quantile does accept a number. If we don't want to implement the tuple behaviour 410 # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above. 411 axis = _util.allow_only_single_axis(axis) 412 413 q = _util.cast_if_needed(q, a.dtype) 414 415 return torch.quantile(a, q, axis=axis, interpolation=method) 416 417 418def percentile( 419 a: ArrayLike, 420 q: ArrayLike, 421 axis: AxisLike = None, 422 out: Optional[OutArray] = None, 423 overwrite_input=False, 424 method="linear", 425 keepdims: KeepDims = False, 426 *, 427 interpolation: NotImplementedType = None, 428): 429 # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32 430 if _dtypes_impl.python_type_for_torch(q.dtype) == int: 431 q = q.to(_dtypes_impl.default_dtypes().float_dtype) 432 qq = q / 100.0 433 434 return quantile( 435 a, 436 qq, 437 axis=axis, 438 overwrite_input=overwrite_input, 439 method=method, 440 keepdims=keepdims, 441 interpolation=interpolation, 442 ) 443 444 445def median( 446 a: ArrayLike, 447 axis=None, 448 out: Optional[OutArray] = None, 449 overwrite_input=False, 450 keepdims: KeepDims = False, 451): 452 return quantile( 453 a, 454 torch.as_tensor(0.5), 455 axis=axis, 456 overwrite_input=overwrite_input, 457 out=out, 458 keepdims=keepdims, 459 ) 460