# mypy: allow-untyped-defs """This module converts objects into numpy array.""" import numpy as np import torch def make_np(x): """ Convert an object into numpy array. Args: x: An instance of torch tensor Returns: numpy.array: Numpy array """ if isinstance(x, np.ndarray): return x if np.isscalar(x): return np.array([x]) if isinstance(x, torch.Tensor): return _prepare_pytorch(x) raise NotImplementedError( f"Got {type(x)}, but numpy array or torch tensor are expected." ) def _prepare_pytorch(x): if x.dtype == torch.bfloat16: x = x.to(torch.float16) x = x.detach().cpu().numpy() return x