1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ 18 19 #include <functional> 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/grappler/op_types.h" 23 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" 24 #include "tensorflow/core/lib/gtl/inlined_vector.h" 25 #include "tensorflow/core/platform/protobuf.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 namespace fusion_utils { 30 31 // These functions are invoked with first and second function signature, 32 // should set a signature of fused second_function. 33 using SetFunctionSignatureFn = std::function<void( 34 const OpDef& first_function_signature, 35 const OpDef& second_function_signature, OpDef* fused_function_signature)>; 36 37 using StringCollection = gtl::InlinedVector<string, 2>; 38 39 // These functions are invoked with nodes from second function that were 40 // previously taking arguments as input. The `arg_num` tells which 41 // function argument node was using as an input, e.g: 42 // node(arg_1, other_node, arg_4) 43 // would be called on the first and third input with arg_num equal 1 and 4. 44 // It should set up inputs based on first function inputs or outputs or 45 // second function inputs. 46 using SetInputFn = 47 std::function<string(const StringCollection& first_function_inputs, 48 const StringCollection& second_function_inputs, 49 const StringCollection& parent_outputs, int arg_num)>; 50 51 // This function is invoked with first and second function ret. It is used to 52 // set up returns of fused function. 53 using SetOutputFn = 54 std::function<void(const protobuf::Map<string, string>& parent_ret, 55 const protobuf::Map<string, string>& second_function_ret, 56 protobuf::Map<string, string>* fused_ret)>; 57 58 using SetNodesFn = std::function<void( 59 const FunctionDef& first_function, const FunctionDef& second_function, 60 FunctionDef* fused_function, FunctionDefLibrary* library)>; 61 62 void MergeNodes(const FunctionDef& first_function, 63 const FunctionDef& second_function, FunctionDef* fused_function, 64 FunctionDefLibrary* library); 65 66 // Returns true if functions can be composed. 67 bool CanCompose(const OpDef& first_signature, const OpDef& second_signature); 68 69 void ComposeSignature(const OpDef& first_signature, 70 const OpDef& second_signature, OpDef* fused_signature); 71 72 string ComposeInput(const StringCollection& first_inputs, 73 const StringCollection& second_inputs, 74 const StringCollection& first_outputs, int arg_num); 75 76 // Sets output to the composition of first and second function: 77 // second_function(first_function(args...)). 78 void ComposeOutput(const protobuf::Map<string, string>& first_ret, 79 const protobuf::Map<string, string>& second_ret, 80 protobuf::Map<string, string>* fused_ret); 81 82 // Set input signature to `first_function_signature` and output signature 83 // to `first_function_signature` + `second_function_signature` 84 void CombineSignature(const OpDef& first_signature, 85 const OpDef& second_signature, OpDef* fused_signature); 86 87 // Apart from first function returns, return values from second function as 88 // extra returns like: 89 // return *first_function(...), *second_function(...) 90 void CombineOutput(const protobuf::Map<string, string>& first_ret, 91 const protobuf::Map<string, string>& second_ret, 92 protobuf::Map<string, string>* fused_ret); 93 94 // Returns true if both signatures have the same number of input and output 95 // args. 96 bool HasSameSignature(const OpDef& first_signature, 97 const OpDef& second_signature); 98 99 // Check if both signatures are same and copy it from `first_signature`. 100 void SameSignature(const OpDef& first_signature, const OpDef& second_signature, 101 OpDef* fused_signature); 102 103 // Take the same input as first function. 104 string SameInput(const StringCollection& first_inputs, 105 const StringCollection& second_inputs, 106 const StringCollection& first_outputs, int arg_num); 107 108 // Create a fused function that computes the short-circuit logical AND of the 109 // result of the first function and the result of the second function. 110 void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret, 111 const protobuf::Map<string, string>& second_ret, 112 protobuf::Map<string, string>* fused_ret); 113 114 void LazyConjunctionNodes(const FunctionDef& first_function, 115 const FunctionDef& second_function, 116 FunctionDef* fused_function, 117 FunctionDefLibrary* library); 118 119 // Fuse `first_function` with `second_function`, setting `fused_name_prefix` as 120 // a name prefix. The nodes from `first_function` are copied unmodified. All 121 // of the setup functions are called with a copy of second function having names 122 // that are not conflicting with first function. This means that copied nodes 123 // from second function can end up having different names. For explanation of 124 // set up functions see the documentation of the functions types. 125 FunctionDef* FuseFunctions( 126 const FunctionDef& first_function, const FunctionDef& second_function, 127 StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, 128 const SetInputFn& set_input, const SetOutputFn& set_output, 129 const SetNodesFn& set_nodes, FunctionDefLibrary* library); 130 131 } // namespace fusion_utils 132 } // namespace grappler 133 } // namespace tensorflow 134 135 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ 136