1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 18 19 #include "absl/container/flat_hash_map.h" 20 #include "absl/container/flat_hash_set.h" 21 #include "absl/types/span.h" 22 #include "tensorflow/core/framework/device_base.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/resource_mgr.h" 25 #include "tensorflow/core/framework/tensor_shape.pb.h" 26 #include "tensorflow/core/grappler/costs/graph_properties.h" 27 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" 28 #include "tensorflow/core/grappler/utils.h" 29 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 30 31 namespace tensorflow { 32 namespace grappler { 33 34 const char kConstantFoldingConst[] = "ConstantFolding"; 35 const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; 36 extern const int64 kMaxConstantSize; 37 38 // Constant folding optimization for a graph. 39 class ConstantFolding : public GraphOptimizer { 40 public: 41 // The size limit will only be considered if the newly created node is greater 42 // than original_size (optional). 43 static Status CreateNodeDef(const string& name, const TensorValue& tensor, 44 NodeDef* node, size_t original_size = 0); 45 static string AddControlDependency(const string& input_name, GraphDef* graph, 46 NodeMap* node_map); 47 48 explicit ConstantFolding(DeviceBase* cpu_device, 49 bool disable_compressed_tensor_optimization = false, 50 bool fold_quantization_emulation = true); 51 ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device, 52 bool disable_compressed_tensor_optimization = false, 53 bool fold_quantization_emulation = true); 54 ~ConstantFolding()55 ~ConstantFolding() override {} 56 name()57 string name() const override { return "constant_folding"; }; 58 UsesFunctionLibrary()59 bool UsesFunctionLibrary() const override { return false; } 60 61 Status Optimize(Cluster* cluster, const GrapplerItem& item, 62 GraphDef* output) override; 63 64 void Feedback(Cluster* cluster, const GrapplerItem& item, 65 const GraphDef& optimize_output, double result) override; 66 67 private: 68 bool ForwardInputs(NodeDef* node, absl::Span<const int> inputs_to_forward); 69 string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const; 70 bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const; 71 72 bool IsReallyConstant(const NodeDef& node) const; 73 74 bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor); 75 76 Status MaterializeShapes(const GraphProperties& properties); 77 78 Status MaterializeBroadcastGradientArgs(const NodeDef& node, 79 const GraphProperties& properties); 80 Status MaterializeReductionIndices(NodeDef* node, 81 const GraphProperties& properties); 82 Status MaterializeConstantValuedNode(NodeDef* node, 83 const GraphProperties& properties); 84 Status MaterializeOutputValues(NodeDef* node, 85 const GraphProperties& properties); 86 Status MaterializeConstants(const GraphProperties& properties); 87 88 bool IsFoldable(const NodeDef& node, const GraphProperties* properties); 89 bool IsFoldableUncached(const NodeDef& node, 90 const GraphProperties* properties) const; 91 bool MaybeFoldable(const NodeDef& node, 92 const GraphProperties* properties) const; 93 94 Status EvaluateNode(const NodeDef& node, 95 const gtl::InlinedVector<TensorValue, 4>& inputs, 96 gtl::InlinedVector<TensorValue, 4>* output) const; 97 98 Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs, 99 bool* result_too_large); 100 101 Status FoldMergeNode(NodeDef* node, GraphDef* output_graph); 102 Status FoldNode(NodeDef* node, GraphDef* output_graph, 103 bool* result_too_large); 104 105 bool IsOnes(const NodeDef& node) const; 106 bool IsZeros(const NodeDef& node) const; 107 bool ReplaceOperationWithBroadcastTo(int input_to_broadcast, 108 const GraphProperties& properties, 109 NodeDef* node, GraphDef* graph); 110 void ReplaceOperationWithIdentity(int input_to_forward, 111 const GraphProperties& properties, 112 NodeDef* node, GraphDef* graph); 113 void ReplaceOperationWithSnapshot(int input_to_forward, 114 const GraphProperties& properties, 115 NodeDef* node, GraphDef* graph); 116 void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties, 117 GraphDef* graph); 118 void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast, 119 const GraphProperties& properties, 120 NodeDef* node, GraphDef* graph); 121 void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); 122 Status ReplaceOperationWithConstant(double value, 123 const GraphProperties& properties, 124 const TensorShapeProto& shape, 125 NodeDef* node, GraphDef* graph); 126 127 // Notice: Destroys *value. 128 Status ReplaceOperationWithConstantTensor(DataType dtype, TensorProto* value, 129 NodeDef* node, GraphDef* graph); 130 131 void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); 132 Status FoldGraph(const GraphProperties& properties, GraphDef* output, 133 absl::flat_hash_set<string>* nodes_to_not_simplify); 134 135 bool IsSimplifiableReshape(const NodeDef& node, 136 const GraphProperties& properties) const; 137 Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph, 138 GraphProperties* properties, 139 absl::flat_hash_set<string>* nodes_to_not_simplify); 140 Status SimplifyNode(bool use_shape_info, NodeDef* node, 141 GraphDef* optimized_graph, GraphProperties* properties); 142 143 Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item, 144 GraphDef* optimized_graph); 145 146 // Applies partial constant folding for Concat which is not commutative. 147 // Returns true if the transformation applied successfully. 148 bool PartialConcatConstFolding(GraphDef* optimized_graph, 149 GraphProperties* properties, NodeDef* node); 150 151 // Applies partial constant folding for associative operators AddN and 152 // AccumulateNV2. Returns true if the transformation applied successfully. 153 bool PartialAssocOpConstFolding(GraphDef* optimized_graph, 154 GraphProperties* properties, NodeDef* node); 155 156 // Applies partial constant propagation through IdentityN operator. 157 // Returns true if the transformation applied successfully. 158 bool PartialConstPropThroughIdentityN(NodeDef* node); 159 160 struct ConstantPushDownContext { 161 NodeDef* op_child; 162 NodeDef* const_child; 163 bool left_child_is_const; 164 bool right_child_is_const; 165 NodeDef* left_leaf; 166 NodeDef* right_leaf; 167 bool left_leaf_is_const; 168 bool right_leaf_is_const; 169 170 // Shape & type information. 171 const std::vector<OpInfo::TensorProperties>* parent_input_props; 172 const std::vector<OpInfo::TensorProperties>* op_child_input_props; 173 }; 174 175 // Populates ctx with pointers to the nodes in expression tree for which 176 // constant pushdown optimization is being considered, corresponding to one of 177 // the following configurations: 178 // 179 // parent parent 180 // / \ / \ 181 // op_child const_child const_child op_child 182 // / \ / \ 183 // left_leaf right_leaf left_leaf right_leaf 184 // 185 // Returns true if the expression is possible amenable for optimization. 186 // Returns false if must_have_properties is true and input properties for 187 // parent and op_child are not known. 188 bool PrepareConstantPushDown(const NodeDef& parent, 189 const GraphProperties& properties, 190 bool must_have_properties, 191 ConstantPushDownContext* ctx) const; 192 193 // Pushes down constants on '+', '-', '*', and '/' operators if applicable. 194 // Returns true if the transformation applied successfully. 195 bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph, 196 NodeDef* node); 197 198 // Pushes down constants on '+' and 'BiasAdd' operators if applicable. 199 // Returns true if the graph was modified. 200 bool ConstantPushDownBiasAdd(GraphProperties* properties, 201 GraphDef* optimized_graph, NodeDef* node); 202 203 // Aggregate constants present around a conv operator. Returns true if the 204 // transformation was applied successfully. 205 bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node, 206 const GraphProperties& properties); 207 208 // Strength reduces floating point division by a constant Div(x, const) to 209 // multiplication by the reciprocal Mul(x, Reciprocal(const)). 210 bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); 211 212 // Simplifies arithmetic operations with ones or zeros. Returns the status, 213 // and updates the success input argument that denotes if any simplification 214 // was applied. 215 Status SimplifyArithmeticOperations(const GraphProperties& properties, 216 bool use_shape_info, 217 GraphDef* optimized_graph, NodeDef* node); 218 219 // Simplifies a Reshape operation to an Identity operation if applicable. 220 bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, 221 NodeDef* node); 222 223 // Returns true iff the node is a reduction and its reduction indices are 224 // constant. Sets *indices_is_empty to true if the set of dimensions to reduce 225 // along is empty (this happens often in the gradient graphs). 226 bool IsReductionWithConstantIndices(const NodeDef& node, 227 bool* indices_is_empty) const; 228 // Returns true if theres a possibility that a Reduce node could be simplified 229 // to an Identity/Reshape. 230 bool IsReductionCandidateForSimplification( 231 const NodeDef& node, const GraphProperties& properties, 232 TensorShapeProto* input_tensor_shape, 233 TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const; 234 // Returns true iff this reduction can be reduced to an identity (i.e if the 235 // input dimensions to reduce along are all of size 1 and keep_dims is true). 236 bool IsReductionSimplifiableToIdentity( 237 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, 238 const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const; 239 // Changes a reduction into an Identity op, returning true on success. 240 bool ReplaceReductionWithIdentity(NodeDef* node) const; 241 242 // Simplifies a Reduction operation to an Identity/Reshape operation if 243 // applicable. 244 bool SimplifyReduction(GraphDef* optimized_graph, 245 const GraphProperties& properties, NodeDef* node); 246 247 // Switch(x, x) will always feed false to its false branch and true to 248 // its true branch. By rewriting the graph a bit, we can propagate these 249 // constants down the two output branches, and just use control dependencies 250 // to trigger the selected one at runtime. For example, 251 // 252 // +------+ 253 // x-->|Switch|-->a (in practice there may be multiple consumers of each 254 // x-->| |-->b output branch.) 255 // +------+ 256 // 257 // Is rewritten as 258 // 259 // +------+ 260 // x-->|Switch|-->Identity--^>Const(false)-->a 261 // x-->| |-->Identity--^>Const(true)-->b 262 // +------+ 263 bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node); 264 265 // Moves constants past Enter node if applicable. 266 bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node); 267 268 // Simplifies Pack operation if applicable. 269 bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); 270 271 // Simplifies a Squeeze operation to an Identity operation if applicable. 272 void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, 273 GraphDef* optimized_graph, NodeDef* node); 274 275 // Simplifies a Pad operation to an Identity operation if applicable. 276 Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, 277 GraphDef* optimized_graph, NodeDef* node); 278 279 // Simplifies a Tile operation to an Identity operation if applicable. 280 Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, 281 GraphDef* optimized_graph, NodeDef* node); 282 283 // Simplifies a StridedSlice operation to an Identity operation if applicable. 284 Status SimplifyStridedSlice(const GraphProperties& properties, 285 bool use_shape_info, GraphDef* optimized_graph, 286 NodeDef* node); 287 288 // Simplifies a Slice operation to an Identity operation if applicable. 289 Status SimplifySlice(const GraphProperties& properties, bool use_shape_info, 290 GraphDef* optimized_graph, NodeDef* node); 291 292 // Simplify a Case operation where the output_idx is known. 293 bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node); 294 295 // Simplify a Select operation where the predicates are all true or all false. 296 bool SimplifySelect(const GraphProperties& properties, 297 GraphDef* optimized_graph, NodeDef* node); 298 299 // Replaces variable updates that are effectively no-ops with NoOp nodes. 300 void RemoveRedundantVariableUpdates(GraphProperties* properties, 301 GraphDef* optimized_graph, NodeDef* node); 302 303 // Removes Reverse op over dimensions with size 1. 304 Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, 305 GraphDef* optimized_graph, NodeDef* node); 306 307 // Removes RandomShuffle op if it is scalar or first dimension is of size 1. 308 void RemoveRandomShuffle(const GraphProperties& properties, 309 bool use_shape_info, GraphDef* optimized_graph, 310 NodeDef* node); 311 312 // Removes Shuffle or Transpose op over dimensions of size 1. 313 Status RemoveShuffleOrTranspose(const GraphProperties& properties, 314 bool use_shape_info, 315 GraphDef* optimized_graph, NodeDef* node); 316 317 // Removes Split or SplitV node if possible. 318 void RemoveSplitOrSplitV(const GraphProperties& properties, 319 GraphDef* optimized_graph, NodeDef* node); 320 321 bool GetConcatAxis(const NodeDef& node, int* axis); 322 bool MergeConcat(bool use_shape_info, GraphProperties* properties, 323 GraphDef* optimized_graph, NodeDef* node); 324 325 Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, 326 GraphDef* optimized_graph); 327 328 // Points to an externally provided device or to owned_device_; 329 RewriterConfig::Toggle opt_level_; 330 DeviceBase* cpu_device_; 331 std::unique_ptr<DeviceBase> owned_device_; 332 333 std::unique_ptr<ResourceMgr> resource_mgr_; 334 GraphDef* graph_; 335 std::unique_ptr<NodeMap> node_map_; 336 std::unordered_set<string> nodes_to_preserve_; 337 // TODO(rmlarsen): Could these be keyed on absl::string_view? 338 absl::flat_hash_set<string> nodes_allowlist_; 339 absl::flat_hash_set<string> feed_nodes_; 340 absl::flat_hash_map<string, bool> maybe_foldable_nodes_; 341 bool has_fetch_; 342 bool graph_modified_; 343 bool graph_contains_assign_or_inplace_op_; 344 bool disable_compressed_tensor_optimization_; 345 bool fold_quantization_emulation_; 346 }; 347 348 } // end namespace grappler 349 } // end namespace tensorflow 350 351 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ 352