1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Reversible residual network compatible with eager execution. 16 17Customized basic operations. 18 19Reference [The Reversible Residual Network: Backpropagation 20Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import tensorflow as tf 27 28 29def downsample(x, filters, strides, axis=1): 30 """Downsample feature map with avg pooling, if filter size doesn't match.""" 31 32 def pad_strides(strides, axis=1): 33 """Convert length 2 to length 4 strides. 34 35 Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations 36 such as `tf.nn.avg_pool` use length 4 strides. 37 38 Args: 39 strides: length 2 list/tuple strides for height and width 40 axis: integer specifying feature dimension according to data format 41 Returns: 42 length 4 strides padded with 1 on batch and channel dimension 43 """ 44 45 assert len(strides) == 2 46 47 if axis == 1: 48 return [1, 1, strides[0], strides[1]] 49 return [1, strides[0], strides[1], 1] 50 51 assert len(x.shape) == 4 and (axis == 1 or axis == 3) 52 53 data_format = "NCHW" if axis == 1 else "NHWC" 54 strides_ = pad_strides(strides, axis=axis) 55 56 if strides[0] > 1: 57 x = tf.nn.avg_pool( 58 x, strides_, strides_, padding="VALID", data_format=data_format) 59 60 in_filter = x.shape[axis] 61 out_filter = filters 62 63 if in_filter < out_filter: 64 pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2] 65 if axis == 1: 66 x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]]) 67 else: 68 x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size]) 69 # In case `tape.gradient(x, [x])` produces a list of `None` 70 return x + 0. 71