1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ 18 19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 20 #include "tensorflow/compiler/tf2xla/xla_context.h" 21 #include "tensorflow/compiler/xla/client/local_client.h" 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/function.h" 25 #include "tensorflow/core/framework/function.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/platform/env.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/notification.h" 31 #include "tensorflow/core/platform/thread_annotations.h" 32 #include "tensorflow/core/public/version.h" 33 34 namespace tensorflow { 35 36 // GraphCompiler compiles the graph in topological order in the current 37 // thread. It also resolves the nondeterminism in the graph by enforcing a 38 // total order on all inputs to a node. This abstraction helps us create the 39 // same XLA computation given two structurally equivalent TensorFlow graphs. 40 // If a function call is visited during the graph traversal, it is then 41 // compiled through the xla_context into a computation and a `Call` operation 42 // is inserted to call into that computation. 43 // 44 // Note: GraphCompiler was created to remove our dependency to TF Executor in 45 // the history. There are still some todos so that we can completely decouple 46 // from Executor. 47 // 48 // TODO(yunxing): Remove usage of XlaCompilationDevice. 49 // 50 // TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now 51 // that we don't use TF Executor to pass around a tensor. 52 // 53 // TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can 54 // handle a XlaExpression directly instead of a Tensor. This may require our own 55 // op registration infrastructure instead of FunctionLibraryRuntime. 56 class GraphCompiler { 57 public: GraphCompiler(XlaContext * xla_context,XlaCompilationDevice * device,Graph * graph,FunctionLibraryRuntime * flib,ScopedStepContainer * step_container)58 GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, 59 Graph* graph, FunctionLibraryRuntime* flib, 60 ScopedStepContainer* step_container) 61 : xla_context_(xla_context), 62 device_(device), 63 graph_(graph), 64 flib_(flib), 65 step_container_(step_container) {} 66 67 // Compiles the graph. The results are written in `xla_context` that is passed 68 // into the compiler. 69 Status Compile(); 70 71 private: 72 // Partially sets params. This partially set params can be reused 73 // across multiple nodes visit. 74 void PartiallySetupParams(OpKernelContext::Params* params); 75 76 // Tests if a node is a functional node. A functional node represents a 77 // defined computation and should be compiled using `compiler_`. 78 bool IsFunctional(Node* n); 79 80 // Compiles a functional node and writes result to OpkernelContext. A 81 // functional node represents a defined computation and should be compiled 82 // using `compiler_`. 83 Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); 84 85 XlaContext* xla_context_; 86 XlaCompilationDevice* device_; 87 Graph* graph_; 88 FunctionLibraryRuntime* flib_; 89 ScopedStepContainer* step_container_; 90 // A buffer to hold tensor inputs to a node, this is reused across the graph 91 // traversal. 92 gtl::InlinedVector<TensorValue, 4> tensor_inputs_; 93 }; 94 95 } // namespace tensorflow 96 97 #endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ 98