1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for strip_pruning_vars.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import re 21 22from tensorflow.contrib.model_pruning.python import pruning 23from tensorflow.contrib.model_pruning.python import strip_pruning_vars_lib 24from tensorflow.contrib.model_pruning.python.layers import layers 25from tensorflow.contrib.model_pruning.python.layers import rnn_cells 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import graph_util 28from tensorflow.python.framework import importer 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import random_ops 32from tensorflow.python.ops import rnn 33from tensorflow.python.ops import rnn_cell as tf_rnn_cells 34from tensorflow.python.ops import state_ops 35from tensorflow.python.ops import variable_scope 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import test 38from tensorflow.python.training import training_util 39 40 41def _get_number_pruning_vars(graph_def): 42 number_vars = 0 43 for node in graph_def.node: 44 if re.match(r"^.*(mask$)|(threshold$)", node.name): 45 number_vars += 1 46 return number_vars 47 48 49def _get_node_names(tensor_names): 50 return [ 51 strip_pruning_vars_lib._node_name(tensor_name) 52 for tensor_name in tensor_names 53 ] 54 55 56class StripPruningVarsTest(test.TestCase): 57 58 def setUp(self): 59 param_list = [ 60 "pruning_frequency=1", "begin_pruning_step=1", "end_pruning_step=10", 61 "nbins=2048", "threshold_decay=0.0" 62 ] 63 self.initial_graph = ops.Graph() 64 self.initial_graph_def = None 65 self.final_graph = ops.Graph() 66 self.final_graph_def = None 67 self.pruning_spec = ",".join(param_list) 68 with self.initial_graph.as_default(): 69 self.sparsity = variables.Variable(0.5, name="sparsity") 70 self.global_step = training_util.get_or_create_global_step() 71 self.increment_global_step = state_ops.assign_add(self.global_step, 1) 72 self.mask_update_op = None 73 74 def _build_convolutional_model(self, number_of_layers): 75 # Create a graph with several conv2d layers 76 kernel_size = 3 77 base_depth = 4 78 depth_step = 7 79 height, width = 7, 9 80 with variable_scope.variable_scope("conv_model"): 81 input_tensor = array_ops.ones((8, height, width, base_depth)) 82 top_layer = input_tensor 83 for ix in range(number_of_layers): 84 top_layer = layers.masked_conv2d( 85 top_layer, 86 base_depth + (ix + 1) * depth_step, 87 kernel_size, 88 scope="Conv_" + str(ix)) 89 90 return top_layer 91 92 def _build_fully_connected_model(self, number_of_layers): 93 base_depth = 4 94 depth_step = 7 95 96 input_tensor = array_ops.ones((8, base_depth)) 97 98 top_layer = input_tensor 99 100 with variable_scope.variable_scope("fc_model"): 101 for ix in range(number_of_layers): 102 top_layer = layers.masked_fully_connected( 103 top_layer, base_depth + (ix + 1) * depth_step) 104 105 return top_layer 106 107 def _build_lstm_model(self, number_of_layers): 108 batch_size = 8 109 dim = 10 110 inputs = variables.Variable(random_ops.random_normal([batch_size, dim])) 111 112 def lstm_cell(): 113 return rnn_cells.MaskedBasicLSTMCell( 114 dim, forget_bias=0.0, state_is_tuple=True, reuse=False) 115 116 cell = tf_rnn_cells.MultiRNNCell( 117 [lstm_cell() for _ in range(number_of_layers)], state_is_tuple=True) 118 119 outputs = rnn.static_rnn( 120 cell, [inputs], 121 initial_state=cell.zero_state(batch_size, dtypes.float32)) 122 123 return outputs 124 125 def _prune_model(self, session): 126 pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) 127 p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) 128 self.mask_update_op = p.conditional_mask_update_op() 129 130 variables.global_variables_initializer().run() 131 for _ in range(20): 132 session.run(self.mask_update_op) 133 session.run(self.increment_global_step) 134 135 def _get_outputs(self, session, input_graph, tensors_list, graph_prefix=None): 136 outputs = [] 137 138 for output_tensor in tensors_list: 139 if graph_prefix: 140 output_tensor = graph_prefix + "/" + output_tensor 141 outputs.append( 142 session.run(session.graph.get_tensor_by_name(output_tensor))) 143 144 return outputs 145 146 def _get_initial_outputs(self, output_tensor_names_list): 147 with self.session(graph=self.initial_graph) as sess1: 148 self._prune_model(sess1) 149 reference_outputs = self._get_outputs(sess1, self.initial_graph, 150 output_tensor_names_list) 151 152 self.initial_graph_def = graph_util.convert_variables_to_constants( 153 sess1, sess1.graph.as_graph_def(), 154 _get_node_names(output_tensor_names_list)) 155 return reference_outputs 156 157 def _get_final_outputs(self, output_tensor_names_list): 158 self.final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn( 159 self.initial_graph_def, _get_node_names(output_tensor_names_list)) 160 _ = importer.import_graph_def(self.final_graph_def, name="final") 161 162 with self.test_session(self.final_graph) as sess2: 163 final_outputs = self._get_outputs( 164 sess2, 165 self.final_graph, 166 output_tensor_names_list, 167 graph_prefix="final") 168 return final_outputs 169 170 def _check_removal_of_pruning_vars(self, number_masked_layers): 171 self.assertEqual( 172 _get_number_pruning_vars(self.initial_graph_def), number_masked_layers) 173 self.assertEqual(_get_number_pruning_vars(self.final_graph_def), 0) 174 175 def _check_output_equivalence(self, initial_outputs, final_outputs): 176 for initial_output, final_output in zip(initial_outputs, final_outputs): 177 self.assertAllEqual(initial_output, final_output) 178 179 def testConvolutionalModel(self): 180 with self.initial_graph.as_default(): 181 number_masked_conv_layers = 5 182 top_layer = self._build_convolutional_model(number_masked_conv_layers) 183 output_tensor_names = [top_layer.name] 184 initial_outputs = self._get_initial_outputs(output_tensor_names) 185 186 # Remove pruning-related nodes. 187 with self.final_graph.as_default(): 188 final_outputs = self._get_final_outputs(output_tensor_names) 189 190 # Check that the final graph has no pruning-related vars 191 self._check_removal_of_pruning_vars(number_masked_conv_layers) 192 193 # Check that outputs remain the same after removal of pruning-related nodes 194 self._check_output_equivalence(initial_outputs, final_outputs) 195 196 def testFullyConnectedModel(self): 197 with self.initial_graph.as_default(): 198 number_masked_fc_layers = 3 199 top_layer = self._build_fully_connected_model(number_masked_fc_layers) 200 output_tensor_names = [top_layer.name] 201 initial_outputs = self._get_initial_outputs(output_tensor_names) 202 203 # Remove pruning-related nodes. 204 with self.final_graph.as_default(): 205 final_outputs = self._get_final_outputs(output_tensor_names) 206 207 # Check that the final graph has no pruning-related vars 208 self._check_removal_of_pruning_vars(number_masked_fc_layers) 209 210 # Check that outputs remain the same after removal of pruning-related nodes 211 self._check_output_equivalence(initial_outputs, final_outputs) 212 213 def testLSTMModel(self): 214 with self.initial_graph.as_default(): 215 number_masked_lstm_layers = 2 216 outputs = self._build_lstm_model(number_masked_lstm_layers) 217 output_tensor_names = [outputs[0][0].name] 218 initial_outputs = self._get_initial_outputs(output_tensor_names) 219 220 # Remove pruning-related nodes. 221 with self.final_graph.as_default(): 222 final_outputs = self._get_final_outputs(output_tensor_names) 223 224 # Check that the final graph has no pruning-related vars 225 self._check_removal_of_pruning_vars(number_masked_lstm_layers) 226 227 # Check that outputs remain the same after removal of pruning-related nodes 228 self._check_output_equivalence(initial_outputs, final_outputs) 229 230 231if __name__ == "__main__": 232 test.main() 233