• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: ignore-errors
2
3""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc
4in the 'public' layer.
5
6Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc
7"""
8from __future__ import annotations
9
10import functools
11from typing import Optional, TYPE_CHECKING
12
13import torch
14
15from . import _dtypes_impl, _util
16
17
18if TYPE_CHECKING:
19    from ._normalizations import (
20        ArrayLike,
21        AxisLike,
22        DTypeLike,
23        KeepDims,
24        NotImplementedType,
25        OutArray,
26    )
27
28
29def _deco_axis_expand(func):
30    """
31    Generically handle axis arguments in reductions.
32    axis is *always* the 2nd arg in the function so no need to have a look at its signature
33    """
34
35    @functools.wraps(func)
36    def wrapped(a, axis=None, *args, **kwds):
37        if axis is not None:
38            axis = _util.normalize_axis_tuple(axis, a.ndim)
39
40        if axis == ():
41            # So we insert a length-one axis and run the reduction along it.
42            # We cannot return a.clone() as this would sidestep the checks inside the function
43            newshape = _util.expand_shape(a.shape, axis=0)
44            a = a.reshape(newshape)
45            axis = (0,)
46
47        return func(a, axis, *args, **kwds)
48
49    return wrapped
50
51
52def _atleast_float(dtype, other_dtype):
53    """Return a dtype that is real or complex floating-point.
54
55    For inputs that are boolean or integer dtypes, this returns the default
56    float dtype; inputs that are complex get converted to the default complex
57    dtype; real floating-point dtypes (`float*`) get passed through unchanged
58    """
59    if dtype is None:
60        dtype = other_dtype
61    if not (dtype.is_floating_point or dtype.is_complex):
62        return _dtypes_impl.default_dtypes().float_dtype
63    return dtype
64
65
66@_deco_axis_expand
67def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False):
68    return a.count_nonzero(axis)
69
70
71@_deco_axis_expand
72def argmax(
73    a: ArrayLike,
74    axis: AxisLike = None,
75    out: Optional[OutArray] = None,
76    *,
77    keepdims: KeepDims = False,
78):
79    if a.is_complex():
80        raise NotImplementedError(f"argmax with dtype={a.dtype}.")
81
82    axis = _util.allow_only_single_axis(axis)
83
84    if a.dtype == torch.bool:
85        # RuntimeError: "argmax_cpu" not implemented for 'Bool'
86        a = a.to(torch.uint8)
87
88    return torch.argmax(a, axis)
89
90
91@_deco_axis_expand
92def argmin(
93    a: ArrayLike,
94    axis: AxisLike = None,
95    out: Optional[OutArray] = None,
96    *,
97    keepdims: KeepDims = False,
98):
99    if a.is_complex():
100        raise NotImplementedError(f"argmin with dtype={a.dtype}.")
101
102    axis = _util.allow_only_single_axis(axis)
103
104    if a.dtype == torch.bool:
105        # RuntimeError: "argmin_cpu" not implemented for 'Bool'
106        a = a.to(torch.uint8)
107
108    return torch.argmin(a, axis)
109
110
111@_deco_axis_expand
112def any(
113    a: ArrayLike,
114    axis: AxisLike = None,
115    out: Optional[OutArray] = None,
116    keepdims: KeepDims = False,
117    *,
118    where: NotImplementedType = None,
119):
120    axis = _util.allow_only_single_axis(axis)
121    axis_kw = {} if axis is None else {"dim": axis}
122    return torch.any(a, **axis_kw)
123
124
125@_deco_axis_expand
126def all(
127    a: ArrayLike,
128    axis: AxisLike = None,
129    out: Optional[OutArray] = None,
130    keepdims: KeepDims = False,
131    *,
132    where: NotImplementedType = None,
133):
134    axis = _util.allow_only_single_axis(axis)
135    axis_kw = {} if axis is None else {"dim": axis}
136    return torch.all(a, **axis_kw)
137
138
139@_deco_axis_expand
140def amax(
141    a: ArrayLike,
142    axis: AxisLike = None,
143    out: Optional[OutArray] = None,
144    keepdims: KeepDims = False,
145    initial: NotImplementedType = None,
146    where: NotImplementedType = None,
147):
148    if a.is_complex():
149        raise NotImplementedError(f"amax with dtype={a.dtype}")
150
151    return a.amax(axis)
152
153
154max = amax
155
156
157@_deco_axis_expand
158def amin(
159    a: ArrayLike,
160    axis: AxisLike = None,
161    out: Optional[OutArray] = None,
162    keepdims: KeepDims = False,
163    initial: NotImplementedType = None,
164    where: NotImplementedType = None,
165):
166    if a.is_complex():
167        raise NotImplementedError(f"amin with dtype={a.dtype}")
168
169    return a.amin(axis)
170
171
172min = amin
173
174
175@_deco_axis_expand
176def ptp(
177    a: ArrayLike,
178    axis: AxisLike = None,
179    out: Optional[OutArray] = None,
180    keepdims: KeepDims = False,
181):
182    return a.amax(axis) - a.amin(axis)
183
184
185@_deco_axis_expand
186def sum(
187    a: ArrayLike,
188    axis: AxisLike = None,
189    dtype: Optional[DTypeLike] = None,
190    out: Optional[OutArray] = None,
191    keepdims: KeepDims = False,
192    initial: NotImplementedType = None,
193    where: NotImplementedType = None,
194):
195    assert dtype is None or isinstance(dtype, torch.dtype)
196
197    if dtype == torch.bool:
198        dtype = _dtypes_impl.default_dtypes().int_dtype
199
200    axis_kw = {} if axis is None else {"dim": axis}
201    return a.sum(dtype=dtype, **axis_kw)
202
203
204@_deco_axis_expand
205def prod(
206    a: ArrayLike,
207    axis: AxisLike = None,
208    dtype: Optional[DTypeLike] = None,
209    out: Optional[OutArray] = None,
210    keepdims: KeepDims = False,
211    initial: NotImplementedType = None,
212    where: NotImplementedType = None,
213):
214    axis = _util.allow_only_single_axis(axis)
215
216    if dtype == torch.bool:
217        dtype = _dtypes_impl.default_dtypes().int_dtype
218
219    axis_kw = {} if axis is None else {"dim": axis}
220    return a.prod(dtype=dtype, **axis_kw)
221
222
223product = prod
224
225
226@_deco_axis_expand
227def mean(
228    a: ArrayLike,
229    axis: AxisLike = None,
230    dtype: Optional[DTypeLike] = None,
231    out: Optional[OutArray] = None,
232    keepdims: KeepDims = False,
233    *,
234    where: NotImplementedType = None,
235):
236    dtype = _atleast_float(dtype, a.dtype)
237
238    axis_kw = {} if axis is None else {"dim": axis}
239    result = a.mean(dtype=dtype, **axis_kw)
240
241    return result
242
243
244@_deco_axis_expand
245def std(
246    a: ArrayLike,
247    axis: AxisLike = None,
248    dtype: Optional[DTypeLike] = None,
249    out: Optional[OutArray] = None,
250    ddof=0,
251    keepdims: KeepDims = False,
252    *,
253    where: NotImplementedType = None,
254):
255    in_dtype = dtype
256    dtype = _atleast_float(dtype, a.dtype)
257    tensor = _util.cast_if_needed(a, dtype)
258    result = tensor.std(dim=axis, correction=ddof)
259    return _util.cast_if_needed(result, in_dtype)
260
261
262@_deco_axis_expand
263def var(
264    a: ArrayLike,
265    axis: AxisLike = None,
266    dtype: Optional[DTypeLike] = None,
267    out: Optional[OutArray] = None,
268    ddof=0,
269    keepdims: KeepDims = False,
270    *,
271    where: NotImplementedType = None,
272):
273    in_dtype = dtype
274    dtype = _atleast_float(dtype, a.dtype)
275    tensor = _util.cast_if_needed(a, dtype)
276    result = tensor.var(dim=axis, correction=ddof)
277    return _util.cast_if_needed(result, in_dtype)
278
279
280# cumsum / cumprod are almost reductions:
281#   1. no keepdims
282#   2. axis=None flattens
283
284
285def cumsum(
286    a: ArrayLike,
287    axis: AxisLike = None,
288    dtype: Optional[DTypeLike] = None,
289    out: Optional[OutArray] = None,
290):
291    if dtype == torch.bool:
292        dtype = _dtypes_impl.default_dtypes().int_dtype
293    if dtype is None:
294        dtype = a.dtype
295
296    (a,), axis = _util.axis_none_flatten(a, axis=axis)
297    axis = _util.normalize_axis_index(axis, a.ndim)
298
299    return a.cumsum(axis=axis, dtype=dtype)
300
301
302def cumprod(
303    a: ArrayLike,
304    axis: AxisLike = None,
305    dtype: Optional[DTypeLike] = None,
306    out: Optional[OutArray] = None,
307):
308    if dtype == torch.bool:
309        dtype = _dtypes_impl.default_dtypes().int_dtype
310    if dtype is None:
311        dtype = a.dtype
312
313    (a,), axis = _util.axis_none_flatten(a, axis=axis)
314    axis = _util.normalize_axis_index(axis, a.ndim)
315
316    return a.cumprod(axis=axis, dtype=dtype)
317
318
319cumproduct = cumprod
320
321
322def average(
323    a: ArrayLike,
324    axis=None,
325    weights: ArrayLike = None,
326    returned=False,
327    *,
328    keepdims=False,
329):
330    if weights is None:
331        result = mean(a, axis=axis)
332        wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype)
333    else:
334        if not a.dtype.is_floating_point:
335            a = a.double()
336
337        # axis & weights
338        if a.shape != weights.shape:
339            if axis is None:
340                raise TypeError(
341                    "Axis must be specified when shapes of a and weights differ."
342                )
343            if weights.ndim != 1:
344                raise TypeError(
345                    "1D weights expected when shapes of a and weights differ."
346                )
347            if weights.shape[0] != a.shape[axis]:
348                raise ValueError(
349                    "Length of weights not compatible with specified axis."
350                )
351
352            # setup weight to broadcast along axis
353            weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape)
354            weights = weights.swapaxes(-1, axis)
355
356        # do the work
357        result_dtype = _dtypes_impl.result_type_impl(a, weights)
358        numerator = sum(a * weights, axis, dtype=result_dtype)
359        wsum = sum(weights, axis, dtype=result_dtype)
360        result = numerator / wsum
361
362    # We process keepdims manually because the decorator does not deal with variadic returns
363    if keepdims:
364        result = _util.apply_keepdims(result, axis, a.ndim)
365
366    if returned:
367        if wsum.shape != result.shape:
368            wsum = torch.broadcast_to(wsum, result.shape).clone()
369        return result, wsum
370    else:
371        return result
372
373
374# Not using deco_axis_expand as it assumes that axis is the second arg
375def quantile(
376    a: ArrayLike,
377    q: ArrayLike,
378    axis: AxisLike = None,
379    out: Optional[OutArray] = None,
380    overwrite_input=False,
381    method="linear",
382    keepdims: KeepDims = False,
383    *,
384    interpolation: NotImplementedType = None,
385):
386    if overwrite_input:
387        # raise NotImplementedError("overwrite_input in quantile not implemented.")
388        # NumPy documents that `overwrite_input` MAY modify inputs:
389        # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile
390        # Here we choose to work out-of-place because why not.
391        pass
392
393    if not a.dtype.is_floating_point:
394        dtype = _dtypes_impl.default_dtypes().float_dtype
395        a = a.to(dtype)
396
397    # edge case: torch.quantile only supports float32 and float64
398    if a.dtype == torch.float16:
399        a = a.to(torch.float32)
400
401    if axis is None:
402        a = a.flatten()
403        q = q.flatten()
404        axis = (0,)
405    else:
406        axis = _util.normalize_axis_tuple(axis, a.ndim)
407
408    # FIXME(Mario) Doesn't np.quantile accept a tuple?
409    # torch.quantile does accept a number. If we don't want to implement the tuple behaviour
410    # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above.
411    axis = _util.allow_only_single_axis(axis)
412
413    q = _util.cast_if_needed(q, a.dtype)
414
415    return torch.quantile(a, q, axis=axis, interpolation=method)
416
417
418def percentile(
419    a: ArrayLike,
420    q: ArrayLike,
421    axis: AxisLike = None,
422    out: Optional[OutArray] = None,
423    overwrite_input=False,
424    method="linear",
425    keepdims: KeepDims = False,
426    *,
427    interpolation: NotImplementedType = None,
428):
429    # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32
430    if _dtypes_impl.python_type_for_torch(q.dtype) == int:
431        q = q.to(_dtypes_impl.default_dtypes().float_dtype)
432    qq = q / 100.0
433
434    return quantile(
435        a,
436        qq,
437        axis=axis,
438        overwrite_input=overwrite_input,
439        method=method,
440        keepdims=keepdims,
441        interpolation=interpolation,
442    )
443
444
445def median(
446    a: ArrayLike,
447    axis=None,
448    out: Optional[OutArray] = None,
449    overwrite_input=False,
450    keepdims: KeepDims = False,
451):
452    return quantile(
453        a,
454        torch.as_tensor(0.5),
455        axis=axis,
456        overwrite_input=overwrite_input,
457        out=out,
458        keepdims=keepdims,
459    )
460