• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""VM implementations based on numpy."""
16
17import numpy as np
18from mindspore._checkparam import Validator as validator
19
20
21def avg_pooling(x, pool_h, pool_w, stride):
22    """
23    Applies average pooling over an input array.
24
25    Args:
26        x (numpy.ndarray): The input array to be average pooled.
27        pool_h (int): Height of the pooling window.
28        pool_w (int): Width of the pooling window.
29        stride (int): The stride of the sliding window.
30
31    Returns:
32        numpy.ndarray, an output array after applying average pooling on input array.
33    """
34    validator.check_positive_int(stride, "stride")
35    num, channel, height, width = x.shape
36    out_h = (height - pool_h) // stride + 1
37    out_w = (width - pool_w) // stride + 1
38
39    col = im2col(x, pool_h, pool_w, stride)
40    col = col.reshape(-1, pool_h * pool_w)
41
42    out = np.mean(col, axis=1)
43    out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
44
45    return out
46
47
48def avg_pool_grad(dout, origin_shape, pool_h, pool_w, stride):
49    """
50    Gets grad of average pooling.
51
52    Args:
53        x (numpy.ndarray): The input array to be average pooled.
54        dout (numpy.ndarray): The  grad of pre-layer.
55        pool_h (int): Height of the pooling window.
56        pool_w (int): Width of the pooling window.
57        stride (int): The stride of the sliding window.
58
59    Returns:
60        numpy.ndarray, grad of average pooling.
61    """
62    # pylint: disable=unused-argument
63    _, _, height, width = dout.shape
64    dx = np.zeros(origin_shape)
65    for i in range(height):
66        for j in range(width):
67            dx[:, :, i:(i + pool_h), j:(j + pool_w)] += np.ones((pool_h, pool_w))
68    return dx
69
70
71def _batch_norm(x, scale, shift, running_mean=None, running_var=None,
72                eps=1e-05, momentum=0.1, is_training=True):
73    """Batch Normalization over an array."""
74    _, c_h_w = x.shape
75    # Handle running_mean and running_var are not None
76    # if running_mean is None:
77    #     running_mean = np.zeros(c_h_w)
78    #     running_var = np.zeros(c_h_w)
79    running_mean = np.zeros(c_h_w)
80    running_var = np.zeros(c_h_w)
81    if np.ndim(scale) > 0:
82        scale = scale.mean()
83    if np.ndim(shift) > 0:
84        shift = shift.mean()
85
86    if is_training:
87        x_mean = np.mean(x, axis=0)
88        x_var = np.var(x, axis=0)
89
90        # Normalization followed by Affine transformation
91        x_norm = (x - x_mean) / np.sqrt(x_var + eps)
92
93        # Estimate running average of mean and variance to use at test time
94        running_mean = momentum * running_mean + (1 - momentum) * x_mean
95        running_var = momentum * running_var + (1 - momentum) * x_var
96    else:
97        # normalize using running average
98        x_norm = (x - running_mean) / np.sqrt(running_var + eps)
99        x_mean = running_mean
100        x_var = running_var
101
102    out = scale * x_norm + shift
103
104    return out, x_mean, x_var, running_mean, running_var
105
106
107def batch_norm(x, scale=1, shift=0, mean=None, variance=None,
108               eps=1e-05, momentum=0.1, is_training=True):
109    """Batch Normalization over an array."""
110    input_shape = x.shape
111    if x.ndim != 2:
112        batch_num = x.shape[0]
113        x = x.reshape(batch_num, -1)
114
115    out, _, _, running_mean, running_var = _batch_norm(x, scale, shift, mean, variance, \
116                                                       eps, momentum, is_training)
117
118    return out.reshape(*input_shape), np.array(scale), np.array(shift), running_mean, running_var
119
120
121def _batch_norm_grad(dout, x, scale, save_mean, save_inv_variance, \
122                     eps=1e-05, momentum=0.1, is_training=True):
123    """Batch Normalization over an array."""
124    if x.ndim != 2:
125        batch_num = x.shape[0]
126        x = x.reshape(batch_num, -1)
127    if np.ndim(scale) > 0:
128        scale = scale.mean()
129    x_norm, x_mean, x_var, _, _ = _batch_norm(x, scale, shift=0, running_mean=save_mean, \
130                                              running_var=save_inv_variance, \
131                                              eps=eps, momentum=momentum, is_training=is_training)
132    batch_size = x.shape[0]
133    dx_norm = scale * dout
134    dvar = np.sum(dx_norm * (x - x_mean) * ((x_var + eps) ** (-3.0 / 2)) * (-1.0 / 2), axis=0)
135    dmean = np.sum(dx_norm * (-1.0 / np.sqrt(x_var + eps)), axis=0) \
136            + dvar * (np.sum(-2 * (x - x_mean), axis=0) * (1.0 / batch_size))
137    dx = dx_norm * (1.0 / np.sqrt(x_var + eps)) + dvar * (2.0 * (x - x_mean) / batch_size) + dmean * (1.0 / batch_size)
138    dgamma = np.sum(dout * x_norm, axis=0)
139    dbeta = np.sum(dout, axis=0)
140    return dx, dgamma, dbeta
141
142
143def batch_norm_grad(dy, x, scale, save_mean, save_inv_variance):
144    """Batch Normalization over an array."""
145    if dy.ndim != 2:
146        batch_size = dy.shape[0]
147        dy = dy.reshape(batch_size, -1)
148
149    dx, dgamma, dbeta = _batch_norm_grad(dy, x, scale, save_mean, save_inv_variance)
150    input_shape = x.shape
151    dx = dx.reshape(*input_shape)
152    return dx, dgamma, dbeta
153
154
155def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
156    """Rearranges a row vector to an image."""
157    if isinstance(stride, int):
158        stride_h = stride
159        stride_w = stride
160    elif isinstance(stride, tuple) and len(stride) == 2:
161        stride_h = stride[0]
162        stride_w = stride[1]
163    elif isinstance(stride, tuple) and len(stride) == 4:
164        stride_h = stride[2]
165        stride_w = stride[3]
166    else:
167        raise ValueError(f"The \'stride\' should be an int number or "
168                         f"a tuple of two or four int numbers, but got {stride}")
169
170    if isinstance(pad, int):
171        pad_top = pad
172        pad_bottom = pad
173        pad_left = pad
174        pad_right = pad
175    elif isinstance(pad, tuple) and len(pad) == 2:
176        pad_top = pad[0]
177        pad_bottom = pad[0]
178        pad_left = pad[1]
179        pad_right = pad[1]
180    elif isinstance(pad, tuple) and len(pad) == 4:
181        pad_top, pad_bottom, pad_left, pad_right = pad
182    else:
183        raise ValueError(f"The \'pad\' should be an int number or "
184                         f"a tuple of two or four int numbers, but got {pad}")
185
186    batch_num, channel, height, width = input_shape
187    out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1
188    out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1
189    col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
190        .transpose(0, 3, 4, 5, 1, 2)
191
192    img = np.zeros((batch_num,
193                    channel,
194                    height + pad_top + pad_bottom + stride_h - 1,
195                    width + pad_left + pad_right + stride_w - 1)) \
196        .astype(col.dtype)
197    for y in range(filter_h):
198        y_max = y + stride_h * out_h
199        for x in range(filter_w):
200            x_max = x + stride_h * out_w
201            img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
202
203    return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right]
204
205
206def convolve(x, w, b=None, pad_mode="valid"):
207    """
208    Gets the discrete, linear convolution of two one-dimensional sequences.
209
210    Args:
211        x (numpy.ndarray): One-dimensional input array.
212        w (numpy.ndarray): One-dimensional input array.
213        b (numpy.ndarray): One-dimensional input array. Default: None.
214        pad_mode (str): Padding mode which can be: "full" means returns the
215                  convolution at each point of overlap, with an output shape
216                  of (N+M-1,); "same" means returns output of length max(M, N);
217                  Amd "valid" means returns output of length max(M, N) - min(M, N)
218                  + 1. Default: "valid".
219
220    Returns:
221        numpy.ndarray, discrete, linear convolution of x and w, then plus b.
222    """
223    if pad_mode not in {"same", "valid"}:
224        pad_mode = "full"
225    y = np.convolve(x, w, pad_mode)
226    if b:
227        y += b
228    return y
229
230
231def conv2d(x, weight, bias=None, stride=1, pad=0,
232           dilation=1, groups=1, padding_mode='zeros'):
233    """Convolution 2D."""
234    # pylint: disable=unused-argument
235    validator.check_value_type('stride', stride, (int, tuple))
236    if isinstance(stride, int):
237        stride = (stride, stride)
238    elif len(stride) == 4:
239        stride = (stride[2], stride[3])
240    if len(stride) != 2 or (not isinstance(stride[0], int)) or \
241            (not isinstance(stride[1], int)) or \
242            stride[0] < 1 or stride[1] < 1:
243        raise ValueError(f"The \'stride\' of \'conv2d\' should be an positive int number or "
244                         f"a tuple of two positive int numbers, but got {stride}")
245    stride_h = stride[0]
246    stride_w = stride[1]
247    validator.check_value_type('dilation', dilation, (int, tuple))
248    if isinstance(dilation, int):
249        dilation = (dilation, dilation)
250    elif len(dilation) == 4:
251        dilation = (dilation[2], dilation[3])
252    if len(dilation) != 2 or (not isinstance(dilation[0], int)) or \
253            (not isinstance(dilation[1], int)) or \
254            dilation[0] < 1 or dilation[1] < 1:
255        raise ValueError(f"The \'dilation\' of \'conv2d\' should be an positive int number or "
256                         f"a tuple of two positive int numbers, but got {dilation}")
257    dilation_h = dilation[0]
258    dilation_w = dilation[1]
259
260    if isinstance(pad, int):
261        pad_top = pad
262        pad_bottom = pad
263        pad_left = pad
264        pad_right = pad
265    elif isinstance(pad, tuple) and len(pad) == 4:
266        pad_top, pad_bottom, pad_left, pad_right = pad
267    else:
268        raise ValueError(f"The \'pad\' should be an int number or "
269                         f"a tuple of two or four int numbers, but got {pad}")
270
271    batch_num, _, x_h, x_w = x.shape
272    filter_num, _, filter_h, filter_w = weight.shape
273    out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
274    out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
275    col = im2col(x, filter_h, filter_w, stride, pad, dilation)
276    col_w = np.reshape(weight, (filter_num, -1)).T
277    out = np.dot(col, col_w)
278    out = out.reshape((batch_num, out_h, out_w, -1)).transpose(0, 3, 1, 2)
279    if bias is not None:
280        out += bias
281    return out
282
283
284def conv2d_backprop_filter(dout, x, w_size, stride=1, pad=0):
285    """Backpropagation filter for conv2d."""
286    filter_num, channel, filter_height, filter_width = w_size
287    dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
288    col = im2col(x, filter_height, filter_width, stride, pad)
289    dw = np.dot(col.T, dout)
290    dw = dw.transpose(1, 0).reshape((filter_num, channel, filter_height, filter_width))
291    return dw
292
293
294def conv2d_backprop_input(dout, x_size, weight, stride=1, pad=0):
295    """Backpropagation input for conv2d."""
296    filter_num, _, filter_h, filter_w = weight.shape
297    dout = dout.transpose(0, 2, 3, 1).reshape(-1, filter_num)
298    col_w = weight.reshape(filter_num, -1).T
299    dcol = np.dot(dout, col_w.T)
300    dx = col2im(dcol, x_size, filter_h, filter_w, stride, pad)
301    return dx
302
303
304def flatten(x):
305    """
306    Flattens an array to one dimension.
307
308    Args:
309        x (numpy.ndarray): An array to be flattened.
310
311    Returns:
312        numpy.ndarray, a flattened array in one dimension.
313    """
314    return x.flatten()
315
316
317def flatten2(x):
318    """
319    Flattens an array to one dimension by reshape.
320
321    Args:
322        x (numpy.ndarray): An array to be flattened.
323
324    Returns:
325        numpy.ndarray, a flattened array in one dimension.
326    """
327    return x.reshape(1, -1)
328
329
330def flatten_batch(x):
331    """
332    Flattens a batch of arrays to one dimension.
333
334    Args:
335        x (numpy.ndarray): A batch of arrays to be flattened.
336
337    Returns:
338        numpy.ndarray, a flattened one dimension array.
339    """
340    return x.reshape(x.shape[0], -1)
341
342
343def flatten_grad(dout, x):
344    """Grad of flatten."""
345    dout = np.reshape(dout, x)
346    return dout
347
348
349def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
350    """Rearranges an image to row vector."""
351    if isinstance(stride, int):
352        stride_h = stride
353        stride_w = stride
354    elif isinstance(stride, tuple) and len(stride) == 2:
355        stride_h = stride[0]
356        stride_w = stride[1]
357    elif isinstance(stride, tuple) and len(stride) == 4:
358        stride_h = stride[2]
359        stride_w = stride[3]
360    else:
361        raise ValueError(f"The \'stride\' should be an int number or "
362                         f"a tuple of two or four int numbers, but got {stride}")
363    if isinstance(dilation, int):
364        dilation_h = dilation
365        dilation_w = dilation
366    elif isinstance(dilation, tuple) and len(dilation) == 2:
367        dilation_h = dilation[0]
368        dilation_w = dilation[1]
369    elif isinstance(dilation, tuple) and len(dilation) == 4:
370        dilation_h = dilation[2]
371        dilation_w = dilation[3]
372    else:
373        raise ValueError(f"The \'dilation\' should be an int number or "
374                         f"a tuple of two or four int numbers, but got {dilation}")
375
376    if isinstance(pad, int):
377        pad_top = pad
378        pad_bottom = pad
379        pad_left = pad
380        pad_right = pad
381    elif isinstance(pad, tuple) and len(pad) == 4:
382        pad_top, pad_bottom, pad_left, pad_right = pad
383    else:
384        raise ValueError(f"The \'pad\' should be an int number or "
385                         f"a tuple of two or four int numbers, but got {pad}")
386
387    batch_num, channel, height, width = img.shape
388    out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
389    out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
390
391    img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
392    col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
393
394    for y in range(filter_h):
395        y_max = y + stride_h * out_h
396        for x in range(filter_w):
397            x_max = x + stride_h * out_w
398            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride_h, x:x_max:stride_h]
399
400    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(batch_num * out_h * out_w, -1)
401    return col
402
403
404def matmul(x, w, b=None):
405    """
406    Dot product of array x and w, then plus array b if b is not None.
407
408    Args:
409        x (numpy.ndarray): Represents the input array.
410        w (numpy.ndarray): Represents weights array.
411        b (numpy.ndarray): Represents bias array which has the same shape as x. Default: None.
412
413    Returns:
414        numpy.ndarray, the result of (x*w + b).
415    """
416    y = np.dot(x, w)
417    if b:
418        y += b
419    return y
420
421
422def max_pooling(x, pool_h, pool_w, stride):
423    """Max pooling."""
424    validator.check_positive_int(stride, "stride")
425    num, channel, height, width = x.shape
426    out_h = (height - pool_h) // stride + 1
427    out_w = (width - pool_w) // stride + 1
428
429    col = im2col(x, pool_h, pool_w, stride)
430    col = col.reshape(-1, pool_h * pool_w)
431
432    out = np.max(col, axis=1)
433    out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
434
435    return out
436
437
438def max_pool_grad(x, dout, pool_h, pool_w, stride):
439    """Grad of max pooling."""
440    dout = dout.transpose(0, 2, 3, 1)
441    pool_size = pool_h * pool_w
442    dmax = np.zeros((dout.size, pool_size), dout.dtype)
443    col = im2col(x, pool_h, pool_w, stride)
444    col = col.reshape(-1, pool_h * pool_w)
445    arg_max = np.argmax(col, axis=1)
446    dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
447    dmax = dmax.reshape(dout.shape + (pool_size,))
448    dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
449    dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
450    return dx
451
452
453def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
454    """Grad of max pooling with argmax."""
455    dout = dout.transpose(0, 2, 3, 1)
456    pool_size = pool_h * pool_w
457    dmax = np.zeros((dout.size, pool_size), dout.dtype)
458    dmax[np.arange(arg_max.size), arg_max.flatten()] = dout.flatten()
459    dmax = dmax.reshape(dout.shape + (pool_size,))
460    dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
461    dx = col2im(dcol, x.shape, pool_h, pool_w, stride)
462    return dx
463
464
465def max_pool_with_argmax(x, pool_h, pool_w, stride):
466    """Max pooling with argmax."""
467    validator.check_positive_int(stride, "stride")
468    num, channel, height, width = x.shape
469    out_h = (height - pool_h) // stride + 1
470    out_w = (width - pool_w) // stride + 1
471    col = im2col(x, pool_h, pool_w, stride)
472    col = col.reshape(-1, pool_h * pool_w)
473    out = np.max(col, axis=1)
474    out_argmax = np.argmax(col, axis=1)
475    out = out.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
476    out_argmax = out_argmax.reshape((num, out_h, out_w, channel)).transpose(0, 3, 1, 2)
477    return out, out_argmax
478
479
480def relu(x):
481    """
482    Rectified linear unit.
483
484    Args:
485        x (numpy.ndarray): The input array.
486
487    Returns:
488        numpy.ndarray, the array applied relu.
489    """
490    return x * (x > 0)
491
492
493def relu_grad(y):
494    """
495    Grad of relu.
496
497    Args:
498        y (numpy.ndarray): The input array.
499
500    Returns:
501        numpy.ndarray, the array applied grad of relu.
502    """
503    y[y <= 0] = 0
504    y[y > 0] = 1
505    return y
506
507
508def sigmoid(x):
509    """
510    Sigmoid activation function.
511
512    Args:
513        x (numpy.ndarray): The input array.
514
515    Returns:
516        numpy.ndarray, the array applied sigmoid.
517    """
518    return 1 / (1 + np.exp(x * -1))
519
520
521def tanh(x):
522    """
523    Computes hyperbolic tangent element-wise.
524
525    Args:
526        x (numpy.ndarray): The input array.
527
528    Returns:
529        numpy.ndarray, the array applied tanh.
530    """
531    a = np.exp(x) - np.exp(x * -1)
532    b = np.exp(x) + np.exp(x * -1)
533    return a / b
534
535
536def softmax(x, axis=None):
537    """
538    Softmax function which is `softmax(x) = np.exp(x)/sum(np.exp(x))`.
539
540    Args:
541        x (numpy.ndarray): Input array.
542        axis (Union[int, tuple[int]]): Axis to compute values along. Default: None.
543
544    Returns:
545        numpy.ndarray, has the same shape as x.
546    """
547    from scipy.special import softmax as scipy_softmax
548    return scipy_softmax(x, axis)
549
550
551def softmax_cross_entropy_with_logits(logits, labels):
552    sample_num = labels.shape[0]
553    prob = softmax(logits)
554    log_likelihood = -np.log(prob[range(sample_num)]) * labels
555    loss = np.sum(log_likelihood)
556    dx = prob.copy()
557    dx[range(sample_num)] -= labels
558    return loss, dx
559
560
561def shape(x):
562    """
563    Gets the array's dimensions.
564
565    Args:
566        x (numpy.ndarray): Input array.
567
568    Returns:
569        tuple, the shape/dimensions of the input array.
570    """
571    return np.array(np.shape(x))
572
573
574def expand_dims(x, axis):
575    """
576    Expands the shape of an array.
577
578    Args:
579        x (numpy.ndarray): Input array.
580        axis (int): Position in the expanded axes where the new axis is placed.
581
582    Returns:
583        numpy.ndarray, view of input array with the number of dimensions increased by one.
584    """
585    return np.expand_dims(x, axis)
586
587
588def squeeze(x, axis):
589    """
590    Removes single-dimensional entries from the shape of an array.
591
592    Args:
593        x (numpy.ndarray): Input array.
594        axis (Union[int, tuple[int]]): Selected subset of the single-dimensional entries in the shape.
595
596    Returns:
597        numpy.ndarray, the input numpy.ndarray, but with all or a subset of the dimensions of length
598        1 removed.
599    """
600    return np.squeeze(x, tuple(axis))
601
602
603def reshape(x, shp):
604    """
605    Applies a new shape to an array without changing its data.
606
607    Args:
608        x (numpy.ndarray): Input array.
609        shp (tuple[int]): New shape to apply to x.
610
611    Returns:
612        numpy.ndarray, a new view object or a copy of input array.
613    """
614    return np.reshape(x, tuple(shp))
615
616
617def rank(x):
618    """
619    Gets number of array dimensions.
620
621    Args:
622        x (numpy.ndarray): Input array.
623
624    Returns:
625        int, number of input array dimensions.
626    """
627    return np.array(np.ndim(x))
628
629
630def logsoftmax(x):
631    """
632    Log softmax function.
633
634    Args:
635        x (numpy.ndarray): Input array.
636
637    Returns:
638        numpy.ndarray, the result of applying log softmax on the input array.
639    """
640    return np.array(np.log(softmax(x)))
641
642
643def transpose(x, axes=None):
644    """
645    Transposes an input array according to axes.
646
647    Args:
648        x (numpy.ndarray): Input array.
649        axes (list): The axes to be transposed. Default: None.
650
651    Returns:
652        numpy.ndarray, transposed array.
653    """
654    return np.transpose(x, axes)
655
656
657def invert_permutation(x):
658    """
659    Gets the inverse permutation of an array.
660
661    Args:
662        x (numpy.ndarray): Input array.
663
664    Returns:
665        tuple, the inverse permutation of the input array.
666    """
667    x = np.array(x)
668    y = np.argsort(x)
669    return tuple(y)
670
671
672def select(cond, x, y):
673    """
674    Gets elements from x or y depending on cond.
675
676    Args:
677        cond (bool): Where True, yield x, otherwise yield y.
678        x (numpy.ndarray): Values from which to choose.
679        y (numpy.ndarray): Values from which to choose.
680
681    Returns:
682        numpy.ndarray, elements from x where condition is True, and elements from y elsewhere.
683    """
684    return np.where(cond, x, y)
685
686
687def sum_by_axis(x, axis):
688    """
689    Sum of array elements over a given axis.
690
691    Args:
692        x (numpy.ndarray): Input array.
693        axis (Union[int, tuple[int]]): Axis or axes along which a sum is performed.
694
695    Returns:
696        numpy.ndarray, has the same shape as input array with the specified axis removed.
697    """
698    return np.sum(x, axis)
699
700
701def equal(x, y):
702    """
703    Gets (x == y) element-wise.
704
705    Args:
706        x (numpy.ndarray): Input array.
707        y (numpy.ndarray): Input array.
708
709    Returns:
710        numpy.ndarray, element-wise comparison of x and y.
711    """
712    return np.equal(x, y)
713
714
715def not_equal(x, y):
716    """
717    Gets (x != y) element-wise.
718
719    Args:
720        x (numpy.ndarray): Input array.
721        y (numpy.ndarray): Input array.
722
723    Returns:
724        numpy.ndarray, element-wise comparison of x and y.
725    """
726    return np.not_equal(x, y)
727
728
729def greater(x, y):
730    """
731    Get the truth value of (x > y) element-wise.
732
733    Args:
734        x (numpy.ndarray): Input array.
735        y (numpy.ndarray): Input array.
736
737    Returns:
738        numpy.ndarray, element-wise comparison of x and y.
739    """
740    return np.greater(x, y)
741
742
743def less(x, y):
744    """
745    Get the truth value of (x < y) element-wise.
746
747    Args:
748        x (numpy.ndarray): Input array.
749        y (numpy.ndarray): Input array.
750
751    Returns:
752        Array, element-wise comparison of x and y.
753    """
754    return np.less(x, y)
755
756
757def logical_not(x):
758    """
759    Gets the truth value of NOT x element-wise.
760
761    Args:
762        x (numpy.ndarray): Input array.
763
764    Returns:
765        bool, have the same shape as x of the NOT operation on elements of x.
766    """
767    return np.logical_not(x)
768
769
770def sqrt(x):
771    """
772    Gets the non-negative square-root of an numpy.ndarray, element-wise.
773
774    Args:
775        x (numpy.ndarray): Input array.
776
777    Returns:
778        numpy.ndarray, has the same shape as x, containing the positive square-root of each
779        element in x.
780    """
781    return np.sqrt(x)
782
783
784def power(x, y):
785    """
786    First array elements raised to powers from second numpy.ndarray, element-wise.
787
788    Args:
789        x (numpy.ndarray): The bases array.
790        y (numpy.ndarray): The exponents array.
791
792    Returns:
793        numpy.ndarray, the bases in x raised to the exponents in y.
794    """
795    return np.power(x, y)
796
797
798def exp(x):
799    """
800    Gets the exponential of all elements in the input array.
801
802    Args:
803        x (numpy.ndarray): Input array.
804
805    Returns:
806        numpy.ndarray, element-wise exponential of x.
807    """
808    return np.exp(x)
809
810
811def maximum(x, y):
812    """
813    Gets the max of x and y element-wise.
814
815    If x > y, return x. Otherwise, return y.
816
817    Args:
818        x (numpy.ndarray): First input array.
819        y (numpy.ndarray): Second input array ave the same type as x.
820
821    Returns:
822        numpy.ndarray, has the same type as x.
823    """
824    return np.maximum(x, y)
825
826
827def minimum(x, y):
828    """
829    Gets the min of x and y element-wise.
830
831    If x < y, return x. Otherwise, return y.
832
833    Args:
834        x (numpy.ndarray): First input array.
835        y (numpy.ndarray): Second input array have the same type as x.
836
837    Returns:
838        numpy.ndarray, has the same type as x.
839    """
840    return np.minimum(x, y)
841
842
843def all_(x, axis=(), keep_dims=False):
844    """
845    Check all array elements along a given axis evaluate to True.
846
847    Args:
848        x (numpy.ndarray): An array to be reduced.
849        axis (Union[None, int, tuple(int)): Dimensions of reduction.
850        keep_dims (bool): Whether to keep the reduced dimensions.
851
852    Returns:
853        numpy.ndarray, has the same type as x.
854    """
855    axis = None if axis == () else axis
856    return np.all(x, axis, keepdims=keep_dims)
857
858
859def any_(x, axis=(), keep_dims=False):
860    """
861    Check any array element along a given axis evaluate to True.
862
863    Args:
864        x (numpy.ndarray): An array to be reduced.
865        axis (Union[None, int, tuple(int)): Dimensions of reduction.
866        keep_dims (bool): Whether to keep the reduced dimensions.
867
868    Returns:
869        numpy.ndarray, has the same type as x.
870    """
871    axis = None if axis == () else axis
872    return np.any(x, axis, keepdims=keep_dims)
873