1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fast-Fourier Transform ops.""" 16import re 17 18import re 19 20import numpy as np 21 22from tensorflow.python.framework import dtypes as _dtypes 23from tensorflow.python.framework import ops as _ops 24from tensorflow.python.framework import tensor_util as _tensor_util 25from tensorflow.python.ops import array_ops as _array_ops 26from tensorflow.python.ops import gen_spectral_ops 27from tensorflow.python.ops import manip_ops 28from tensorflow.python.ops import math_ops as _math_ops 29from tensorflow.python.util import dispatch 30from tensorflow.python.util.tf_export import tf_export 31 32 33def _infer_fft_length_for_rfft(input_tensor, fft_rank): 34 """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`.""" 35 # A TensorShape for the inner fft_rank dimensions. 36 fft_shape = input_tensor.get_shape()[-fft_rank:] 37 38 # If any dim is unknown, fall back to tensor-based math. 39 if not fft_shape.is_fully_defined(): 40 return _array_ops.shape(input_tensor)[-fft_rank:] 41 42 # Otherwise, return a constant. 43 return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32) 44 45 46def _infer_fft_length_for_irfft(input_tensor, fft_rank): 47 """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`.""" 48 # A TensorShape for the inner fft_rank dimensions. 49 fft_shape = input_tensor.get_shape()[-fft_rank:] 50 51 # If any dim is unknown, fall back to tensor-based math. 52 if not fft_shape.is_fully_defined(): 53 fft_length = _array_ops.unstack(_array_ops.shape(input_tensor)[-fft_rank:]) 54 fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1)) 55 return _array_ops.stack(fft_length) 56 57 # Otherwise, return a constant. 58 fft_length = fft_shape.as_list() 59 if fft_length: 60 fft_length[-1] = max(0, 2 * (fft_length[-1] - 1)) 61 return _ops.convert_to_tensor(fft_length, _dtypes.int32) 62 63 64def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): 65 """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" 66 fft_shape = _tensor_util.constant_value_as_shape(fft_length) 67 68 # Edge case: skip padding empty tensors. 69 if (input_tensor.shape.ndims is not None and 70 any(dim.value == 0 for dim in input_tensor.shape.dims)): 71 return input_tensor 72 73 # If we know the shapes ahead of time, we can either skip or pre-compute the 74 # appropriate paddings. Otherwise, fall back to computing paddings in 75 # TensorFlow. 76 if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: 77 # Slice the last FFT-rank dimensions from input_tensor's shape. 78 input_fft_shape = input_tensor.shape[-fft_shape.ndims:] # pylint: disable=invalid-unary-operand-type 79 80 if input_fft_shape.is_fully_defined(): 81 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 82 if is_reverse: 83 fft_shape = fft_shape[:-1].concatenate( 84 fft_shape.dims[-1].value // 2 + 1) 85 86 paddings = [[0, max(fft_dim.value - input_dim.value, 0)] 87 for fft_dim, input_dim in zip( 88 fft_shape.dims, input_fft_shape.dims)] 89 if any(pad > 0 for _, pad in paddings): 90 outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - 91 fft_shape.ndims), 0) 92 return _array_ops.pad(input_tensor, outer_paddings + paddings) 93 return input_tensor 94 95 # If we can't determine the paddings ahead of time, then we have to pad. If 96 # the paddings end up as zero, tf.pad has a special-case that does no work. 97 input_rank = _array_ops.rank(input_tensor) 98 input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] 99 outer_dims = _math_ops.maximum(0, input_rank - fft_rank) 100 outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) 101 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 102 if is_reverse: 103 fft_length = _array_ops.concat([fft_length[:-1], 104 fft_length[-1:] // 2 + 1], 0) 105 fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) 106 paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) 107 paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings], 108 axis=1) 109 return _array_ops.pad(input_tensor, paddings) 110 111 112def _rfft_wrapper(fft_fn, fft_rank, default_name): 113 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 114 115 def _rfft(input_tensor, fft_length=None, name=None): 116 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 117 with _ops.name_scope(name, default_name, 118 [input_tensor, fft_length]) as name: 119 input_tensor = _ops.convert_to_tensor(input_tensor, 120 preferred_dtype=_dtypes.float32) 121 if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64): 122 raise ValueError( 123 "RFFT requires tf.float32 or tf.float64 inputs, got: %s" % 124 input_tensor) 125 real_dtype = input_tensor.dtype 126 if real_dtype == _dtypes.float32: 127 complex_dtype = _dtypes.complex64 128 else: 129 assert real_dtype == _dtypes.float64 130 complex_dtype = _dtypes.complex128 131 input_tensor.shape.with_rank_at_least(fft_rank) 132 if fft_length is None: 133 fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) 134 else: 135 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 136 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) 137 138 fft_length_static = _tensor_util.constant_value(fft_length) 139 if fft_length_static is not None: 140 fft_length = fft_length_static 141 return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name) 142 _rfft.__doc__ = re.sub(" Tcomplex.*?\n", "", fft_fn.__doc__) 143 return _rfft 144 145 146def _irfft_wrapper(ifft_fn, fft_rank, default_name): 147 """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument.""" 148 149 def _irfft(input_tensor, fft_length=None, name=None): 150 """Wrapper irfft* that infers fft_length argument.""" 151 with _ops.name_scope(name, default_name, 152 [input_tensor, fft_length]) as name: 153 input_tensor = _ops.convert_to_tensor(input_tensor, 154 preferred_dtype=_dtypes.complex64) 155 input_tensor.shape.with_rank_at_least(fft_rank) 156 if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128): 157 raise ValueError( 158 "IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" % 159 input_tensor) 160 complex_dtype = input_tensor.dtype 161 real_dtype = complex_dtype.real_dtype 162 if fft_length is None: 163 fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) 164 else: 165 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 166 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, 167 is_reverse=True) 168 fft_length_static = _tensor_util.constant_value(fft_length) 169 if fft_length_static is not None: 170 fft_length = fft_length_static 171 return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name) 172 _irfft.__doc__ = re.sub(" Treal.*?\n", "", ifft_fn.__doc__) 173 return _irfft 174 175 176# FFT/IFFT 1/2/3D are exported via 177# third_party/tensorflow/core/api_def/python_api/ 178fft = gen_spectral_ops.fft 179ifft = gen_spectral_ops.ifft 180fft2d = gen_spectral_ops.fft2d 181ifft2d = gen_spectral_ops.ifft2d 182fft3d = gen_spectral_ops.fft3d 183ifft3d = gen_spectral_ops.ifft3d 184rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") 185tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])( 186 dispatch.add_dispatch_support(rfft)) 187irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") 188tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])( 189 dispatch.add_dispatch_support(irfft)) 190rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") 191tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])( 192 dispatch.add_dispatch_support(rfft2d)) 193irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") 194tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])( 195 dispatch.add_dispatch_support(irfft2d)) 196rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") 197tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])( 198 dispatch.add_dispatch_support(rfft3d)) 199irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") 200tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])( 201 dispatch.add_dispatch_support(irfft3d)) 202 203 204def _fft_size_for_grad(grad, rank): 205 return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:]) 206 207 208@_ops.RegisterGradient("FFT") 209def _fft_grad(_, grad): 210 size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype) 211 return ifft(grad) * size 212 213 214@_ops.RegisterGradient("IFFT") 215def _ifft_grad(_, grad): 216 rsize = _math_ops.cast( 217 1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype), 218 grad.dtype) 219 return fft(grad) * rsize 220 221 222@_ops.RegisterGradient("FFT2D") 223def _fft2d_grad(_, grad): 224 size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype) 225 return ifft2d(grad) * size 226 227 228@_ops.RegisterGradient("IFFT2D") 229def _ifft2d_grad(_, grad): 230 rsize = _math_ops.cast( 231 1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype), 232 grad.dtype) 233 return fft2d(grad) * rsize 234 235 236@_ops.RegisterGradient("FFT3D") 237def _fft3d_grad(_, grad): 238 size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype) 239 return ifft3d(grad) * size 240 241 242@_ops.RegisterGradient("IFFT3D") 243def _ifft3d_grad(_, grad): 244 rsize = _math_ops.cast( 245 1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype), 246 grad.dtype) 247 return fft3d(grad) * rsize 248 249 250def _rfft_grad_helper(rank, irfft_fn): 251 """Returns a gradient function for an RFFT of the provided rank.""" 252 # Can't happen because we don't register a gradient for RFFT3D. 253 assert rank in (1, 2), "Gradient for RFFT3D is not implemented." 254 255 def _grad(op, grad): 256 """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" 257 fft_length = op.inputs[1] 258 complex_dtype = grad.dtype 259 real_dtype = complex_dtype.real_dtype 260 input_shape = _array_ops.shape(op.inputs[0]) 261 is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype) 262 263 def _tile_for_broadcasting(matrix, t): 264 expanded = _array_ops.reshape( 265 matrix, 266 _array_ops.concat([ 267 _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32), 268 _array_ops.shape(matrix) 269 ], 0)) 270 return _array_ops.tile( 271 expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0)) 272 273 def _mask_matrix(length): 274 """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len).""" 275 # TODO(rjryan): Speed up computation of twiddle factors using the 276 # following recurrence relation and cache them across invocations of RFFT. 277 # 278 # t_n = exp(sqrt(-1) * pi * n^2 / line_len) 279 # for n = 0, 1,..., line_len-1. 280 # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2 281 a = _array_ops.tile( 282 _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1)) 283 b = _array_ops.transpose(a, [1, 0]) 284 return _math_ops.exp( 285 -2j * np.pi * _math_ops.cast(a * b, complex_dtype) / 286 _math_ops.cast(length, complex_dtype)) 287 288 def _ymask(length): 289 """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" 290 return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2), 291 complex_dtype) 292 293 y0 = grad[..., 0:1] 294 if rank == 1: 295 ym = grad[..., -1:] 296 extra_terms = y0 + is_even * ym * _ymask(input_shape[-1]) 297 elif rank == 2: 298 # Create a mask matrix for y0 and ym. 299 base_mask = _mask_matrix(input_shape[-2]) 300 301 # Tile base_mask to match y0 in shape so that we can batch-matmul the 302 # inner 2 dimensions. 303 tiled_mask = _tile_for_broadcasting(base_mask, y0) 304 305 y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0)) 306 extra_terms = y0_term 307 308 ym = grad[..., -1:] 309 ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym)) 310 311 inner_dim = input_shape[-1] 312 ym_term = _array_ops.tile( 313 ym_term, 314 _array_ops.concat([ 315 _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32), 316 [inner_dim] 317 ], 0)) * _ymask(inner_dim) 318 319 extra_terms += is_even * ym_term 320 321 # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling 322 # factor, plus some additional terms to make up for the components dropped 323 # due to Hermitian symmetry. 324 input_size = _math_ops.cast( 325 _fft_size_for_grad(op.inputs[0], rank), real_dtype) 326 the_irfft = irfft_fn(grad, fft_length) 327 return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None 328 329 return _grad 330 331 332def _irfft_grad_helper(rank, rfft_fn): 333 """Returns a gradient function for an IRFFT of the provided rank.""" 334 # Can't happen because we don't register a gradient for IRFFT3D. 335 assert rank in (1, 2), "Gradient for IRFFT3D is not implemented." 336 337 def _grad(op, grad): 338 """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" 339 # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs 340 # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the 341 # graph we special-case the situation where the FFT length and last 342 # dimension of the input are known at graph construction time. 343 fft_length = op.inputs[1] 344 fft_length_static = _tensor_util.constant_value(fft_length) 345 if fft_length_static is not None: 346 fft_length = fft_length_static 347 real_dtype = grad.dtype 348 if real_dtype == _dtypes.float32: 349 complex_dtype = _dtypes.complex64 350 elif real_dtype == _dtypes.float64: 351 complex_dtype = _dtypes.complex128 352 is_odd = _math_ops.mod(fft_length[-1], 2) 353 input_last_dimension = _array_ops.shape(op.inputs[0])[-1] 354 mask = _array_ops.concat( 355 [[1.0], 2.0 * _array_ops.ones( 356 [input_last_dimension - 2 + is_odd], real_dtype), 357 _array_ops.ones([1 - is_odd], real_dtype)], 0) 358 359 rsize = _math_ops.reciprocal(_math_ops.cast( 360 _fft_size_for_grad(grad, rank), real_dtype)) 361 362 # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling 363 # factor and a mask. The mask scales the gradient for the Hermitian 364 # symmetric components of the RFFT by a factor of two, since these 365 # components are de-duplicated in the RFFT. 366 the_rfft = rfft_fn(grad, fft_length) 367 return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None 368 369 return _grad 370 371 372@tf_export("signal.fftshift") 373@dispatch.add_dispatch_support 374def fftshift(x, axes=None, name=None): 375 """Shift the zero-frequency component to the center of the spectrum. 376 377 This function swaps half-spaces for all axes listed (defaults to all). 378 Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. 379 380 @compatibility(numpy) 381 Equivalent to numpy.fft.fftshift. 382 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html 383 @end_compatibility 384 385 For example: 386 387 ```python 388 x = tf.signal.fftshift([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.]) 389 x.numpy() # array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.]) 390 ``` 391 392 Args: 393 x: `Tensor`, input tensor. 394 axes: `int` or shape `tuple`, optional Axes over which to shift. Default is 395 None, which shifts all axes. 396 name: An optional name for the operation. 397 398 Returns: 399 A `Tensor`, The shifted tensor. 400 """ 401 with _ops.name_scope(name, "fftshift") as name: 402 x = _ops.convert_to_tensor(x) 403 if axes is None: 404 axes = tuple(range(x.shape.ndims)) 405 shift = _array_ops.shape(x) // 2 406 elif isinstance(axes, int): 407 shift = _array_ops.shape(x)[axes] // 2 408 else: 409 rank = _array_ops.rank(x) 410 # allows negative axis 411 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 412 shift = _array_ops.gather(_array_ops.shape(x), axes) // 2 413 414 return manip_ops.roll(x, shift, axes, name) 415 416 417@tf_export("signal.ifftshift") 418@dispatch.add_dispatch_support 419def ifftshift(x, axes=None, name=None): 420 """The inverse of fftshift. 421 422 Although identical for even-length x, 423 the functions differ by one sample for odd-length x. 424 425 @compatibility(numpy) 426 Equivalent to numpy.fft.ifftshift. 427 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html 428 @end_compatibility 429 430 For example: 431 432 ```python 433 x = tf.signal.ifftshift([[ 0., 1., 2.],[ 3., 4., -4.],[-3., -2., -1.]]) 434 x.numpy() # array([[ 4., -4., 3.],[-2., -1., -3.],[ 1., 2., 0.]]) 435 ``` 436 437 Args: 438 x: `Tensor`, input tensor. 439 axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None, 440 which shifts all axes. 441 name: An optional name for the operation. 442 443 Returns: 444 A `Tensor`, The shifted tensor. 445 """ 446 with _ops.name_scope(name, "ifftshift") as name: 447 x = _ops.convert_to_tensor(x) 448 if axes is None: 449 axes = tuple(range(x.shape.ndims)) 450 shift = -(_array_ops.shape(x) // 2) 451 elif isinstance(axes, int): 452 shift = -(_array_ops.shape(x)[axes] // 2) 453 else: 454 rank = _array_ops.rank(x) 455 # allows negative axis 456 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 457 shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2) 458 459 return manip_ops.roll(x, shift, axes, name) 460 461 462_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft)) 463_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft)) 464_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d)) 465_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d)) 466