1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 18 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/core/framework/device_base.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/resource_mgr.h" 23 #include "tensorflow/core/framework/tensor_shape.pb.h" 24 #include "tensorflow/core/grappler/costs/graph_properties.h" 25 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" 26 #include "tensorflow/core/grappler/utils.h" 27 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 28 29 namespace tensorflow { 30 namespace grappler { 31 32 const char kConstantFoldingConst[] = "ConstantFolding"; 33 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; 34 35 // Constant folding optimization for a graph. 36 class ConstantFolding : public GraphOptimizer { 37 public: 38 // The size limit will only be considered if the newly created node is greater 39 // than original_size (optional). 40 static Status CreateNodeDef(const string& name, const TensorValue& tensor, 41 NodeDef* node, size_t original_size = 0); 42 static string AddControlDependency(const string& input_name, GraphDef* graph, 43 NodeMap* node_map); 44 45 explicit ConstantFolding(DeviceBase* cpu_device); 46 ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device); 47 ~ConstantFolding()48 ~ConstantFolding() override {} 49 name()50 string name() const override { return "constant folding"; }; 51 52 Status Optimize(Cluster* cluster, const GrapplerItem& item, 53 GraphDef* output) override; 54 55 void Feedback(Cluster* cluster, const GrapplerItem& item, 56 const GraphDef& optimize_output, double result) override; 57 58 private: 59 string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; 60 bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; 61 62 bool IsReallyConstant(const NodeDef& node) const; 63 64 Status MaterializeShapes(const GraphProperties& properties); 65 66 Status MaterializeBroadcastGradientArgs(const NodeDef& node, 67 const GraphProperties& properties); 68 Status MaterializeReductionIndices(NodeDef* node, 69 const GraphProperties& properties); 70 Status MaterializeConstantValuedNode(NodeDef* node, 71 const GraphProperties& properties); 72 Status MaterializeConstants(const GraphProperties& properties); 73 74 bool IsFoldable(const NodeDef& node) const; 75 76 Status EvaluateNode(const NodeDef& node, 77 const gtl::InlinedVector<TensorValue, 4>& inputs, 78 gtl::InlinedVector<TensorValue, 4>* output) const; 79 80 Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs, 81 bool* result_too_large); 82 83 Status FoldMergeNode(NodeDef* node, GraphDef* output_graph); 84 Status FoldNode(NodeDef* node, GraphDef* output_graph, 85 bool* result_too_large); 86 87 bool IsOnes(const NodeDef& node) const; 88 bool IsZeros(const NodeDef& node) const; 89 void ReplaceOperationWithIdentity(int input_to_forward, 90 const GraphProperties& properties, 91 NodeDef* node, GraphDef* graph); 92 void ReplaceOperationWithSnapshot(int input_to_forward, 93 const GraphProperties& properties, 94 NodeDef* node, GraphDef* graph); 95 void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); 96 Status ReplaceOperationWithConstant(double value, 97 const GraphProperties& properties, 98 const TensorShapeProto& shape, 99 NodeDef* node, GraphDef* graph, 100 bool* success); 101 void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); 102 Status FoldGraph(GraphDef* output, 103 absl::flat_hash_set<string>* nodes_to_not_simplify); 104 105 bool IsSimplifiableReshape(const NodeDef& node, 106 const GraphProperties& properties) const; 107 Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph, 108 GraphProperties* properties, 109 absl::flat_hash_set<string>* nodes_to_not_simplify); 110 Status SimplifyNode(bool use_shape_info, NodeDef* node, 111 GraphDef* optimized_graph, GraphProperties* properties); 112 113 Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, 114 GraphDef* output); 115 116 // Applies partial constant folding for Concat which is not commutative. 117 // Returns true if the transformation applied successfully. 118 bool PartialConcatConstFolding(GraphDef* optimized_graph, 119 GraphProperties* properties, NodeDef* node); 120 121 // Applies partial constant folding for associative operators AddN and 122 // AccumulateNV2. Returns true if the transformation applied successfully. 123 bool PartialAssocOpConstFolding(GraphDef* optimized_graph, 124 GraphProperties* properties, NodeDef* node); 125 126 // Applies partial constant propagation through IdentityN operator. 127 // Returns true if the transformation applied successfully. 128 bool PartialConstPropThroughIdentityN(NodeDef* node); 129 130 // Pushes down constants on '+' and '*' operators if applicable. Returns true 131 // the transformation applied successfully. 132 bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node); 133 134 // Aggregate constants present around a conv operator. Returns true if the 135 // transformation was applied successfully. 136 bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node, 137 const GraphProperties& properties); 138 139 // Strength reduces floating point division by a constant Div(x, const) to 140 // multiplication by the reciprocal Mul(x, Reciprocal(const)). 141 bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); 142 143 // Simplifies arithmetic operations with ones or zeros. Returns the status, 144 // and updates the success input argument that denotes if any simplification 145 // was applied. 146 Status SimplifyArithmeticOperations(const GraphProperties& properties, 147 bool use_shape_info, 148 GraphDef* optimized_graph, NodeDef* node, 149 bool* success); 150 151 // Simplifies a Reshape operation to an Identity operation if applicable. 152 bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, 153 NodeDef* node); 154 155 // Returns true if theres a possibility that a Reduce node could be simplified 156 // to an Identity/Reshape. 157 bool IsReductionCandidateForSimplification( 158 const NodeDef& node, const GraphProperties& properties, 159 TensorShapeProto* input_tensor_shape, 160 TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const; 161 // Returns true iff this reduction can be reduced to an identity (i.e if the 162 // set of dimensions to reduce along is empty). This happens often in the 163 // gradient graphs. 164 bool IsReductionSimplifiableToIdentity( 165 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, 166 const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const; 167 // Simplifies a Reduction operation to an Identity/Reshape operation if 168 // applicable. 169 bool SimplifyReduction(GraphDef* optimized_graph, 170 const GraphProperties& properties, NodeDef* node); 171 172 // Switch(x, x) will always feed false to its false branch and true to 173 // its true branch. By rewriting the graph a bit, we can propagate these 174 // constants down the two output branches, and just use control dependencies 175 // to trigger the selected one at runtime. For example, 176 // 177 // +------+ 178 // x-->|Switch|-->a (in practice there may be multiple consumers of each 179 // x-->| |-->b output branch.) 180 // +------+ 181 // 182 // Is rewritten as 183 // 184 // +------+ 185 // x-->|Switch|-->Identity--^>Const(false)-->a 186 // x-->| |-->Identity--^>Const(true)-->b 187 // +------+ 188 bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node); 189 190 // Moves constants past Enter node if applicable. 191 bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node); 192 193 // Simplifies Pack operation if applicable. 194 bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); 195 196 // Simplifies a Squeeze operation to an Identity operation if applicable. 197 bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, 198 GraphDef* optimized_graph, NodeDef* node); 199 200 // Simplifies a Pad operation to an Identity operation if applicable. 201 Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, 202 GraphDef* optimized_graph, NodeDef* node, bool* success); 203 204 // Simplifies a Tile operation to an Identity operation if applicable. 205 Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, 206 GraphDef* optimized_graph, NodeDef* node, bool* success); 207 208 // Simplifies a StridedSlice operation to an Identity operation if applicable. 209 Status SimplifyStridedSlice(const GraphProperties& properties, 210 bool use_shape_info, GraphDef* optimized_graph, 211 NodeDef* node, bool* success); 212 213 // Simplifies a Slice operation to an Identity operation if applicable. 214 Status SimplifySlice(const GraphProperties& properties, bool use_shape_info, 215 GraphDef* optimized_graph, NodeDef* node, bool* success); 216 217 // Removes Reverse op over dimensions with size 1. 218 Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, 219 GraphDef* optimized_graph, NodeDef* node, bool* success); 220 221 // Removes RandomShuffle op if it is scalar or first dimension is of size 1. 222 bool RemoveRandomShuffle(const GraphProperties& properties, 223 bool use_shape_info, GraphDef* optimized_graph, 224 NodeDef* node); 225 226 // Removes Shuffle or Transpose op over dimensions of size 1. 227 Status RemoveShuffleOrTranspose(const GraphProperties& properties, 228 bool use_shape_info, 229 GraphDef* optimized_graph, NodeDef* node, 230 bool* success); 231 232 // Removes Split or SplitV node if possible. 233 bool RemoveSplitOrSplitV(const GraphProperties& properties, 234 GraphDef* optimized_graph, NodeDef* node); 235 236 bool MergeConcat(const GraphProperties& properties, bool use_shape_info, 237 GraphDef* optimized_graph, NodeDef* node); 238 239 Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, 240 GraphDef* optimized_graph); 241 242 // Points to an externally provided device or to owned_device_; 243 RewriterConfig::Toggle opt_level_; 244 DeviceBase* cpu_device_; 245 std::unique_ptr<DeviceBase> owned_device_; 246 247 std::unique_ptr<ResourceMgr> resource_mgr_; 248 GraphDef* graph_; 249 std::unique_ptr<NodeMap> node_map_; 250 std::unordered_set<string> nodes_to_preserve_; 251 std::unordered_set<string> nodes_whitelist_; 252 std::unordered_set<string> feed_nodes_; 253 bool has_fetch_; 254 bool graph_modified_; 255 bool graph_contains_assign_or_inplace_op_; 256 }; 257 258 } // end namespace grappler 259 } // end namespace tensorflow 260 261 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 262