• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/core/framework/device_base.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/resource_mgr.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/grappler/costs/graph_properties.h"
27 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
34 const char kConstantFoldingConst[] = "ConstantFolding";
35 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
36 extern const int64 kMaxConstantSize;
37 
38 // Constant folding optimization for a graph.
39 class ConstantFolding : public GraphOptimizer {
40  public:
41   // The size limit will only be considered if the newly created node is greater
42   // than original_size (optional).
43   static Status CreateNodeDef(const string& name, const TensorValue& tensor,
44                               NodeDef* node, size_t original_size = 0);
45   static string AddControlDependency(const string& input_name, GraphDef* graph,
46                                      NodeMap* node_map);
47 
48   explicit ConstantFolding(DeviceBase* cpu_device,
49                            bool disable_compressed_tensor_optimization = false,
50                            bool fold_quantization_emulation = true);
51   ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device,
52                   bool disable_compressed_tensor_optimization = false,
53                   bool fold_quantization_emulation = true);
54 
~ConstantFolding()55   ~ConstantFolding() override {}
56 
name()57   string name() const override { return "constant_folding"; };
58 
UsesFunctionLibrary()59   bool UsesFunctionLibrary() const override { return false; }
60 
61   Status Optimize(Cluster* cluster, const GrapplerItem& item,
62                   GraphDef* output) override;
63 
64   void Feedback(Cluster* cluster, const GrapplerItem& item,
65                 const GraphDef& optimize_output, double result) override;
66 
67  private:
68   bool ForwardInputs(NodeDef* node, absl::Span<const int> inputs_to_forward);
69   string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
70   bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
71 
72   bool IsReallyConstant(const NodeDef& node) const;
73 
74   bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
75 
76   Status MaterializeShapes(const GraphProperties& properties);
77 
78   Status MaterializeBroadcastGradientArgs(const NodeDef& node,
79                                           const GraphProperties& properties);
80   Status MaterializeReductionIndices(NodeDef* node,
81                                      const GraphProperties& properties);
82   Status MaterializeConstantValuedNode(NodeDef* node,
83                                        const GraphProperties& properties);
84   Status MaterializeOutputValues(NodeDef* node,
85                                  const GraphProperties& properties);
86   Status MaterializeConstants(const GraphProperties& properties);
87 
88   bool IsFoldable(const NodeDef& node, const GraphProperties* properties);
89   bool IsFoldableUncached(const NodeDef& node,
90                           const GraphProperties* properties) const;
91   bool MaybeFoldable(const NodeDef& node,
92                      const GraphProperties* properties) const;
93 
94   Status EvaluateNode(const NodeDef& node,
95                       const gtl::InlinedVector<TensorValue, 4>& inputs,
96                       gtl::InlinedVector<TensorValue, 4>* output) const;
97 
98   Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs,
99                              bool* result_too_large);
100 
101   Status FoldMergeNode(NodeDef* node, GraphDef* output_graph);
102   Status FoldNode(NodeDef* node, GraphDef* output_graph,
103                   bool* result_too_large);
104 
105   bool IsOnes(const NodeDef& node) const;
106   bool IsZeros(const NodeDef& node) const;
107   bool ReplaceOperationWithBroadcastTo(int input_to_broadcast,
108                                        const GraphProperties& properties,
109                                        NodeDef* node, GraphDef* graph);
110   void ReplaceOperationWithIdentity(int input_to_forward,
111                                     const GraphProperties& properties,
112                                     NodeDef* node, GraphDef* graph);
113   void ReplaceOperationWithSnapshot(int input_to_forward,
114                                     const GraphProperties& properties,
115                                     NodeDef* node, GraphDef* graph);
116   void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties,
117                                 GraphDef* graph);
118   void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
119                                              const GraphProperties& properties,
120                                              NodeDef* node, GraphDef* graph);
121   void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
122   Status ReplaceOperationWithConstant(double value,
123                                       const GraphProperties& properties,
124                                       const TensorShapeProto& shape,
125                                       NodeDef* node, GraphDef* graph);
126 
127   // Notice: Destroys *value.
128   Status ReplaceOperationWithConstantTensor(DataType dtype, TensorProto* value,
129                                             NodeDef* node, GraphDef* graph);
130 
131   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
132   Status FoldGraph(const GraphProperties& properties, GraphDef* output,
133                    absl::flat_hash_set<string>* nodes_to_not_simplify);
134 
135   bool IsSimplifiableReshape(const NodeDef& node,
136                              const GraphProperties& properties) const;
137   Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
138                        GraphProperties* properties,
139                        absl::flat_hash_set<string>* nodes_to_not_simplify);
140   Status SimplifyNode(bool use_shape_info, NodeDef* node,
141                       GraphDef* optimized_graph, GraphProperties* properties);
142 
143   Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item,
144                              GraphDef* optimized_graph);
145 
146   // Applies partial constant folding for Concat which is not commutative.
147   // Returns true if the transformation applied successfully.
148   bool PartialConcatConstFolding(GraphDef* optimized_graph,
149                                  GraphProperties* properties, NodeDef* node);
150 
151   // Applies partial constant folding for associative operators AddN and
152   // AccumulateNV2. Returns true if the transformation applied successfully.
153   bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
154                                   GraphProperties* properties, NodeDef* node);
155 
156   // Applies partial constant propagation through IdentityN operator.
157   // Returns true if the transformation applied successfully.
158   bool PartialConstPropThroughIdentityN(NodeDef* node);
159 
160   struct ConstantPushDownContext {
161     NodeDef* op_child;
162     NodeDef* const_child;
163     bool left_child_is_const;
164     bool right_child_is_const;
165     NodeDef* left_leaf;
166     NodeDef* right_leaf;
167     bool left_leaf_is_const;
168     bool right_leaf_is_const;
169 
170     // Shape & type information.
171     const std::vector<OpInfo::TensorProperties>* parent_input_props;
172     const std::vector<OpInfo::TensorProperties>* op_child_input_props;
173   };
174 
175   // Populates ctx with pointers to the nodes in expression tree for which
176   // constant pushdown optimization is being considered, corresponding to one of
177   // the following configurations:
178   //
179   //               parent                            parent
180   //               /    \                            /    \
181   //        op_child   const_child            const_child op_child
182   //         /     \                                       /     \
183   //    left_leaf  right_leaf                        left_leaf  right_leaf
184   //
185   // Returns true if the expression is possible amenable for optimization.
186   // Returns false if must_have_properties is true and input properties for
187   // parent and op_child are not known.
188   bool PrepareConstantPushDown(const NodeDef& parent,
189                                const GraphProperties& properties,
190                                bool must_have_properties,
191                                ConstantPushDownContext* ctx) const;
192 
193   // Pushes down constants on '+', '-', '*', and '/' operators if applicable.
194   // Returns true if the transformation applied successfully.
195   bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph,
196                         NodeDef* node);
197 
198   // Pushes down constants on '+' and 'BiasAdd' operators if applicable.
199   // Returns true if the graph was modified.
200   bool ConstantPushDownBiasAdd(GraphProperties* properties,
201                                GraphDef* optimized_graph, NodeDef* node);
202 
203   // Aggregate constants present around a conv operator. Returns true if the
204   // transformation was applied successfully.
205   bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
206                        const GraphProperties& properties);
207 
208   // Strength reduces floating point division by a constant Div(x, const) to
209   // multiplication by the reciprocal Mul(x, Reciprocal(const)).
210   bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
211 
212   // Simplifies arithmetic operations with ones or zeros. Returns the status,
213   // and updates the success input argument that denotes if any simplification
214   // was applied.
215   Status SimplifyArithmeticOperations(const GraphProperties& properties,
216                                       bool use_shape_info,
217                                       GraphDef* optimized_graph, NodeDef* node);
218 
219   // Simplifies a Reshape operation to an Identity operation if applicable.
220   bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
221                        NodeDef* node);
222 
223   // Returns true iff the node is a reduction and its reduction indices are
224   // constant. Sets *indices_is_empty to true if the set of dimensions to reduce
225   // along is empty (this happens often in the gradient graphs).
226   bool IsReductionWithConstantIndices(const NodeDef& node,
227                                       bool* indices_is_empty) const;
228   // Returns true if theres a possibility that a Reduce node could be simplified
229   // to an Identity/Reshape.
230   bool IsReductionCandidateForSimplification(
231       const NodeDef& node, const GraphProperties& properties,
232       TensorShapeProto* input_tensor_shape,
233       TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const;
234   // Returns true iff this reduction can be reduced to an identity (i.e if the
235   // input dimensions to reduce along are all of size 1 and keep_dims is true).
236   bool IsReductionSimplifiableToIdentity(
237       const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
238       const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
239   // Changes a reduction into an Identity op, returning true on success.
240   bool ReplaceReductionWithIdentity(NodeDef* node) const;
241 
242   // Simplifies a Reduction operation to an Identity/Reshape operation if
243   // applicable.
244   bool SimplifyReduction(GraphDef* optimized_graph,
245                          const GraphProperties& properties, NodeDef* node);
246 
247   // Switch(x, x) will always feed false to its false branch and true to
248   // its true branch. By rewriting the graph a bit, we can propagate these
249   // constants down the two output branches, and just use control dependencies
250   // to trigger the selected one at runtime. For example,
251   //
252   //     +------+
253   // x-->|Switch|-->a  (in practice there may be multiple consumers of each
254   // x-->|      |-->b   output branch.)
255   //     +------+
256   //
257   // Is rewritten as
258   //
259   //     +------+
260   // x-->|Switch|-->Identity--^>Const(false)-->a
261   // x-->|      |-->Identity--^>Const(true)-->b
262   //     +------+
263   bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);
264 
265   // Moves constants past Enter node if applicable.
266   bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node);
267 
268   // Simplifies Pack operation if applicable.
269   bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
270 
271   // Simplifies a Squeeze operation to an Identity operation if applicable.
272   void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
273                        GraphDef* optimized_graph, NodeDef* node);
274 
275   // Simplifies a Pad operation to an Identity operation if applicable.
276   Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
277                      GraphDef* optimized_graph, NodeDef* node);
278 
279   // Simplifies a Tile operation to an Identity operation if applicable.
280   Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
281                       GraphDef* optimized_graph, NodeDef* node);
282 
283   // Simplifies a StridedSlice operation to an Identity operation if applicable.
284   Status SimplifyStridedSlice(const GraphProperties& properties,
285                               bool use_shape_info, GraphDef* optimized_graph,
286                               NodeDef* node);
287 
288   // Simplifies a Slice operation to an Identity operation if applicable.
289   Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
290                        GraphDef* optimized_graph, NodeDef* node);
291 
292   // Simplify a Case operation where the output_idx is known.
293   bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node);
294 
295   // Simplify a Select operation where the predicates are all true or all false.
296   bool SimplifySelect(const GraphProperties& properties,
297                       GraphDef* optimized_graph, NodeDef* node);
298 
299   // Replaces variable updates that are effectively no-ops with NoOp nodes.
300   void RemoveRedundantVariableUpdates(GraphProperties* properties,
301                                       GraphDef* optimized_graph, NodeDef* node);
302 
303   // Removes Reverse op over dimensions with size 1.
304   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
305                        GraphDef* optimized_graph, NodeDef* node);
306 
307   // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
308   void RemoveRandomShuffle(const GraphProperties& properties,
309                            bool use_shape_info, GraphDef* optimized_graph,
310                            NodeDef* node);
311 
312   // Removes Shuffle or Transpose op over dimensions of size 1.
313   Status RemoveShuffleOrTranspose(const GraphProperties& properties,
314                                   bool use_shape_info,
315                                   GraphDef* optimized_graph, NodeDef* node);
316 
317   // Removes Split or SplitV node if possible.
318   void RemoveSplitOrSplitV(const GraphProperties& properties,
319                            GraphDef* optimized_graph, NodeDef* node);
320 
321   bool GetConcatAxis(const NodeDef& node, int* axis);
322   bool MergeConcat(bool use_shape_info, GraphProperties* properties,
323                    GraphDef* optimized_graph, NodeDef* node);
324 
325   Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
326                                                GraphDef* optimized_graph);
327 
328   // Points to an externally provided device or to owned_device_;
329   RewriterConfig::Toggle opt_level_;
330   DeviceBase* cpu_device_;
331   std::unique_ptr<DeviceBase> owned_device_;
332 
333   std::unique_ptr<ResourceMgr> resource_mgr_;
334   GraphDef* graph_;
335   std::unique_ptr<NodeMap> node_map_;
336   std::unordered_set<string> nodes_to_preserve_;
337   // TODO(rmlarsen): Could these be keyed on absl::string_view?
338   absl::flat_hash_set<string> nodes_allowlist_;
339   absl::flat_hash_set<string> feed_nodes_;
340   absl::flat_hash_map<string, bool> maybe_foldable_nodes_;
341   bool has_fetch_;
342   bool graph_modified_;
343   bool graph_contains_assign_or_inplace_op_;
344   bool disable_compressed_tensor_optimization_;
345   bool fold_quantization_emulation_;
346 };
347 
348 }  // end namespace grappler
349 }  // end namespace tensorflow
350 
351 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
352