• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
18 
19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
20 #include "tensorflow/compiler/tf2xla/xla_context.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/notification.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/public/version.h"
33 
34 namespace tensorflow {
35 
36 // GraphCompiler compiles the graph in topological order in the current
37 // thread. It also resolves the nondeterminism in the graph by enforcing a
38 // total order on all inputs to a node. This abstraction helps us create the
39 // same XLA computation given two structurally equivalent TensorFlow graphs.
40 // If a function call is visited during the graph traversal, it is then
41 // compiled through the xla_context into a computation and a `Call` operation
42 // is inserted to call into that computation.
43 //
44 // Note: GraphCompiler was created to remove our dependency to TF Executor in
45 // the history. There are still some todos so that we can completely decouple
46 // from Executor.
47 //
48 // TODO(yunxing): Remove usage of XlaCompilationDevice.
49 //
50 // TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now
51 // that we don't use TF Executor to pass around a tensor.
52 //
53 // TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can
54 // handle a XlaExpression directly instead of a Tensor. This may require our own
55 // op registration infrastructure instead of FunctionLibraryRuntime.
56 class GraphCompiler {
57  public:
GraphCompiler(XlaContext * xla_context,XlaCompilationDevice * device,Graph * graph,FunctionLibraryRuntime * flib,ScopedStepContainer * step_container)58   GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
59                 Graph* graph, FunctionLibraryRuntime* flib,
60                 ScopedStepContainer* step_container)
61       : xla_context_(xla_context),
62         device_(device),
63         graph_(graph),
64         flib_(flib),
65         step_container_(step_container) {}
66 
67   // Compiles the graph. The results are written in `xla_context` that is passed
68   // into the compiler.
69   Status Compile();
70 
71  private:
72   // Partially sets params. This partially set params can be reused
73   // across multiple nodes visit.
74   void PartiallySetupParams(OpKernelContext::Params* params);
75 
76   // Tests if a node is a functional node. A functional node represents a
77   // defined computation and should be compiled using `compiler_`.
78   bool IsFunctional(Node* n);
79 
80   // Compiles a functional node and writes result to OpkernelContext. A
81   // functional node represents a defined computation and should be compiled
82   // using `compiler_`.
83   Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
84 
85   XlaContext* xla_context_;
86   XlaCompilationDevice* device_;
87   Graph* graph_;
88   FunctionLibraryRuntime* flib_;
89   ScopedStepContainer* step_container_;
90   // A buffer to hold tensor inputs to a node, this is reused across the graph
91   // traversal.
92   gtl::InlinedVector<TensorValue, 4> tensor_inputs_;
93 };
94 
95 }  // namespace tensorflow
96 
97 #endif  // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
98