• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the contexts used during XLA compilation.
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
20 
21 #include <vector>
22 
23 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
24 #include "tensorflow/compiler/tf2xla/xla_expression.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/resource_mgr.h"
31 #include "tensorflow/core/platform/macros.h"
32 
33 namespace tensorflow {
34 
35 class XlaOpKernelContext;
36 
37 // The XlaContext is the data structure that holds the state of an XLA
38 // compilation, that is accessible from OpKernelContexts when compiling a
39 // subgraph of Ops using XLA.
40 class XlaContext : public ResourceBase {
41  public:
42   // Retrieves the XlaContext of the current compilation.
43   static XlaContext& Get(const OpKernelContext* ctx);
44 
45   // Creates a new XlaContext. See the documentation on the class data fields
46   // for descriptions of the arguments.
47   XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
48 
49   // Virtual method defined by ResourceBase.
50   string DebugString() const override;
51 
compiler()52   XlaCompiler* compiler() const { return compiler_; }
53 
54   // Returns the XlaBuilder that Ops use for compiling new expressions.
builder()55   xla::XlaBuilder* builder() { return builder_; }
56 
args()57   const std::vector<XlaExpression>& args() const { return args_; }
58   void set_args(std::vector<XlaExpression> args);
59 
retvals()60   const std::vector<XlaExpression>& retvals() { return retvals_; }
61 
62   // Sets a return value.
63   // Since we do not always know in advance how many return values there are,
64   // grows the return values vector to size index+1 if it is smaller.
65   void SetRetval(int index, const XlaExpression& expression);
66 
67   // Adds 'resource' to the set of resources owned by the context.
68   XlaResource* AddResource(std::unique_ptr<XlaResource> resource);
69 
resources()70   const std::vector<std::unique_ptr<XlaResource>>& resources() {
71     return resources_;
72   }
73 
74   // Get an XLA lambda to compute Max. This is cached in the
75   // XlaContext since it may be used by multiple Ops. There is a
76   // separate specialization of the computation for each DataType.
77   const xla::XlaComputation* GetOrCreateMax(const DataType type);
78 
79   // Get an XLA lambda to compute Min. This is cached in the
80   // XlaContext since it may be used by multiple Ops. There is a
81   // separate specialization of the computation for each DataType.
82   const xla::XlaComputation* GetOrCreateMin(const DataType type);
83 
84   // Get an XLA lambda to compute Add. This is cached in the
85   // XlaContext since it may be used by multiple Ops. There is a
86   // separate specialization of the computation for each DataType.
87   const xla::XlaComputation* GetOrCreateAdd(const DataType type);
88 
89   // Get an XLA lambda to compute Mul. This is cached in the
90   // XlaContext since it may be used by multiple Ops. There is a
91   // separate specialization of the computation for each DataType.
92   const xla::XlaComputation* GetOrCreateMul(const DataType type);
93 
94   // The name of the XlaContext resource during symbolic graph execution.
95   static const char kXlaContextResourceName[];
96 
97  private:
98   XlaCompiler* const compiler_;
99 
100   // The XlaBuilder used to construct the subgraph's compiled representation.
101   xla::XlaBuilder* builder_;
102 
103   // Arguments to the Tensorflow graph, indexed by _Arg index.
104   // Includes both compile-time constant arguments and runtime parameters.
105   std::vector<XlaExpression> args_;
106 
107   // Return values of the Tensorflow graph, indexed by _Retval index.
108   std::vector<XlaExpression> retvals_;
109 
110   // Holds ownership of resources. The resources are not ordered.
111   std::vector<std::unique_ptr<XlaResource>> resources_;
112 
113   // Cache of prebuilt computations indexed by their type.
114   using ComputationMap = std::map<DataType, xla::XlaComputation>;
115 
116   // Finds the value for the given type in out map if it already
117   // exists or makes a new value with create function and keeps it the
118   // map. The returned value != nullptr and is owned by the map.
119   const xla::XlaComputation* LookupOrCreate(
120       DataType type, ComputationMap* out,
121       const std::function<xla::XlaComputation()>& create);
122 
123   // Cached computation to compute Max of two elements, specialized by type.
124   ComputationMap max_func_;
125 
126   // Cached computation to compute Min of two elements, specialized by type.
127   ComputationMap min_func_;
128 
129   // Cached computation to compute Sum of two elements, specialized by type.
130   ComputationMap add_func_;
131 
132   // Cached computation to compute Mul of two elements, specialized by type.
133   ComputationMap mul_func_;
134 
135   // Cached computation to compute Sigmoid of an element, specialized by type.
136   ComputationMap sigmoid_func_;
137 
138   TF_DISALLOW_COPY_AND_ASSIGN(XlaContext);
139 };
140 
141 }  // namespace tensorflow
142 
143 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
144