1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for adding pruning related ops to the graph. 16""" 17# pylint: disable=missing-docstring 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import init_ops 30from tensorflow.python.ops import nn_ops 31from tensorflow.python.ops import state_ops 32from tensorflow.python.ops import variable_scope 33 34 35def weight_mask_variable(var, scope): 36 """Create a mask for the weights. 37 38 This function adds a variable 'mask' to the graph. 39 40 Args: 41 var: the weight variable that needs to be masked 42 scope: The variable scope of the variable var 43 44 Returns: 45 the mask variable of the same size and shape as var, initialized to all 1s. 46 """ 47 with variable_scope.variable_scope(scope): 48 mask = variable_scope.get_variable( 49 'mask', 50 var.get_shape(), 51 initializer=init_ops.ones_initializer(), 52 trainable=False, 53 dtype=var.dtype) 54 return mask 55 56 57def weight_threshold_variable(var, scope): 58 """Create a scalar threshold for the weights. 59 60 This function adds a variable 61 'threshold' to the graph. 62 63 Args: 64 var: The weight variable that needs to be masked 65 scope: The variable scope of the variable var 66 67 Returns: 68 A scalar threshold variable initialized to 0. 69 """ 70 with variable_scope.variable_scope(scope): 71 threshold = variable_scope.get_variable( 72 'threshold', [], 73 initializer=init_ops.zeros_initializer(), 74 trainable=False, 75 dtype=var.dtype) 76 return threshold 77 78 79def kronecker_product(mat1, mat2): 80 """Computes the Kronecker product of two matrices mat1 and mat2. 81 82 Args: 83 mat1: A matrix of size m x n 84 mat2: A matrix of size p x q 85 Returns: 86 Kronecker product of matrices mat1 and mat2 of size mp x nq 87 """ 88 89 m1, n1 = mat1.get_shape().as_list() 90 mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) 91 m2, n2 = mat2.get_shape().as_list() 92 mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) 93 return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) 94 95 96def expand_tensor(tensor, block_dims): 97 """Expands a 2D tensor by replicating the tensor values. 98 99 This is equivalent to the kronecker product of the tensor and a matrix of 100 ones of size block_dims. 101 102 Example: 103 104 tensor = [[1,2] 105 [3,4]] 106 block_dims = [2,2] 107 108 result = [[1 1 2 2] 109 [1 1 2 2] 110 [3 3 4 4] 111 [3 3 4 4]] 112 113 Args: 114 tensor: A 2D tensor that needs to be expanded. 115 block_dims: List of integers specifying the expansion factor. 116 117 Returns: 118 The expanded tensor 119 120 Raises: 121 ValueError: if tensor is not rank-2 or block_dims is does not have 2 122 elements. 123 """ 124 if tensor.get_shape().ndims != 2: 125 raise ValueError('Input tensor must be rank 2') 126 127 if len(block_dims) != 2: 128 raise ValueError('block_dims must have 2 elements') 129 130 block_height, block_width = block_dims 131 132 def _tile_rows(tensor, multiple): 133 """Create a new tensor by tiling the tensor along rows.""" 134 return array_ops.tile(tensor, [multiple, 1]) 135 136 def _generate_indices(num_rows, block_dim): 137 indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32) 138 for k in range(block_dim): 139 for r in range(num_rows): 140 indices[k * num_rows + r] = r * block_dim + k 141 return indices 142 143 def _replicate_rows(tensor, multiple): 144 tensor_shape = tensor.shape.as_list() 145 expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] 146 indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple)) 147 return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple), 148 expanded_shape) 149 150 expanded_tensor = tensor 151 152 # Expand rows by factor block_height. 153 if block_height > 1: 154 expanded_tensor = _replicate_rows(tensor, block_height) 155 156 # Transpose and expand by factor block_width. Transpose the result. 157 if block_width > 1: 158 expanded_tensor = array_ops.transpose( 159 _replicate_rows(array_ops.transpose(expanded_tensor), block_width)) 160 161 return expanded_tensor 162 163 164def factorized_pool(input_tensor, 165 window_shape, 166 pooling_type, 167 strides, 168 padding, 169 name=None): 170 """Performs m x n pooling through a combination of 1xm and 1xn pooling. 171 172 Args: 173 input_tensor: Input tensor. Must be rank 2 174 window_shape: Pooling window shape 175 pooling_type: Either 'MAX' or 'AVG' 176 strides: The stride of the pooling window 177 padding: 'SAME' or 'VALID'. 178 name: Name of the op 179 180 Returns: 181 A rank 2 tensor containing the pooled output 182 183 Raises: 184 ValueError: if the input tensor is not rank 2 185 """ 186 if input_tensor.get_shape().ndims != 2: 187 raise ValueError('factorized_pool() accepts tensors of rank 2 only') 188 189 [height, width] = input_tensor.get_shape() 190 with ops.name_scope(name, 'factorized_pool'): 191 input_tensor_aligned = array_ops.reshape( 192 input_tensor, [1, 1, height, width], 193 name=input_tensor.op.name + '_aligned') 194 195 height_pooling = nn_ops.pool( 196 input_tensor_aligned, 197 window_shape=[1, window_shape[0]], 198 pooling_type=pooling_type, 199 strides=[1, strides[0]], 200 padding=padding) 201 swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2]) 202 203 width_pooling = nn_ops.pool( 204 swap_height_width, 205 window_shape=[1, window_shape[1]], 206 pooling_type=pooling_type, 207 strides=[1, strides[1]], 208 padding=padding) 209 210 return array_ops.squeeze( 211 array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]), axis=[0, 1]) 212 213 214def determine_partitioned_axis(partitioned_variable): 215 partitioned_axis = 0 216 concatenated_variable_shape = partitioned_variable.get_shape() 217 for partition in partitioned_variable: 218 partition_shape = partition.get_shape() 219 maybe_partitioned_axis = np.less(partition_shape, 220 concatenated_variable_shape) 221 # Sanity check: make sure number of partitioned axis == 1 222 if np.count_nonzero(maybe_partitioned_axis) != 1: 223 raise ValueError('Number of partitioned axes %s not equal to 1' % 224 np.count_nonzero(maybe_partitioned_axis)) 225 partitioned_axis = np.where(maybe_partitioned_axis)[0][0] 226 return partitioned_axis 227 228 229def variable_assign(var, new_value): 230 return state_ops.assign(var, new_value, name=var.op.name + '_assign') 231 232 233def partitioned_variable_assign(partitioned_var, new_value): 234 """Assign op for partitioned variables. 235 236 Args: 237 partitioned_var: A partitioned tensorflow variable 238 new_value: Value to be assigned to the variable var 239 240 Returns: 241 A tensorflow op that groups the assign ops for each of the variable slices 242 """ 243 # Determine which axis was used to partition the variable. Currently 244 # tensorflow allows partitioning variable only along 1 axis. 245 axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis( 246 partitioned_var) 247 248 partition_sizes = np.array( 249 [partition.get_shape()[axis] for partition in partitioned_var]) 250 new_partitioned_values = array_ops.split( 251 new_value, 252 ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32), 253 axis=axis) 254 op_list = [] 255 for partition in partitioned_var: 256 op_list.append( 257 variable_assign(partition, new_partitioned_values[len(op_list)])) 258 return control_flow_ops.group( 259 *op_list, name=partitioned_var.name + '_group_assign') 260