1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
18
19 #include <unordered_map>
20
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/kernel_def.pb.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/lib/core/status.h"
30
31 namespace tensorflow {
32
33 // ValidateConfig returns OK iff config is valid.
34 Status ValidateConfig(const tf2xla::Config& config);
35
36 // Modifies <graph_def> to include placeholders for each fed tensor, and
37 // update references to the fed tensors to refer to the placeholders.
38 // The existing nodes referenced by the feeds are not removed or modified
39 // (except where their input edges are modified by the replacement of other
40 // feeds).
41 Status AddPlaceholdersForFeeds(
42 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
43 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
44
45 // Returns in <out> a copy of <in>, pruned to only include fetches from
46 // <config>.
47 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
48 GraphDef* out);
49
50 // Returns node:port for the given <id>.
51 string TensorIdToString(const tf2xla::TensorId& id);
52
53 // Updates the sharding of <n> based on the sharding of its neighbors.
54 // If <out_edges> is true, outgoing edges from <n> are considered; else incoming
55 // edges are considered.
56 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
57
58 // Add an allowed data type to the AttrConstraint with the given name.
59 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
60 KernelDef* kdef);
61
62 // Returns the next random seed to use for seeding xla rng.
63 uint32 GetXLARandomSeed();
64
65 // Indicates how a FunctionDef is associated with a graph node (e.g. the node is
66 // a function call, or the node has function attrs).
67 class AssociatedFunctionInfo {
68 public:
69 enum AssociatedFunctionType {
70 kFunctionAttr = 0,
71 kFunctionCallNode = 1,
72 kSymbolicGradient = 2,
73 };
74
75 // The function is an attr of the node.
FunctionAttr(const string & func_name,const AttrValueMap & attrs,const string & attr_name)76 static AssociatedFunctionInfo FunctionAttr(const string& func_name,
77 const AttrValueMap& attrs,
78 const string& attr_name) {
79 return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
80 }
81
82 // The node is a function call.
FunctionCall(const string & func_name,const AttrValueMap & attrs)83 static AssociatedFunctionInfo FunctionCall(const string& func_name,
84 const AttrValueMap& attrs) {
85 // attr_name will not be used in this case.
86 return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
87 /*attr_name=*/"");
88 }
89
90 // The node is a SymbolicGradient op.
SymbolicGradient(const string & func_name,const AttrValueMap & attrs)91 static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
92 const AttrValueMap& attrs) {
93 // attr_name will not be used in this case.
94 return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
95 /*attr_name=*/"");
96 }
97
type()98 AssociatedFunctionType type() const { return type_; }
99
func_name()100 const string& func_name() const { return func_name_; }
101
attr_name()102 const string& attr_name() const { return attr_name_; }
103
attrs()104 const AttrValueMap& attrs() const { return attrs_; }
105
106 private:
AssociatedFunctionInfo(AssociatedFunctionType type,const string & func_name,const AttrValueMap & attrs,const string & attr_name)107 AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
108 const AttrValueMap& attrs, const string& attr_name)
109 : type_(type),
110 func_name_(func_name),
111 attrs_(attrs),
112 attr_name_(attr_name) {}
113
114 // Available for all instances.
115 AssociatedFunctionType type_;
116 string func_name_;
117 AttrValueMap attrs_;
118
119 // Only available if the function is defined in an attr.
120 string attr_name_;
121 };
122
123 // Returns if the NodeDef has associated function.
124 bool HasAssociatedFunction(const NodeDef& node_def,
125 const FunctionLibraryDefinition* fld);
126
127 // Gets functions associated with the node. Current cases:
128 // 1. For function call node, its function name;
129 // 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
130 // and returned attrs will be this node's attributes;
131 // 3. For nodes like XlaWhile/XlaIf, all their function attributes.
132 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
133 const Node& node, const FunctionLibraryDefinition* fld);
134
135 // Changes associated functions for the node. Current cases:
136 // 1. For function call node, creates a new node with the new function name and
137 // remove the old node;
138 // 2. For SymbolicGradient op, add or replace GradientDef in
139 // FunctionLibraryDefinition;
140 // 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
141 Status RewriteAssociatedFunction(
142 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
143 const AssociatedFunctionInfo& associated_function,
144 const string& rewritten_function_name);
145
146 // Attribute to mark nodes to be executed on host.
147 extern const char kXlaOutsideCompilationAttrName[];
148
149 // Class to act as cache for FunctionLibraryRuntime::Handle objects.
150 class CachedFunctionHandles {
151 public:
CachedFunctionHandles(FunctionLibraryRuntime * flr)152 CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {}
153
154 // Populates `handle` for requested function and attributes. If we have
155 // instantiated the function with the same attributes before, `handle` will be
156 // cached handle; otherwise instantiate the function and populate `handle`.
157 Status GetOrInstantiate(const string& func_name, AttrSlice attrs,
158 FunctionLibraryRuntime::Handle* handle);
159
160 // Releases all handles in the cache. Returns first non-OK status if any;
161 // returns OK otherwise.
162 Status ReleaseAllHandles();
163
~CachedFunctionHandles()164 ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); }
165
166 private:
167 FunctionLibraryRuntime* flr_;
168 std::map<string, FunctionLibraryRuntime::Handle> handles_;
169
170 TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
171 };
172
173 // Struct for node's output edge info.
174 struct OutEdgeInfo {
175 Node* dst;
176 int src_output, dst_input;
177 };
178
179 // Replaces node `n` with a new node whose NodeDef is `node_def`.
180 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
181
182 // Helper function that builds an Identity node.
183 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
184 DataType dtype, const Node* input,
185 absl::optional<string> requested_device);
186
187 // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
188 // body functions to use the Const nodes instead of original _Arg nodes.
189 //
190 // For example, say we have the following computation:
191 // shape = constant_op.constant([1])
192 // return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
193 // If we do not rewrite then/else function, they will use _Arg node as shape
194 // input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
195 // time constant, so XLA compilation will fail. This rewriting process will
196 // change the shape input to Const node.
197 Status PropagateConstIntoFunctionalNodes(
198 Graph* g, const FunctionLibraryDefinition* lookup_fld,
199 FunctionLibraryDefinition* fld);
200
201 // Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
202 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
203 FunctionLibraryDefinition* fld);
204
205 // Finds the following pattern in the graph:
206 // 1) EmptyTensorList -> forward While op -> backward While op,
207 // 2) in forward While op, a Const node is pushed,
208 // 3) in backward While op, data is popped from the tensor list.
209 // And rewrites backward While op to use Const node instead of TensorListPopBack
210 // result.
211 // TODO(b/128633174) remove the TensorList and related TensorList ops.
212 Status RewriteTensorListWithConstElement(Graph* g,
213 FunctionLibraryDefinition* fld);
214
215 extern const char kTpuReplicateAttrName[];
216
IsConstTraversableOpType(const Node * node)217 inline bool IsConstTraversableOpType(const Node* node) {
218 return node->type_string() == "Identity" ||
219 node->type_string() == "IdentityN" || node->IsWhileNode();
220 }
221
222 // Determines whether a loop body is invariant for the given argument index.
223 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
224 const FunctionLibraryDefinition* lookup_fld);
225
226 } // namespace tensorflow
227
228 #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
229