• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the contexts used during XLA compilation.
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
20 
21 #include <vector>
22 
23 #include "tensorflow/compiler/tf2xla/xla_expression.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/platform/macros.h"
33 
34 namespace tensorflow {
35 
36 class XlaOpKernelContext;
37 class XlaCompiler;
38 
39 // The XlaContext is the data structure that holds the state of an XLA
40 // compilation, that is accessible from OpKernelContexts when compiling a
41 // subgraph of Ops using XLA.
42 class XlaContext : public ResourceBase {
43  public:
44   // Retrieves the XlaContext of the current compilation.
45   static XlaContext& Get(const OpKernelContext* ctx);
46 
47   // Creates a new XlaContext. See the documentation on the class data fields
48   // for descriptions of the arguments.
49   XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
50              const Graph* graph);
51 
52   // Virtual method defined by ResourceBase.
53   string DebugString() const override;
54 
compiler()55   XlaCompiler* compiler() const { return compiler_; }
56 
StackTraceForNodeName(const std::string & name)57   const AbstractStackTrace* StackTraceForNodeName(const std::string& name) {
58     const auto& it = stack_traces_.find(name);
59     if (it != stack_traces_.end()) {
60       return it->second.get();
61     }
62     return nullptr;
63   }
64 
65   // Returns the XlaBuilder that Ops use for compiling new expressions.
builder()66   xla::XlaBuilder* builder() { return builder_; }
67 
args()68   const std::vector<XlaExpression>& args() const { return args_; }
69   void set_args(std::vector<XlaExpression> args);
70 
retvals()71   const std::vector<XlaExpression>& retvals() { return retvals_; }
72 
73   // Sets a return value.
74   // Since we do not always know in advance how many return values there are,
75   // grows the return values vector to size index+1 if it is smaller.
76   void SetRetval(int index, const XlaExpression& expression);
77 
78   // Adds 'resource' to the set of resources owned by the context.
79   XlaResource* AddResource(std::unique_ptr<XlaResource> resource);
80 
resources()81   const std::vector<std::unique_ptr<XlaResource>>& resources() {
82     return resources_;
83   }
84 
85   // Get an XLA lambda to compute Max. This is cached in the
86   // XlaContext since it may be used by multiple Ops. There is a
87   // separate specialization of the computation for each DataType.
88   const xla::XlaComputation* GetOrCreateMax(const DataType type);
89 
90   // Get an XLA lambda to compute Min. This is cached in the
91   // XlaContext since it may be used by multiple Ops. There is a
92   // separate specialization of the computation for each DataType.
93   const xla::XlaComputation* GetOrCreateMin(const DataType type);
94 
95   // Get an XLA lambda to compute Add. This is cached in the
96   // XlaContext since it may be used by multiple Ops. There is a
97   // separate specialization of the computation for each DataType.
98   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
99 
100   // Get an XLA lambda to compute Mul. This is cached in the
101   // XlaContext since it may be used by multiple Ops. There is a
102   // separate specialization of the computation for each DataType.
103   const xla::XlaComputation* GetOrCreateMul(const DataType type);
104 
105   // The name of the XlaContext resource during symbolic graph execution.
106   static const char kXlaContextResourceName[];
107 
RecordCollectiveReduceV2OpInfo(int group_key,int group_size)108   Status RecordCollectiveReduceV2OpInfo(int group_key, int group_size) {
109     if (!collective_reduce_info_) {
110       collective_reduce_info_ = {group_key, group_size};
111     } else if (collective_reduce_info_->group_key != group_key ||
112                collective_reduce_info_->group_size != group_size) {
113       return errors::InvalidArgument(
114           "Only single configuration of CollectiveReduceV2Op is ",
115           "supported in a given cluster. Recorded group_key=",
116           collective_reduce_info_->group_key,
117           " attempting to insert group_key=", group_key);
118     }
119     return Status::OK();
120   }
121 
122   const absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo>&
GetCollectiveReduceV2OpInfo()123   GetCollectiveReduceV2OpInfo() {
124     return collective_reduce_info_;
125   }
126 
127  private:
128   XlaCompiler* const compiler_;
129 
130   // The XlaBuilder used to construct the subgraph's compiled representation.
131   xla::XlaBuilder* builder_;
132 
133   // Stack traces for the graph used for compilation.
134   StackTracesMap stack_traces_;
135 
136   // Arguments to the Tensorflow graph, indexed by _Arg index.
137   // Includes both compile-time constant arguments and runtime parameters.
138   std::vector<XlaExpression> args_;
139 
140   // Return values of the Tensorflow graph, indexed by _Retval index.
141   std::vector<XlaExpression> retvals_;
142 
143   // Holds ownership of resources. The resources are not ordered.
144   std::vector<std::unique_ptr<XlaResource>> resources_;
145 
146   // Information about encountered CollectiveReduceV2OpInfo ops. We allow only a
147   // single configuration per cluster.
148   absl::optional<XlaCompilationResult::CollectiveReduceV2OpInfo>
149       collective_reduce_info_;
150 
151   // Cache of prebuilt computations indexed by their type.
152   using ComputationMap = std::map<DataType, xla::XlaComputation>;
153 
154   // Finds the value for the given type in out map if it already
155   // exists or makes a new value with create function and keeps it the
156   // map. The returned value != nullptr and is owned by the map.
157   const xla::XlaComputation* LookupOrCreate(
158       DataType type, ComputationMap* out,
159       const std::function<xla::XlaComputation()>& create);
160 
161   // Cached computation to compute Max of two elements, specialized by type.
162   ComputationMap max_func_;
163 
164   // Cached computation to compute Min of two elements, specialized by type.
165   ComputationMap min_func_;
166 
167   // Cached computation to compute Sum of two elements, specialized by type.
168   ComputationMap add_func_;
169 
170   // Cached computation to compute Mul of two elements, specialized by type.
171   ComputationMap mul_func_;
172 
173   // Cached computation to compute Sigmoid of an element, specialized by type.
174   ComputationMap sigmoid_func_;
175 
176   TF_DISALLOW_COPY_AND_ASSIGN(XlaContext);
177 };
178 
179 }  // namespace tensorflow
180 
181 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
182