• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
18 
19 #include <unordered_map>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
23 #include "tensorflow/compiler/xla/status_macros.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/kernel_def.pb.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/lib/core/status.h"
29 
30 namespace tensorflow {
31 
32 // ValidateConfig returns OK iff config is valid.
33 Status ValidateConfig(const tf2xla::Config& config);
34 
35 // Modifies <graph_def> to include placeholders for each fed tensor, and
36 // update references to the fed tensors to refer to the placeholders.
37 // The existing nodes referenced by the feeds are not removed or modified
38 // (except where their input edges are modified by the replacement of other
39 // feeds).
40 Status AddPlaceholdersForFeeds(
41     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
42     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
43 
44 // Returns in <out> a copy of <in>, pruned to only include fetches from
45 // <config>.
46 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
47                          GraphDef* out);
48 
49 // Returns node:port for the given <id>.
50 string TensorIdToString(const tf2xla::TensorId& id);
51 
52 // Updates the sharding of <n> based on the sharding of its neighbors.
53 // If <out_edges> is true, outgoing edges from <n> are considered; else incoming
54 // edges are considered.
55 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
56 
57 // Add an allowed data type to the AttrConstraint with the given name.
58 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
59                                    KernelDef* kdef);
60 
61 // Returns the next random seed to use for seeding xla rng.
62 uint32 GetXLARandomSeed();
63 
64 // Indicates how a FunctionDef is associated with a graph node (e.g. the node is
65 // a function call, or the node has function attrs).
66 class AssociatedFunctionInfo {
67  public:
68   enum AssociatedFunctionType {
69     kFunctionAttr = 0,
70     kFunctionCallNode = 1,
71     kSymbolicGradient = 2,
72   };
73 
74   // The function is an attr of the node.
FunctionAttr(const string & func_name,const AttrValueMap & attrs,const string & attr_name)75   static AssociatedFunctionInfo FunctionAttr(const string& func_name,
76                                              const AttrValueMap& attrs,
77                                              const string& attr_name) {
78     return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
79   }
80 
81   // The node is a function call.
FunctionCall(const string & func_name,const AttrValueMap & attrs)82   static AssociatedFunctionInfo FunctionCall(const string& func_name,
83                                              const AttrValueMap& attrs) {
84     // attr_name will not be used in this case.
85     return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
86                                   /*attr_name=*/"");
87   }
88 
89   // The node is a SymbolicGradient op.
SymbolicGradient(const string & func_name,const AttrValueMap & attrs)90   static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
91                                                  const AttrValueMap& attrs) {
92     // attr_name will not be used in this case.
93     return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
94                                   /*attr_name=*/"");
95   }
96 
type()97   AssociatedFunctionType type() const { return type_; }
98 
func_name()99   const string& func_name() const { return func_name_; }
100 
attr_name()101   const string& attr_name() const { return attr_name_; }
102 
attrs()103   const AttrValueMap& attrs() const { return attrs_; }
104 
105  private:
AssociatedFunctionInfo(AssociatedFunctionType type,const string & func_name,const AttrValueMap & attrs,const string & attr_name)106   AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
107                          const AttrValueMap& attrs, const string& attr_name)
108       : type_(type),
109         func_name_(func_name),
110         attrs_(attrs),
111         attr_name_(attr_name) {}
112 
113   // Available for all instances.
114   AssociatedFunctionType type_;
115   string func_name_;
116   AttrValueMap attrs_;
117 
118   // Only available if the function is defined in an attr.
119   string attr_name_;
120 };
121 
122 // Returns if the NodeDef has associated function.
123 bool HasAssociatedFunction(const NodeDef& node_def,
124                            const FunctionLibraryDefinition* fld);
125 
126 // Gets functions associated with the node. Current cases:
127 // 1. For function call node, its function name;
128 // 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
129 //    and returned attrs will be this node's attributes;
130 // 3. For nodes like XlaWhile/XlaIf, all their function attributes.
131 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
132     const Node& node, const FunctionLibraryDefinition* fld);
133 
134 // Changes associated functions for the node. Current cases:
135 // 1. For function call node, creates a new node with the new function name and
136 //    remove the old node;
137 // 2. For SymbolicGradient op, add or replace GradientDef in
138 //    FunctionLibraryDefinition;
139 // 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
140 Status RewriteAssociatedFunction(
141     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
142     const AssociatedFunctionInfo& associated_function,
143     const string& rewritten_function_name);
144 
145 // Attribute to mark nodes to be executed on host.
146 extern const char kXlaOutsideCompilationAttrName[];
147 
148 // Class to act as cache for FunctionLibraryRuntime::Handle objects.
149 class CachedFunctionHandles {
150  public:
CachedFunctionHandles(FunctionLibraryRuntime * flr)151   CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {}
152 
153   // Populates `handle` for requested function and attributes. If we have
154   // instantiated the function with the same attributes before, `handle` will be
155   // cached handle; otherwise instantiate the function and populate `handle`.
156   Status GetOrInstantiate(const string& func_name, AttrSlice attrs,
157                           FunctionLibraryRuntime::Handle* handle);
158 
159   // Releases all handles in the cache. Returns first non-OK status if any;
160   // returns OK otherwise.
161   Status ReleaseAllHandles();
162 
~CachedFunctionHandles()163   ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); }
164 
165  private:
166   FunctionLibraryRuntime* flr_;
167   std::map<string, FunctionLibraryRuntime::Handle> handles_;
168 
169   TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
170 };
171 
172 // Struct for node's output edge info.
173 struct OutEdgeInfo {
174   Node* dst;
175   int src_output, dst_input;
176 };
177 
178 // Replaces node `n` with a new node whose NodeDef is `node_def`.
179 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
180 
181 // Helper function that builds an Identity node.
182 xla::StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
183                                        DataType dtype, const Node* input,
184                                        absl::optional<string> requested_device);
185 
186 // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
187 // body functions to use the Const nodes instead of original _Arg nodes.
188 //
189 // For example, say we have the following computation:
190 //     shape = constant_op.constant([1])
191 //     return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
192 // If we do not rewrite then/else function, they will use _Arg node as shape
193 // input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
194 // time constant, so XLA compilation will fail. This rewriting process will
195 // change the shape input to Const node.
196 Status PropagateConstIntoFunctionalNodes(
197     Graph* g, const FunctionLibraryDefinition* lookup_fld,
198     FunctionLibraryDefinition* fld);
199 
200 }  // namespace tensorflow
201 
202 #endif  // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
203