1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ 18 19 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" 20 #include "tensorflow/core/grappler/optimizers/constant_folding.h" 21 #include "tensorflow/core/grappler/optimizers/model_pruner.h" 22 #include "tensorflow/core/grappler/utils/grappler_test.h" 23 #include "tensorflow/core/lib/core/status_test_util.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 28 class ArithmeticOptimizerTest : public GrapplerTest { 29 protected: 30 // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no 31 // longer have any output consumers. OptimizeAndPrune(ArithmeticOptimizer * optimizer,GrapplerItem * item,GraphDef * output)32 void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, 33 GraphDef* output) { 34 TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); 35 item->graph.Swap(output); 36 output->Clear(); 37 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); 38 } 39 40 // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. OptimizeTwice(ArithmeticOptimizer * optimizer,GrapplerItem * item,GraphDef * output)41 void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item, 42 GraphDef* output) { 43 TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); 44 item->graph.Swap(output); 45 output->Clear(); 46 TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); 47 } 48 49 // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. 50 // Optionally run a constant folding pass before pruning. 51 void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, 52 GraphDef* output, bool const_folding = false) { 53 TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); 54 55 item->graph.Swap(output); 56 output->Clear(); 57 TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); 58 59 if (const_folding) { 60 item->graph.Swap(output); 61 output->Clear(); 62 TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) 63 .Optimize(nullptr, *item, output)); 64 } 65 66 item->graph.Swap(output); 67 output->Clear(); 68 TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); 69 } 70 71 // TODO(ezhulenev): Make private. After migration to stages each test 72 // should explicitly enable required optimization for tests isolation DisableAllStages(ArithmeticOptimizer * optimizer)73 void DisableAllStages(ArithmeticOptimizer* optimizer) { 74 ArithmeticOptimizer::ArithmeticOptimizerOptions options; 75 options.dedup_computations = false; 76 options.combine_add_to_addn = false; 77 options.convert_sqrt_div_to_rsqrt_mul = false; 78 options.convert_pow = false; 79 options.convert_log1p = false; 80 options.optimize_max_or_min_of_monotonic = false; 81 options.fold_conjugate_into_transpose = false; 82 options.fold_multiply_into_conv = false; 83 options.fold_transpose_into_matmul = false; 84 options.hoist_common_factor_out_of_aggregation = false; 85 options.hoist_cwise_unary_chains = false; 86 options.minimize_broadcasts = false; 87 options.remove_identity_transpose = false; 88 options.remove_involution = false; 89 options.remove_idempotent = false; 90 options.remove_redundant_bitcast = false; 91 options.remove_redundant_cast = false; 92 options.remove_redundant_reshape = false; 93 options.remove_negation = false; 94 options.remove_logical_not = false; 95 options.reorder_cast_like_and_value_preserving = false; 96 options.replace_mul_with_square = false; 97 options.simplify_aggregation = false; 98 options.unary_ops_composition = false; 99 optimizer->options_ = options; 100 } 101 DisableAddToAddNCombining(ArithmeticOptimizer * optimizer)102 void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { 103 optimizer->options_.combine_add_to_addn = false; 104 } 105 EnableOnlyAddToAddNCombining(ArithmeticOptimizer * optimizer)106 void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { 107 DisableAllStages(optimizer); 108 optimizer->options_.combine_add_to_addn = true; 109 } 110 EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer * optimizer)111 void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) { 112 DisableAllStages(optimizer); 113 optimizer->options_.fold_conjugate_into_transpose = true; 114 } 115 EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer * optimizer)116 void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) { 117 DisableAllStages(optimizer); 118 optimizer->options_.fold_multiply_into_conv = true; 119 } 120 EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer * optimizer)121 void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) { 122 DisableAllStages(optimizer); 123 optimizer->options_.fold_transpose_into_matmul = true; 124 } 125 EnableOnlyHoistCommonFactor(ArithmeticOptimizer * optimizer)126 void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { 127 DisableAllStages(optimizer); 128 optimizer->options_.hoist_common_factor_out_of_aggregation = true; 129 } 130 EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer * optimizer)131 void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) { 132 DisableAllStages(optimizer); 133 optimizer->options_.minimize_broadcasts = true; 134 } 135 EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer * optimizer)136 void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) { 137 DisableAllStages(optimizer); 138 optimizer->options_.remove_identity_transpose = true; 139 } 140 EnableOnlyRemoveInvolution(ArithmeticOptimizer * optimizer)141 void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { 142 DisableAllStages(optimizer); 143 optimizer->options_.remove_involution = true; 144 } 145 EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer * optimizer)146 void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { 147 DisableAllStages(optimizer); 148 optimizer->options_.remove_redundant_bitcast = true; 149 } 150 EnableOnlyRemoveRedundantCast(ArithmeticOptimizer * optimizer)151 void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) { 152 DisableAllStages(optimizer); 153 optimizer->options_.remove_redundant_cast = true; 154 } 155 EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer * optimizer)156 void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { 157 DisableAllStages(optimizer); 158 optimizer->options_.remove_redundant_reshape = true; 159 } 160 EnableOnlyRemoveNegation(ArithmeticOptimizer * optimizer)161 void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { 162 DisableAllStages(optimizer); 163 optimizer->options_.remove_negation = true; 164 } 165 EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer * optimizer)166 void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { 167 DisableAllStages(optimizer); 168 optimizer->options_.reorder_cast_like_and_value_preserving = true; 169 } 170 EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer * optimizer)171 void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { 172 DisableAllStages(optimizer); 173 optimizer->options_.replace_mul_with_square = true; 174 } 175 EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer * optimizer)176 void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { 177 DisableAllStages(optimizer); 178 optimizer->options_.hoist_cwise_unary_chains = true; 179 } 180 EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer * optimizer)181 void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) { 182 DisableAllStages(optimizer); 183 optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; 184 } 185 EnableOnlyLogSoftmax(ArithmeticOptimizer * optimizer)186 void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) { 187 DisableAllStages(optimizer); 188 optimizer->options_.convert_log_softmax = true; 189 } 190 EnableOnlyConvertPow(ArithmeticOptimizer * optimizer)191 void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { 192 DisableAllStages(optimizer); 193 optimizer->options_.convert_pow = true; 194 } 195 EnableOnlyFuseSquaredDiff(ArithmeticOptimizer * optimizer)196 void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) { 197 DisableAllStages(optimizer); 198 optimizer->options_.fuse_squared_diff = true; 199 } 200 EnableOnlyRemoveIdempotent(ArithmeticOptimizer * optimizer)201 void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { 202 DisableAllStages(optimizer); 203 optimizer->options_.remove_idempotent = true; 204 } 205 EnableOnlyRemoveLogicalNot(ArithmeticOptimizer * optimizer)206 void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { 207 DisableAllStages(optimizer); 208 optimizer->options_.remove_logical_not = true; 209 } 210 EnableOnlySimplifyAggregation(ArithmeticOptimizer * optimizer)211 void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { 212 DisableAllStages(optimizer); 213 optimizer->options_.simplify_aggregation = true; 214 } 215 EnableOnlyLog1p(ArithmeticOptimizer * optimizer)216 void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) { 217 DisableAllStages(optimizer); 218 optimizer->options_.convert_log1p = true; 219 } 220 EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer * optimizer)221 void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { 222 DisableAllStages(optimizer); 223 optimizer->options_.optimize_max_or_min_of_monotonic = true; 224 } 225 EnableOnlyExpm1(ArithmeticOptimizer * optimizer)226 void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { 227 DisableAllStages(optimizer); 228 optimizer->options_.convert_expm1 = true; 229 } 230 EnableOnlyUnaryOpsComposition(ArithmeticOptimizer * optimizer)231 void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { 232 DisableAllStages(optimizer); 233 optimizer->options_.unary_ops_composition = true; 234 } 235 EnableOnlyRemoveStackStridedSliceSameAxis(ArithmeticOptimizer * optimizer)236 void EnableOnlyRemoveStackStridedSliceSameAxis( 237 ArithmeticOptimizer* optimizer) { 238 DisableAllStages(optimizer); 239 optimizer->options_.remove_stack_strided_slice_same_axis = true; 240 } 241 }; 242 243 } // end namespace grappler 244 } // end namespace tensorflow 245 246 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ 247