1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""General shape ops for frames.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.signal import util_ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis): 30 """Infers the shape of the return value of `frame`.""" 31 frame_length = tensor_util.constant_value(frame_length) 32 frame_step = tensor_util.constant_value(frame_step) 33 axis = tensor_util.constant_value(axis) 34 if signal.shape.ndims is None: 35 return None 36 if axis is None: 37 return [None] * (signal.shape.ndims + 1) 38 39 signal_shape = signal.shape.as_list() 40 num_frames = None 41 frame_axis = signal_shape[axis] 42 outer_dimensions = signal_shape[:axis] 43 inner_dimensions = signal_shape[axis:][1:] 44 if signal_shape and frame_axis is not None: 45 if frame_step is not None and pad_end: 46 # Double negative is so that we round up. 47 num_frames = max(0, -(-frame_axis // frame_step)) 48 elif frame_step is not None and frame_length is not None: 49 assert not pad_end 50 num_frames = max( 51 0, (frame_axis - frame_length + frame_step) // frame_step) 52 return outer_dimensions + [num_frames, frame_length] + inner_dimensions 53 54 55@tf_export("signal.frame") 56def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1, 57 name=None): 58 """Expands `signal`'s `axis` dimension into frames of `frame_length`. 59 60 Slides a window of size `frame_length` over `signal`'s `axis` dimension 61 with a stride of `frame_step`, replacing the `axis` dimension with 62 `[frames, frame_length]` frames. 63 64 If `pad_end` is True, window positions that are past the end of the `axis` 65 dimension are padded with `pad_value` until the window moves fully past the 66 end of the dimension. Otherwise, only window positions that fully overlap the 67 `axis` dimension are produced. 68 69 For example: 70 71 ```python 72 # A batch size 3 tensor of 9152 audio samples. 73 audio = tf.random.normal([3, 9152]) 74 75 # Compute overlapping frames of length 512 with a step of 180 (frames overlap 76 # by 332 samples). By default, only 50 frames are generated since the last 77 # 152 samples do not form a full frame. 78 frames = tf.signal.frame(audio, 512, 180) 79 frames.shape.assert_is_compatible_with([3, 50, 512]) 80 81 # When pad_end is enabled, the final frame is kept (padded with zeros). 82 frames = tf.signal.frame(audio, 512, 180, pad_end=True) 83 frames.shape.assert_is_compatible_with([3, 51, 512]) 84 ``` 85 86 Args: 87 signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions 88 may be unknown. Rank must be at least 1. 89 frame_length: The frame length in samples. An integer or scalar `Tensor`. 90 frame_step: The frame hop size in samples. An integer or scalar `Tensor`. 91 pad_end: Whether to pad the end of `signal` with `pad_value`. 92 pad_value: An optional scalar `Tensor` to use where the input signal 93 does not exist when `pad_end` is True. 94 axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to 95 the last axis. Supports negative values for indexing from the end. 96 name: An optional name for the operation. 97 98 Returns: 99 A `Tensor` of frames with shape `[..., frames, frame_length, ...]`. 100 101 Raises: 102 ValueError: If `frame_length`, `frame_step`, `pad_value`, or `axis` are not 103 scalar. 104 """ 105 with ops.name_scope(name, "frame", [signal, frame_length, frame_step, 106 pad_value]): 107 signal = ops.convert_to_tensor(signal, name="signal") 108 frame_length = ops.convert_to_tensor(frame_length, name="frame_length") 109 frame_step = ops.convert_to_tensor(frame_step, name="frame_step") 110 axis = ops.convert_to_tensor(axis, name="axis") 111 112 signal.shape.with_rank_at_least(1) 113 frame_length.shape.assert_has_rank(0) 114 frame_step.shape.assert_has_rank(0) 115 axis.shape.assert_has_rank(0) 116 117 result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end, 118 axis) 119 120 # Axis can be negative. Convert it to positive. 121 signal_rank = array_ops.rank(signal) 122 axis = math_ops.range(signal_rank)[axis] 123 124 signal_shape = array_ops.shape(signal) 125 outer_dimensions, length_samples, inner_dimensions = array_ops.split( 126 signal_shape, [axis, 1, signal_rank - 1 - axis]) 127 length_samples = array_ops.reshape(length_samples, []) 128 num_outer_dimensions = array_ops.size(outer_dimensions) 129 num_inner_dimensions = array_ops.size(inner_dimensions) 130 131 # If padding is requested, pad the input signal tensor with pad_value. 132 if pad_end: 133 pad_value = ops.convert_to_tensor(pad_value, signal.dtype) 134 pad_value.shape.assert_has_rank(0) 135 136 # Calculate number of frames, using double negatives to round up. 137 num_frames = -(-length_samples // frame_step) 138 139 # Pad the signal by up to frame_length samples based on how many samples 140 # are remaining starting from last_frame_position. 141 pad_samples = math_ops.maximum( 142 0, frame_length + frame_step * (num_frames - 1) - length_samples) 143 144 # Pad the inner dimension of signal by pad_samples. 145 paddings = array_ops.concat( 146 [array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype), 147 [[0, pad_samples]], 148 array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)], 149 0) 150 signal = array_ops.pad(signal, paddings, constant_values=pad_value) 151 152 signal_shape = array_ops.shape(signal) 153 length_samples = signal_shape[axis] 154 else: 155 num_frames = math_ops.maximum( 156 0, 1 + (length_samples - frame_length) // frame_step) 157 158 subframe_length = util_ops.gcd(frame_length, frame_step) 159 subframes_per_frame = frame_length // subframe_length 160 subframes_per_hop = frame_step // subframe_length 161 num_subframes = length_samples // subframe_length 162 163 slice_shape = array_ops.concat([outer_dimensions, 164 [num_subframes * subframe_length], 165 inner_dimensions], 0) 166 subframe_shape = array_ops.concat([outer_dimensions, 167 [num_subframes, subframe_length], 168 inner_dimensions], 0) 169 subframes = array_ops.reshape(array_ops.strided_slice( 170 signal, array_ops.zeros_like(signal_shape), 171 slice_shape), subframe_shape) 172 173 # frame_selector is a [num_frames, subframes_per_frame] tensor 174 # that indexes into the appropriate frame in subframes. For example: 175 # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]] 176 frame_selector = array_ops.reshape( 177 math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1]) 178 179 # subframe_selector is a [num_frames, subframes_per_frame] tensor 180 # that indexes into the appropriate subframe within a frame. For example: 181 # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]] 182 subframe_selector = array_ops.reshape( 183 math_ops.range(subframes_per_frame), [1, subframes_per_frame]) 184 185 # Adding the 2 selector tensors together produces a [num_frames, 186 # subframes_per_frame] tensor of indices to use with tf.gather to select 187 # subframes from subframes. We then reshape the inner-most 188 # subframes_per_frame dimension to stitch the subframes together into 189 # frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]]. 190 selector = frame_selector + subframe_selector 191 192 frames = array_ops.reshape( 193 array_ops.gather(subframes, selector, axis=axis), 194 array_ops.concat([outer_dimensions, [num_frames, frame_length], 195 inner_dimensions], 0)) 196 197 if result_shape: 198 frames.set_shape(result_shape) 199 return frames 200