• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
18 
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/core/framework/device_base.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
26 #include "tensorflow/core/grappler/utils.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
32 const char kConstantFoldingConst[] = "ConstantFolding";
33 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl";
34 
35 // Constant folding optimization for a graph.
36 class ConstantFolding : public GraphOptimizer {
37  public:
38   // The size limit will only be considered if the newly created node is greater
39   // than original_size (optional).
40   static Status CreateNodeDef(const string& name, const TensorValue& tensor,
41                               NodeDef* node, size_t original_size = 0);
42   static string AddControlDependency(const string& input_name, GraphDef* graph,
43                                      NodeMap* node_map);
44 
45   explicit ConstantFolding(DeviceBase* cpu_device);
46   ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);
47 
~ConstantFolding()48   ~ConstantFolding() override {}
49 
name()50   string name() const override { return "constant folding"; };
51 
52   Status Optimize(Cluster* cluster, const GrapplerItem& item,
53                   GraphDef* output) override;
54 
55   void Feedback(Cluster* cluster, const GrapplerItem& item,
56                 const GraphDef& optimize_output, double result) override;
57 
58  private:
59   string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
60   bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
61 
62   bool IsReallyConstant(const NodeDef& node) const;
63 
64   Status MaterializeShapes(const GraphProperties& properties);
65 
66   Status MaterializeBroadcastGradientArgs(const NodeDef& node,
67                                           const GraphProperties& properties);
68   Status MaterializeReductionIndices(NodeDef* node,
69                                      const GraphProperties& properties);
70   Status MaterializeConstantValuedNode(NodeDef* node,
71                                        const GraphProperties& properties);
72   Status MaterializeConstants(const GraphProperties& properties);
73 
74   bool IsFoldable(const NodeDef& node) const;
75 
76   Status EvaluateNode(const NodeDef& node,
77                       const gtl::InlinedVector<TensorValue, 4>& inputs,
78                       gtl::InlinedVector<TensorValue, 4>* output) const;
79 
80   Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs,
81                              bool* result_too_large);
82 
83   Status FoldMergeNode(NodeDef* node, GraphDef* output_graph);
84   Status FoldNode(NodeDef* node, GraphDef* output_graph,
85                   bool* result_too_large);
86 
87   bool IsOnes(const NodeDef& node) const;
88   bool IsZeros(const NodeDef& node) const;
89   void ReplaceOperationWithIdentity(int input_to_forward,
90                                     const GraphProperties& properties,
91                                     NodeDef* node, GraphDef* graph);
92   void ReplaceOperationWithSnapshot(int input_to_forward,
93                                     const GraphProperties& properties,
94                                     NodeDef* node, GraphDef* graph);
95   void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
96   Status ReplaceOperationWithConstant(double value,
97                                       const GraphProperties& properties,
98                                       const TensorShapeProto& shape,
99                                       NodeDef* node, GraphDef* graph,
100                                       bool* success);
101   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
102   Status FoldGraph(GraphDef* output,
103                    absl::flat_hash_set<string>* nodes_to_not_simplify);
104 
105   bool IsSimplifiableReshape(const NodeDef& node,
106                              const GraphProperties& properties) const;
107   Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
108                        GraphProperties* properties,
109                        absl::flat_hash_set<string>* nodes_to_not_simplify);
110   Status SimplifyNode(bool use_shape_info, NodeDef* node,
111                       GraphDef* optimized_graph, GraphProperties* properties);
112 
113   Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
114                              GraphDef* output);
115 
116   // Applies partial constant folding for Concat which is not commutative.
117   // Returns true if the transformation applied successfully.
118   bool PartialConcatConstFolding(GraphDef* optimized_graph,
119                                  GraphProperties* properties, NodeDef* node);
120 
121   // Applies partial constant folding for associative operators AddN and
122   // AccumulateNV2. Returns true if the transformation applied successfully.
123   bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
124                                   GraphProperties* properties, NodeDef* node);
125 
126   // Applies partial constant propagation through IdentityN operator.
127   // Returns true if the transformation applied successfully.
128   bool PartialConstPropThroughIdentityN(NodeDef* node);
129 
130   // Pushes down constants on '+' and '*' operators if applicable. Returns true
131   // the transformation applied successfully.
132   bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
133 
134   // Aggregate constants present around a conv operator. Returns true if the
135   // transformation was applied successfully.
136   bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
137                        const GraphProperties& properties);
138 
139   // Strength reduces floating point division by a constant Div(x, const) to
140   // multiplication by the reciprocal Mul(x, Reciprocal(const)).
141   bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node);
142 
143   // Simplifies arithmetic operations with ones or zeros. Returns the status,
144   // and updates the success input argument that denotes if any simplification
145   // was applied.
146   Status SimplifyArithmeticOperations(const GraphProperties& properties,
147                                       bool use_shape_info,
148                                       GraphDef* optimized_graph, NodeDef* node,
149                                       bool* success);
150 
151   // Simplifies a Reshape operation to an Identity operation if applicable.
152   bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
153                        NodeDef* node);
154 
155   // Returns true if theres a possibility that a Reduce node could be simplified
156   // to an Identity/Reshape.
157   bool IsReductionCandidateForSimplification(
158       const NodeDef& node, const GraphProperties& properties,
159       TensorShapeProto* input_tensor_shape,
160       TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const;
161   // Returns true iff this reduction can be reduced to an identity (i.e if the
162   // set of dimensions to reduce along is empty). This happens often in the
163   // gradient graphs.
164   bool IsReductionSimplifiableToIdentity(
165       const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
166       const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
167   // Simplifies a Reduction operation to an Identity/Reshape operation if
168   // applicable.
169   bool SimplifyReduction(GraphDef* optimized_graph,
170                          const GraphProperties& properties, NodeDef* node);
171 
172   // Switch(x, x) will always feed false to its false branch and true to
173   // its true branch. By rewriting the graph a bit, we can propagate these
174   // constants down the two output branches, and just use control dependencies
175   // to trigger the selected one at runtime. For example,
176   //
177   //     +------+
178   // x-->|Switch|-->a  (in practice there may be multiple consumers of each
179   // x-->|      |-->b   output branch.)
180   //     +------+
181   //
182   // Is rewritten as
183   //
184   //     +------+
185   // x-->|Switch|-->Identity--^>Const(false)-->a
186   // x-->|      |-->Identity--^>Const(true)-->b
187   //     +------+
188   bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node);
189 
190   // Moves constants past Enter node if applicable.
191   bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node);
192 
193   // Simplifies Pack operation if applicable.
194   bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
195 
196   // Simplifies a Squeeze operation to an Identity operation if applicable.
197   bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
198                        GraphDef* optimized_graph, NodeDef* node);
199 
200   // Simplifies a Pad operation to an Identity operation if applicable.
201   Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
202                      GraphDef* optimized_graph, NodeDef* node, bool* success);
203 
204   // Simplifies a Tile operation to an Identity operation if applicable.
205   Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
206                       GraphDef* optimized_graph, NodeDef* node, bool* success);
207 
208   // Simplifies a StridedSlice operation to an Identity operation if applicable.
209   Status SimplifyStridedSlice(const GraphProperties& properties,
210                               bool use_shape_info, GraphDef* optimized_graph,
211                               NodeDef* node, bool* success);
212 
213   // Simplifies a Slice operation to an Identity operation if applicable.
214   Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
215                        GraphDef* optimized_graph, NodeDef* node, bool* success);
216 
217   // Removes Reverse op over dimensions with size 1.
218   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
219                        GraphDef* optimized_graph, NodeDef* node, bool* success);
220 
221   // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
222   bool RemoveRandomShuffle(const GraphProperties& properties,
223                            bool use_shape_info, GraphDef* optimized_graph,
224                            NodeDef* node);
225 
226   // Removes Shuffle or Transpose op over dimensions of size 1.
227   Status RemoveShuffleOrTranspose(const GraphProperties& properties,
228                                   bool use_shape_info,
229                                   GraphDef* optimized_graph, NodeDef* node,
230                                   bool* success);
231 
232   // Removes Split or SplitV node if possible.
233   bool RemoveSplitOrSplitV(const GraphProperties& properties,
234                            GraphDef* optimized_graph, NodeDef* node);
235 
236   bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
237                    GraphDef* optimized_graph, NodeDef* node);
238 
239   Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
240                                                GraphDef* optimized_graph);
241 
242   // Points to an externally provided device or to owned_device_;
243   RewriterConfig::Toggle opt_level_;
244   DeviceBase* cpu_device_;
245   std::unique_ptr<DeviceBase> owned_device_;
246 
247   std::unique_ptr<ResourceMgr> resource_mgr_;
248   GraphDef* graph_;
249   std::unique_ptr<NodeMap> node_map_;
250   std::unordered_set<string> nodes_to_preserve_;
251   std::unordered_set<string> nodes_whitelist_;
252   std::unordered_set<string> feed_nodes_;
253   bool has_fetch_;
254   bool graph_modified_;
255   bool graph_contains_assign_or_inplace_op_;
256 };
257 
258 }  // end namespace grappler
259 }  // end namespace tensorflow
260 
261 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_
262