1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the contexts used during XLA compilation. 17 18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 20 21 #include <vector> 22 23 #include "tensorflow/compiler/tf2xla/xla_expression.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/client/xla_computation.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/graph/graph.h" 31 #include "tensorflow/core/platform/macros.h" 32 33 namespace tensorflow { 34 35 class XlaOpKernelContext; 36 class XlaCompiler; 37 38 // The XlaContext is the data structure that holds the state of an XLA 39 // compilation, that is accessible from OpKernelContexts when compiling a 40 // subgraph of Ops using XLA. 41 class XlaContext : public ResourceBase { 42 public: 43 // Retrieves the XlaContext of the current compilation. 44 static XlaContext& Get(const OpKernelContext* ctx); 45 46 // Creates a new XlaContext. See the documentation on the class data fields 47 // for descriptions of the arguments. 48 XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, 49 const Graph* graph); 50 51 // Virtual method defined by ResourceBase. 52 string DebugString() const override; 53 compiler()54 XlaCompiler* compiler() const { return compiler_; } 55 StackTraceForNodeName(const std::string & name)56 const AbstractStackTrace* StackTraceForNodeName(const std::string& name) { 57 const auto& it = stack_traces_.find(name); 58 if (it != stack_traces_.end()) { 59 return it->second.get(); 60 } 61 return nullptr; 62 } 63 64 // Returns the XlaBuilder that Ops use for compiling new expressions. builder()65 xla::XlaBuilder* builder() { return builder_; } 66 args()67 const std::vector<XlaExpression>& args() const { return args_; } 68 void set_args(std::vector<XlaExpression> args); 69 retvals()70 const std::vector<XlaExpression>& retvals() { return retvals_; } 71 72 // Sets a return value. 73 // Since we do not always know in advance how many return values there are, 74 // grows the return values vector to size index+1 if it is smaller. 75 void SetRetval(int index, const XlaExpression& expression); 76 77 // Adds 'resource' to the set of resources owned by the context. 78 XlaResource* AddResource(std::unique_ptr<XlaResource> resource); 79 resources()80 const std::vector<std::unique_ptr<XlaResource>>& resources() { 81 return resources_; 82 } 83 84 // Get an XLA lambda to compute Max. This is cached in the 85 // XlaContext since it may be used by multiple Ops. There is a 86 // separate specialization of the computation for each DataType. 87 const xla::XlaComputation* GetOrCreateMax(const DataType type); 88 89 // Get an XLA lambda to compute Min. This is cached in the 90 // XlaContext since it may be used by multiple Ops. There is a 91 // separate specialization of the computation for each DataType. 92 const xla::XlaComputation* GetOrCreateMin(const DataType type); 93 94 // Get an XLA lambda to compute Add. This is cached in the 95 // XlaContext since it may be used by multiple Ops. There is a 96 // separate specialization of the computation for each DataType. 97 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 98 99 // Get an XLA lambda to compute Mul. This is cached in the 100 // XlaContext since it may be used by multiple Ops. There is a 101 // separate specialization of the computation for each DataType. 102 const xla::XlaComputation* GetOrCreateMul(const DataType type); 103 104 // The name of the XlaContext resource during symbolic graph execution. 105 static const char kXlaContextResourceName[]; 106 107 private: 108 XlaCompiler* const compiler_; 109 110 // The XlaBuilder used to construct the subgraph's compiled representation. 111 xla::XlaBuilder* builder_; 112 113 // Stack traces for the graph used for compilation. 114 StackTracesMap stack_traces_; 115 116 // Arguments to the Tensorflow graph, indexed by _Arg index. 117 // Includes both compile-time constant arguments and runtime parameters. 118 std::vector<XlaExpression> args_; 119 120 // Return values of the Tensorflow graph, indexed by _Retval index. 121 std::vector<XlaExpression> retvals_; 122 123 // Holds ownership of resources. The resources are not ordered. 124 std::vector<std::unique_ptr<XlaResource>> resources_; 125 126 // Cache of prebuilt computations indexed by their type. 127 using ComputationMap = std::map<DataType, xla::XlaComputation>; 128 129 // Finds the value for the given type in out map if it already 130 // exists or makes a new value with create function and keeps it the 131 // map. The returned value != nullptr and is owned by the map. 132 const xla::XlaComputation* LookupOrCreate( 133 DataType type, ComputationMap* out, 134 const std::function<xla::XlaComputation()>& create); 135 136 // Cached computation to compute Max of two elements, specialized by type. 137 ComputationMap max_func_; 138 139 // Cached computation to compute Min of two elements, specialized by type. 140 ComputationMap min_func_; 141 142 // Cached computation to compute Sum of two elements, specialized by type. 143 ComputationMap add_func_; 144 145 // Cached computation to compute Mul of two elements, specialized by type. 146 ComputationMap mul_func_; 147 148 // Cached computation to compute Sigmoid of an element, specialized by type. 149 ComputationMap sigmoid_func_; 150 151 TF_DISALLOW_COPY_AND_ASSIGN(XlaContext); 152 }; 153 154 } // namespace tensorflow 155 156 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 157