1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Fast-Fourier Transform ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import re 21 22import numpy as np 23 24from tensorflow.python.framework import dtypes as _dtypes 25from tensorflow.python.framework import ops as _ops 26from tensorflow.python.framework import tensor_util as _tensor_util 27from tensorflow.python.ops import array_ops as _array_ops 28from tensorflow.python.ops import gen_spectral_ops 29from tensorflow.python.ops import manip_ops 30from tensorflow.python.ops import math_ops as _math_ops 31from tensorflow.python.util import dispatch 32from tensorflow.python.util.tf_export import tf_export 33 34 35def _infer_fft_length_for_rfft(input_tensor, fft_rank): 36 """Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`.""" 37 # A TensorShape for the inner fft_rank dimensions. 38 fft_shape = input_tensor.get_shape()[-fft_rank:] 39 40 # If any dim is unknown, fall back to tensor-based math. 41 if not fft_shape.is_fully_defined(): 42 return _array_ops.shape(input_tensor)[-fft_rank:] 43 44 # Otherwise, return a constant. 45 return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32) 46 47 48def _infer_fft_length_for_irfft(input_tensor, fft_rank): 49 """Infers the `fft_length` argument for a `rank` IRFFT from `input_tensor`.""" 50 # A TensorShape for the inner fft_rank dimensions. 51 fft_shape = input_tensor.get_shape()[-fft_rank:] 52 53 # If any dim is unknown, fall back to tensor-based math. 54 if not fft_shape.is_fully_defined(): 55 fft_length = _array_ops.unstack(_array_ops.shape(input_tensor)[-fft_rank:]) 56 fft_length[-1] = _math_ops.maximum(0, 2 * (fft_length[-1] - 1)) 57 return _array_ops.stack(fft_length) 58 59 # Otherwise, return a constant. 60 fft_length = fft_shape.as_list() 61 if fft_length: 62 fft_length[-1] = max(0, 2 * (fft_length[-1] - 1)) 63 return _ops.convert_to_tensor(fft_length, _dtypes.int32) 64 65 66def _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, is_reverse=False): 67 """Pads `input_tensor` to `fft_length` on its inner-most `fft_rank` dims.""" 68 fft_shape = _tensor_util.constant_value_as_shape(fft_length) 69 70 # Edge case: skip padding empty tensors. 71 if (input_tensor.shape.ndims is not None and 72 any(dim.value == 0 for dim in input_tensor.shape.dims)): 73 return input_tensor 74 75 # If we know the shapes ahead of time, we can either skip or pre-compute the 76 # appropriate paddings. Otherwise, fall back to computing paddings in 77 # TensorFlow. 78 if fft_shape.is_fully_defined() and input_tensor.shape.ndims is not None: 79 # Slice the last FFT-rank dimensions from input_tensor's shape. 80 input_fft_shape = input_tensor.shape[-fft_shape.ndims:] # pylint: disable=invalid-unary-operand-type 81 82 if input_fft_shape.is_fully_defined(): 83 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 84 if is_reverse: 85 fft_shape = fft_shape[:-1].concatenate( 86 fft_shape.dims[-1].value // 2 + 1) 87 88 paddings = [[0, max(fft_dim.value - input_dim.value, 0)] 89 for fft_dim, input_dim in zip( 90 fft_shape.dims, input_fft_shape.dims)] 91 if any(pad > 0 for _, pad in paddings): 92 outer_paddings = [[0, 0]] * max((input_tensor.shape.ndims - 93 fft_shape.ndims), 0) 94 return _array_ops.pad(input_tensor, outer_paddings + paddings) 95 return input_tensor 96 97 # If we can't determine the paddings ahead of time, then we have to pad. If 98 # the paddings end up as zero, tf.pad has a special-case that does no work. 99 input_rank = _array_ops.rank(input_tensor) 100 input_fft_shape = _array_ops.shape(input_tensor)[-fft_rank:] 101 outer_dims = _math_ops.maximum(0, input_rank - fft_rank) 102 outer_paddings = _array_ops.zeros([outer_dims], fft_length.dtype) 103 # In reverse, we only pad the inner-most dimension to fft_length / 2 + 1. 104 if is_reverse: 105 fft_length = _array_ops.concat([fft_length[:-1], 106 fft_length[-1:] // 2 + 1], 0) 107 fft_paddings = _math_ops.maximum(0, fft_length - input_fft_shape) 108 paddings = _array_ops.concat([outer_paddings, fft_paddings], 0) 109 paddings = _array_ops.stack([_array_ops.zeros_like(paddings), paddings], 110 axis=1) 111 return _array_ops.pad(input_tensor, paddings) 112 113 114def _rfft_wrapper(fft_fn, fft_rank, default_name): 115 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 116 117 def _rfft(input_tensor, fft_length=None, name=None): 118 """Wrapper around gen_spectral_ops.rfft* that infers fft_length argument.""" 119 with _ops.name_scope(name, default_name, 120 [input_tensor, fft_length]) as name: 121 input_tensor = _ops.convert_to_tensor(input_tensor, 122 preferred_dtype=_dtypes.float32) 123 if input_tensor.dtype not in (_dtypes.float32, _dtypes.float64): 124 raise ValueError( 125 "RFFT requires tf.float32 or tf.float64 inputs, got: %s" % 126 input_tensor) 127 real_dtype = input_tensor.dtype 128 if real_dtype == _dtypes.float32: 129 complex_dtype = _dtypes.complex64 130 else: 131 assert real_dtype == _dtypes.float64 132 complex_dtype = _dtypes.complex128 133 input_tensor.shape.with_rank_at_least(fft_rank) 134 if fft_length is None: 135 fft_length = _infer_fft_length_for_rfft(input_tensor, fft_rank) 136 else: 137 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 138 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length) 139 140 fft_length_static = _tensor_util.constant_value(fft_length) 141 if fft_length_static is not None: 142 fft_length = fft_length_static 143 return fft_fn(input_tensor, fft_length, Tcomplex=complex_dtype, name=name) 144 _rfft.__doc__ = re.sub(" Tcomplex.*?\n", "", fft_fn.__doc__) 145 return _rfft 146 147 148def _irfft_wrapper(ifft_fn, fft_rank, default_name): 149 """Wrapper around gen_spectral_ops.irfft* that infers fft_length argument.""" 150 151 def _irfft(input_tensor, fft_length=None, name=None): 152 """Wrapper irfft* that infers fft_length argument.""" 153 with _ops.name_scope(name, default_name, 154 [input_tensor, fft_length]) as name: 155 input_tensor = _ops.convert_to_tensor(input_tensor, 156 preferred_dtype=_dtypes.complex64) 157 input_tensor.shape.with_rank_at_least(fft_rank) 158 if input_tensor.dtype not in (_dtypes.complex64, _dtypes.complex128): 159 raise ValueError( 160 "IRFFT requires tf.complex64 or tf.complex128 inputs, got: %s" % 161 input_tensor) 162 complex_dtype = input_tensor.dtype 163 real_dtype = complex_dtype.real_dtype 164 if fft_length is None: 165 fft_length = _infer_fft_length_for_irfft(input_tensor, fft_rank) 166 else: 167 fft_length = _ops.convert_to_tensor(fft_length, _dtypes.int32) 168 input_tensor = _maybe_pad_for_rfft(input_tensor, fft_rank, fft_length, 169 is_reverse=True) 170 fft_length_static = _tensor_util.constant_value(fft_length) 171 if fft_length_static is not None: 172 fft_length = fft_length_static 173 return ifft_fn(input_tensor, fft_length, Treal=real_dtype, name=name) 174 _irfft.__doc__ = re.sub(" Treal.*?\n", "", ifft_fn.__doc__) 175 return _irfft 176 177 178# FFT/IFFT 1/2/3D are exported via 179# third_party/tensorflow/core/api_def/python_api/ 180fft = gen_spectral_ops.fft 181ifft = gen_spectral_ops.ifft 182fft2d = gen_spectral_ops.fft2d 183ifft2d = gen_spectral_ops.ifft2d 184fft3d = gen_spectral_ops.fft3d 185ifft3d = gen_spectral_ops.ifft3d 186rfft = _rfft_wrapper(gen_spectral_ops.rfft, 1, "rfft") 187tf_export("signal.rfft", v1=["signal.rfft", "spectral.rfft"])( 188 dispatch.add_dispatch_support(rfft)) 189irfft = _irfft_wrapper(gen_spectral_ops.irfft, 1, "irfft") 190tf_export("signal.irfft", v1=["signal.irfft", "spectral.irfft"])( 191 dispatch.add_dispatch_support(irfft)) 192rfft2d = _rfft_wrapper(gen_spectral_ops.rfft2d, 2, "rfft2d") 193tf_export("signal.rfft2d", v1=["signal.rfft2d", "spectral.rfft2d"])( 194 dispatch.add_dispatch_support(rfft2d)) 195irfft2d = _irfft_wrapper(gen_spectral_ops.irfft2d, 2, "irfft2d") 196tf_export("signal.irfft2d", v1=["signal.irfft2d", "spectral.irfft2d"])( 197 dispatch.add_dispatch_support(irfft2d)) 198rfft3d = _rfft_wrapper(gen_spectral_ops.rfft3d, 3, "rfft3d") 199tf_export("signal.rfft3d", v1=["signal.rfft3d", "spectral.rfft3d"])( 200 dispatch.add_dispatch_support(rfft3d)) 201irfft3d = _irfft_wrapper(gen_spectral_ops.irfft3d, 3, "irfft3d") 202tf_export("signal.irfft3d", v1=["signal.irfft3d", "spectral.irfft3d"])( 203 dispatch.add_dispatch_support(irfft3d)) 204 205 206def _fft_size_for_grad(grad, rank): 207 return _math_ops.reduce_prod(_array_ops.shape(grad)[-rank:]) 208 209 210@_ops.RegisterGradient("FFT") 211def _fft_grad(_, grad): 212 size = _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype) 213 return ifft(grad) * size 214 215 216@_ops.RegisterGradient("IFFT") 217def _ifft_grad(_, grad): 218 rsize = _math_ops.cast( 219 1. / _math_ops.cast(_fft_size_for_grad(grad, 1), grad.dtype.real_dtype), 220 grad.dtype) 221 return fft(grad) * rsize 222 223 224@_ops.RegisterGradient("FFT2D") 225def _fft2d_grad(_, grad): 226 size = _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype) 227 return ifft2d(grad) * size 228 229 230@_ops.RegisterGradient("IFFT2D") 231def _ifft2d_grad(_, grad): 232 rsize = _math_ops.cast( 233 1. / _math_ops.cast(_fft_size_for_grad(grad, 2), grad.dtype.real_dtype), 234 grad.dtype) 235 return fft2d(grad) * rsize 236 237 238@_ops.RegisterGradient("FFT3D") 239def _fft3d_grad(_, grad): 240 size = _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype) 241 return ifft3d(grad) * size 242 243 244@_ops.RegisterGradient("IFFT3D") 245def _ifft3d_grad(_, grad): 246 rsize = _math_ops.cast( 247 1. / _math_ops.cast(_fft_size_for_grad(grad, 3), grad.dtype.real_dtype), 248 grad.dtype) 249 return fft3d(grad) * rsize 250 251 252def _rfft_grad_helper(rank, irfft_fn): 253 """Returns a gradient function for an RFFT of the provided rank.""" 254 # Can't happen because we don't register a gradient for RFFT3D. 255 assert rank in (1, 2), "Gradient for RFFT3D is not implemented." 256 257 def _grad(op, grad): 258 """A gradient function for RFFT with the provided `rank` and `irfft_fn`.""" 259 fft_length = op.inputs[1] 260 complex_dtype = grad.dtype 261 real_dtype = complex_dtype.real_dtype 262 input_shape = _array_ops.shape(op.inputs[0]) 263 is_even = _math_ops.cast(1 - (fft_length[-1] % 2), complex_dtype) 264 265 def _tile_for_broadcasting(matrix, t): 266 expanded = _array_ops.reshape( 267 matrix, 268 _array_ops.concat([ 269 _array_ops.ones([_array_ops.rank(t) - 2], _dtypes.int32), 270 _array_ops.shape(matrix) 271 ], 0)) 272 return _array_ops.tile( 273 expanded, _array_ops.concat([_array_ops.shape(t)[:-2], [1, 1]], 0)) 274 275 def _mask_matrix(length): 276 """Computes t_n = exp(sqrt(-1) * pi * n^2 / line_len).""" 277 # TODO(rjryan): Speed up computation of twiddle factors using the 278 # following recurrence relation and cache them across invocations of RFFT. 279 # 280 # t_n = exp(sqrt(-1) * pi * n^2 / line_len) 281 # for n = 0, 1,..., line_len-1. 282 # For n > 2, use t_n = t_{n-1}^2 / t_{n-2} * t_1^2 283 a = _array_ops.tile( 284 _array_ops.expand_dims(_math_ops.range(length), 0), (length, 1)) 285 b = _array_ops.transpose(a, [1, 0]) 286 return _math_ops.exp( 287 -2j * np.pi * _math_ops.cast(a * b, complex_dtype) / 288 _math_ops.cast(length, complex_dtype)) 289 290 def _ymask(length): 291 """A sequence of [1+0j, -1+0j, 1+0j, -1+0j, ...] with length `length`.""" 292 return _math_ops.cast(1 - 2 * (_math_ops.range(length) % 2), 293 complex_dtype) 294 295 y0 = grad[..., 0:1] 296 if rank == 1: 297 ym = grad[..., -1:] 298 extra_terms = y0 + is_even * ym * _ymask(input_shape[-1]) 299 elif rank == 2: 300 # Create a mask matrix for y0 and ym. 301 base_mask = _mask_matrix(input_shape[-2]) 302 303 # Tile base_mask to match y0 in shape so that we can batch-matmul the 304 # inner 2 dimensions. 305 tiled_mask = _tile_for_broadcasting(base_mask, y0) 306 307 y0_term = _math_ops.matmul(tiled_mask, _math_ops.conj(y0)) 308 extra_terms = y0_term 309 310 ym = grad[..., -1:] 311 ym_term = _math_ops.matmul(tiled_mask, _math_ops.conj(ym)) 312 313 inner_dim = input_shape[-1] 314 ym_term = _array_ops.tile( 315 ym_term, 316 _array_ops.concat([ 317 _array_ops.ones([_array_ops.rank(grad) - 1], _dtypes.int32), 318 [inner_dim] 319 ], 0)) * _ymask(inner_dim) 320 321 extra_terms += is_even * ym_term 322 323 # The gradient of RFFT is the IRFFT of the incoming gradient times a scaling 324 # factor, plus some additional terms to make up for the components dropped 325 # due to Hermitian symmetry. 326 input_size = _math_ops.cast( 327 _fft_size_for_grad(op.inputs[0], rank), real_dtype) 328 the_irfft = irfft_fn(grad, fft_length) 329 return 0.5 * (the_irfft * input_size + _math_ops.real(extra_terms)), None 330 331 return _grad 332 333 334def _irfft_grad_helper(rank, rfft_fn): 335 """Returns a gradient function for an IRFFT of the provided rank.""" 336 # Can't happen because we don't register a gradient for IRFFT3D. 337 assert rank in (1, 2), "Gradient for IRFFT3D is not implemented." 338 339 def _grad(op, grad): 340 """A gradient function for IRFFT with the provided `rank` and `rfft_fn`.""" 341 # Generate a simple mask like [1.0, 2.0, ..., 2.0, 1.0] for even-length FFTs 342 # and [1.0, 2.0, ..., 2.0] for odd-length FFTs. To reduce extra ops in the 343 # graph we special-case the situation where the FFT length and last 344 # dimension of the input are known at graph construction time. 345 fft_length = op.inputs[1] 346 fft_length_static = _tensor_util.constant_value(fft_length) 347 if fft_length_static is not None: 348 fft_length = fft_length_static 349 real_dtype = grad.dtype 350 if real_dtype == _dtypes.float32: 351 complex_dtype = _dtypes.complex64 352 elif real_dtype == _dtypes.float64: 353 complex_dtype = _dtypes.complex128 354 is_odd = _math_ops.mod(fft_length[-1], 2) 355 input_last_dimension = _array_ops.shape(op.inputs[0])[-1] 356 mask = _array_ops.concat( 357 [[1.0], 2.0 * _array_ops.ones( 358 [input_last_dimension - 2 + is_odd], real_dtype), 359 _array_ops.ones([1 - is_odd], real_dtype)], 0) 360 361 rsize = _math_ops.reciprocal(_math_ops.cast( 362 _fft_size_for_grad(grad, rank), real_dtype)) 363 364 # The gradient of IRFFT is the RFFT of the incoming gradient times a scaling 365 # factor and a mask. The mask scales the gradient for the Hermitian 366 # symmetric components of the RFFT by a factor of two, since these 367 # components are de-duplicated in the RFFT. 368 the_rfft = rfft_fn(grad, fft_length) 369 return the_rfft * _math_ops.cast(rsize * mask, complex_dtype), None 370 371 return _grad 372 373 374@tf_export("signal.fftshift") 375@dispatch.add_dispatch_support 376def fftshift(x, axes=None, name=None): 377 """Shift the zero-frequency component to the center of the spectrum. 378 379 This function swaps half-spaces for all axes listed (defaults to all). 380 Note that ``y[0]`` is the Nyquist component only if ``len(x)`` is even. 381 382 @compatibility(numpy) 383 Equivalent to numpy.fft.fftshift. 384 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.fftshift.html 385 @end_compatibility 386 387 For example: 388 389 ```python 390 x = tf.signal.fftshift([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.]) 391 x.numpy() # array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.]) 392 ``` 393 394 Args: 395 x: `Tensor`, input tensor. 396 axes: `int` or shape `tuple`, optional Axes over which to shift. Default is 397 None, which shifts all axes. 398 name: An optional name for the operation. 399 400 Returns: 401 A `Tensor`, The shifted tensor. 402 """ 403 with _ops.name_scope(name, "fftshift") as name: 404 x = _ops.convert_to_tensor(x) 405 if axes is None: 406 axes = tuple(range(x.shape.ndims)) 407 shift = _array_ops.shape(x) // 2 408 elif isinstance(axes, int): 409 shift = _array_ops.shape(x)[axes] // 2 410 else: 411 rank = _array_ops.rank(x) 412 # allows negative axis 413 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 414 shift = _array_ops.gather(_array_ops.shape(x), axes) // 2 415 416 return manip_ops.roll(x, shift, axes, name) 417 418 419@tf_export("signal.ifftshift") 420@dispatch.add_dispatch_support 421def ifftshift(x, axes=None, name=None): 422 """The inverse of fftshift. 423 424 Although identical for even-length x, 425 the functions differ by one sample for odd-length x. 426 427 @compatibility(numpy) 428 Equivalent to numpy.fft.ifftshift. 429 https://docs.scipy.org/doc/numpy/reference/generated/numpy.fft.ifftshift.html 430 @end_compatibility 431 432 For example: 433 434 ```python 435 x = tf.signal.ifftshift([[ 0., 1., 2.],[ 3., 4., -4.],[-3., -2., -1.]]) 436 x.numpy() # array([[ 4., -4., 3.],[-2., -1., -3.],[ 1., 2., 0.]]) 437 ``` 438 439 Args: 440 x: `Tensor`, input tensor. 441 axes: `int` or shape `tuple` Axes over which to calculate. Defaults to None, 442 which shifts all axes. 443 name: An optional name for the operation. 444 445 Returns: 446 A `Tensor`, The shifted tensor. 447 """ 448 with _ops.name_scope(name, "ifftshift") as name: 449 x = _ops.convert_to_tensor(x) 450 if axes is None: 451 axes = tuple(range(x.shape.ndims)) 452 shift = -(_array_ops.shape(x) // 2) 453 elif isinstance(axes, int): 454 shift = -(_array_ops.shape(x)[axes] // 2) 455 else: 456 rank = _array_ops.rank(x) 457 # allows negative axis 458 axes = _array_ops.where(_math_ops.less(axes, 0), axes + rank, axes) 459 shift = -(_array_ops.gather(_array_ops.shape(x), axes) // 2) 460 461 return manip_ops.roll(x, shift, axes, name) 462 463 464_ops.RegisterGradient("RFFT")(_rfft_grad_helper(1, irfft)) 465_ops.RegisterGradient("IRFFT")(_irfft_grad_helper(1, rfft)) 466_ops.RegisterGradient("RFFT2D")(_rfft_grad_helper(2, irfft2d)) 467_ops.RegisterGradient("IRFFT2D")(_irfft_grad_helper(2, rfft2d)) 468