1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the contexts used during XLA compilation. 17 18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 20 21 #include <vector> 22 23 #include "tensorflow/compiler/tf2xla/xla_expression.h" 24 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 25 #include "tensorflow/compiler/xla/client/xla_builder.h" 26 #include "tensorflow/compiler/xla/client/xla_computation.h" 27 #include "tensorflow/compiler/xla/status_macros.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/resource_mgr.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/platform/macros.h" 33 34 namespace tensorflow { 35 36 class XlaOpKernelContext; 37 class XlaCompiler; 38 39 // The XlaContext is the data structure that holds the state of an XLA 40 // compilation, that is accessible from OpKernelContexts when compiling a 41 // subgraph of Ops using XLA. 42 class XlaContext : public ResourceBase { 43 public: 44 // Retrieves the XlaContext of the current compilation. 45 static XlaContext& Get(const OpKernelContext* ctx); 46 47 // Creates a new XlaContext. See the documentation on the class data fields 48 // for descriptions of the arguments. 49 XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, 50 const Graph* graph); 51 52 // Virtual method defined by ResourceBase. 53 string DebugString() const override; 54 compiler()55 XlaCompiler* compiler() const { return compiler_; } 56 StackTraceForNodeName(const std::string & name)57 const AbstractStackTrace* StackTraceForNodeName(const std::string& name) { 58 const auto& it = stack_traces_.find(name); 59 if (it != stack_traces_.end()) { 60 return it->second.get(); 61 } 62 return nullptr; 63 } 64 65 // Returns the XlaBuilder that Ops use for compiling new expressions. builder()66 xla::XlaBuilder* builder() { return builder_; } 67 args()68 const std::vector<XlaExpression>& args() const { return args_; } 69 void set_args(std::vector<XlaExpression> args); 70 retvals()71 const std::vector<XlaExpression>& retvals() { return retvals_; } 72 73 // Sets a return value. 74 // Since we do not always know in advance how many return values there are, 75 // grows the return values vector to size index+1 if it is smaller. 76 void SetRetval(int index, const XlaExpression& expression); 77 78 // Adds 'resource' to the set of resources owned by the context. 79 XlaResource* AddResource(std::unique_ptr<XlaResource> resource); 80 resources()81 const std::vector<std::unique_ptr<XlaResource>>& resources() { 82 return resources_; 83 } 84 85 // Get an XLA lambda to compute Max. This is cached in the 86 // XlaContext since it may be used by multiple Ops. There is a 87 // separate specialization of the computation for each DataType. 88 const xla::XlaComputation* GetOrCreateMax(const DataType type); 89 90 // Get an XLA lambda to compute Min. This is cached in the 91 // XlaContext since it may be used by multiple Ops. There is a 92 // separate specialization of the computation for each DataType. 93 const xla::XlaComputation* GetOrCreateMin(const DataType type); 94 95 // Get an XLA lambda to compute Add. This is cached in the 96 // XlaContext since it may be used by multiple Ops. There is a 97 // separate specialization of the computation for each DataType. 98 const xla::XlaComputation* GetOrCreateAdd(const DataType type); 99 100 // Get an XLA lambda to compute Mul. This is cached in the 101 // XlaContext since it may be used by multiple Ops. There is a 102 // separate specialization of the computation for each DataType. 103 const xla::XlaComputation* GetOrCreateMul(const DataType type); 104 105 // The name of the XlaContext resource during symbolic graph execution. 106 static const char kXlaContextResourceName[]; 107 108 // Records the collective information from the nested compilation `result`. 109 Status RecordCollectiveInfoFromNestedCompilationResult( 110 const XlaCompilationResult& result); 111 112 // Records the collective configurations for all the collectives in the XLA 113 // cluster and returns the channel_id to be used for the next collective. 114 StatusOr<int64_t> RecordCollectiveInfo(int group_key, int group_size); 115 116 const std::optional<XlaCompilationResult::CollectiveInfo>& GetCollectiveInfo()117 GetCollectiveInfo() { 118 return collective_info_; 119 } 120 121 private: 122 XlaCompiler* const compiler_; 123 124 // The XlaBuilder used to construct the subgraph's compiled representation. 125 xla::XlaBuilder* builder_; 126 127 // Stack traces for the graph used for compilation. 128 StackTracesMap stack_traces_; 129 130 // Arguments to the Tensorflow graph, indexed by _Arg index. 131 // Includes both compile-time constant arguments and runtime parameters. 132 std::vector<XlaExpression> args_; 133 134 // Return values of the Tensorflow graph, indexed by _Retval index. 135 std::vector<XlaExpression> retvals_; 136 137 // Holds ownership of resources. The resources are not ordered. 138 std::vector<std::unique_ptr<XlaResource>> resources_; 139 140 // Information about encountered collective ops. We allow only a 141 // single configuration per cluster. 142 std::optional<XlaCompilationResult::CollectiveInfo> collective_info_; 143 144 // Cache of prebuilt computations indexed by their type. 145 using ComputationMap = std::map<DataType, xla::XlaComputation>; 146 147 // Finds the value for the given type in out map if it already 148 // exists or makes a new value with create function and keeps it the 149 // map. The returned value != nullptr and is owned by the map. 150 const xla::XlaComputation* LookupOrCreate( 151 DataType type, ComputationMap* out, 152 const std::function<xla::XlaComputation()>& create); 153 154 // Cached computation to compute Max of two elements, specialized by type. 155 ComputationMap max_func_; 156 157 // Cached computation to compute Min of two elements, specialized by type. 158 ComputationMap min_func_; 159 160 // Cached computation to compute Sum of two elements, specialized by type. 161 ComputationMap add_func_; 162 163 // Cached computation to compute Mul of two elements, specialized by type. 164 ComputationMap mul_func_; 165 166 // Cached computation to compute Sigmoid of an element, specialized by type. 167 ComputationMap sigmoid_func_; 168 169 TF_DISALLOW_COPY_AND_ASSIGN(XlaContext); 170 }; 171 172 } // namespace tensorflow 173 174 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ 175