• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
18 
19 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
20 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
21 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
22 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
23 #include "tensorflow/core/grappler/utils/grappler_test.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace tensorflow {
27 namespace grappler {
28 
29 class ArithmeticOptimizerTest : public GrapplerTest {
30  protected:
31   // Optimize a graph using optimizer and prune all the nodes that no
32   // longer have any output consumers.
OptimizeAndPrune(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)33   void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
34                         GraphDef* output) {
35     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
36     item->graph.Swap(output);
37     output->Clear();
38     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
39   }
40 
41   // Run optimizer twice to make sure the rewrite is idempotent.
DedupAndOptimizeTwiceAndPrune(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)42   void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer,
43                                      GrapplerItem* item, GraphDef* output) {
44     TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output));
45     item->graph.Swap(output);
46     output->Clear();
47     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
48     item->graph.Swap(output);
49     output->Clear();
50     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
51     item->graph.Swap(output);
52     output->Clear();
53     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
54   }
55 
56   // Run optimizer twice to make sure the rewrite is idempotent.
OptimizeTwice(GraphOptimizer * optimizer,GrapplerItem * item,GraphDef * output)57   void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item,
58                      GraphDef* output) {
59     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
60     item->graph.Swap(output);
61     output->Clear();
62     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
63   }
64 
65   // Run optimizer twice to make sure the rewrite is idempotent.
66   // Optionally run a constant folding pass before pruning.
67   void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
68                              GraphDef* output, bool const_folding = false) {
69     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
70 
71     item->graph.Swap(output);
72     output->Clear();
73     TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
74 
75     if (const_folding) {
76       item->graph.Swap(output);
77       output->Clear();
78       TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
79                        .Optimize(nullptr, *item, output));
80     }
81 
82     item->graph.Swap(output);
83     output->Clear();
84     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
85   }
86 
DisableAddToAddNCombining(ArithmeticOptimizer * optimizer)87   void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
88     optimizer->options_.combine_add_to_addn = false;
89   }
90 
EnableOnlyAddToAddNCombining(ArithmeticOptimizer * optimizer)91   void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
92     DisableAllStages(optimizer);
93     optimizer->options_.combine_add_to_addn = true;
94   }
95 
EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer * optimizer)96   void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
97     DisableAllStages(optimizer);
98     optimizer->options_.fold_conjugate_into_transpose = true;
99   }
100 
EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer * optimizer)101   void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
102     DisableAllStages(optimizer);
103     optimizer->options_.fold_multiply_into_conv = true;
104   }
105 
EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer * optimizer)106   void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
107     DisableAllStages(optimizer);
108     optimizer->options_.fold_transpose_into_matmul = true;
109   }
110 
EnableOnlyHoistCommonFactor(ArithmeticOptimizer * optimizer)111   void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
112     DisableAllStages(optimizer);
113     optimizer->options_.hoist_common_factor_out_of_aggregation = true;
114   }
115 
EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer * optimizer)116   void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
117     DisableAllStages(optimizer);
118     optimizer->options_.minimize_broadcasts = true;
119   }
120 
EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer * optimizer)121   void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
122     DisableAllStages(optimizer);
123     optimizer->options_.remove_identity_transpose = true;
124   }
125 
EnableOnlyRemoveInvolution(ArithmeticOptimizer * optimizer)126   void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) {
127     DisableAllStages(optimizer);
128     optimizer->options_.remove_involution = true;
129   }
130 
EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer * optimizer)131   void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) {
132     DisableAllStages(optimizer);
133     optimizer->options_.remove_redundant_bitcast = true;
134   }
135 
EnableOnlyRemoveRedundantCast(ArithmeticOptimizer * optimizer)136   void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) {
137     DisableAllStages(optimizer);
138     optimizer->options_.remove_redundant_cast = true;
139   }
140 
EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer * optimizer)141   void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
142     DisableAllStages(optimizer);
143     optimizer->options_.remove_redundant_reshape = true;
144   }
145 
EnableOnlyRemoveNegation(ArithmeticOptimizer * optimizer)146   void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
147     DisableAllStages(optimizer);
148     optimizer->options_.remove_negation = true;
149   }
150 
EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer * optimizer)151   void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
152     DisableAllStages(optimizer);
153     optimizer->options_.reorder_cast_like_and_value_preserving = true;
154   }
155 
EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer * optimizer)156   void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
157     DisableAllStages(optimizer);
158     optimizer->options_.replace_mul_with_square = true;
159   }
160 
EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer * optimizer)161   void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
162     DisableAllStages(optimizer);
163     optimizer->options_.hoist_cwise_unary_chains = true;
164   }
165 
EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer * optimizer)166   void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) {
167     DisableAllStages(optimizer);
168     optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
169   }
170 
EnableOnlyLogSoftmax(ArithmeticOptimizer * optimizer)171   void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) {
172     DisableAllStages(optimizer);
173     optimizer->options_.convert_log_softmax = true;
174   }
175 
EnableOnlyConvertPow(ArithmeticOptimizer * optimizer)176   void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
177     DisableAllStages(optimizer);
178     optimizer->options_.convert_pow = true;
179   }
180 
EnableOnlyFuseSquaredDiff(ArithmeticOptimizer * optimizer)181   void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) {
182     DisableAllStages(optimizer);
183     optimizer->options_.fuse_squared_diff = true;
184   }
185 
EnableOnlyRemoveIdempotent(ArithmeticOptimizer * optimizer)186   void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
187     DisableAllStages(optimizer);
188     optimizer->options_.remove_idempotent = true;
189   }
190 
EnableOnlyRemoveLogicalNot(ArithmeticOptimizer * optimizer)191   void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) {
192     DisableAllStages(optimizer);
193     optimizer->options_.remove_logical_not = true;
194   }
195 
EnableOnlySimplifyAggregation(ArithmeticOptimizer * optimizer)196   void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
197     DisableAllStages(optimizer);
198     optimizer->options_.simplify_aggregation = true;
199   }
200 
EnableOnlyLog1p(ArithmeticOptimizer * optimizer)201   void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
202     DisableAllStages(optimizer);
203     optimizer->options_.convert_log1p = true;
204   }
205 
EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer * optimizer)206   void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
207     DisableAllStages(optimizer);
208     optimizer->options_.optimize_max_or_min_of_monotonic = true;
209   }
210 
EnableOnlyExpm1(ArithmeticOptimizer * optimizer)211   void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
212     DisableAllStages(optimizer);
213     optimizer->options_.convert_expm1 = true;
214   }
215 
EnableOnlyUnaryOpsComposition(ArithmeticOptimizer * optimizer)216   void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
217     DisableAllStages(optimizer);
218     optimizer->options_.unary_ops_composition = true;
219   }
220 
EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer * optimizer)221   void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) {
222     DisableAllStages(optimizer);
223     optimizer->options_.remove_stack_slice_same_axis = true;
224   }
225 
EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer * optimizer)226   void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) {
227     DisableAllStages(optimizer);
228     optimizer->options_.simplify_embedding_lookup = true;
229   }
230 
EnableOnlyRemoveCastIntoSegmentReduction(ArithmeticOptimizer * optimizer)231   void EnableOnlyRemoveCastIntoSegmentReduction(
232       ArithmeticOptimizer* optimizer) {
233     DisableAllStages(optimizer);
234     optimizer->options_.remove_cast_into_segment_reduction = true;
235   }
236 
237  private:
DisableAllStages(ArithmeticOptimizer * optimizer)238   void DisableAllStages(ArithmeticOptimizer* optimizer) {
239     ArithmeticOptimizer::ArithmeticOptimizerOptions options;
240     options.dedup_computations = false;
241     options.combine_add_to_addn = false;
242     options.convert_sqrt_div_to_rsqrt_mul = false;
243     options.convert_pow = false;
244     options.convert_log1p = false;
245     options.optimize_max_or_min_of_monotonic = false;
246     options.fold_conjugate_into_transpose = false;
247     options.fold_multiply_into_conv = false;
248     options.fold_transpose_into_matmul = false;
249     options.hoist_common_factor_out_of_aggregation = false;
250     options.hoist_cwise_unary_chains = false;
251     options.minimize_broadcasts = false;
252     options.remove_identity_transpose = false;
253     options.remove_involution = false;
254     options.remove_idempotent = false;
255     options.remove_redundant_bitcast = false;
256     options.remove_redundant_cast = false;
257     options.remove_redundant_reshape = false;
258     options.remove_negation = false;
259     options.remove_logical_not = false;
260     options.reorder_cast_like_and_value_preserving = false;
261     options.replace_mul_with_square = false;
262     options.simplify_aggregation = false;
263     options.unary_ops_composition = false;
264     options.simplify_embedding_lookup = false;
265     options.remove_cast_into_segment_reduction = false;
266     optimizer->options_ = options;
267   }
268 };
269 
270 }  // end namespace grappler
271 }  // end namespace tensorflow
272 
273 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
274