• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_context.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/layout_util.h"
32 #include "tensorflow/compiler/xla/literal.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/common_runtime/dma_helper.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace tensorflow {
38 
39 const char XlaContext::kXlaContextResourceName[] = "_xla_context";
40 
41 // Looks up the context associated with the current step. It is stored
42 // in a resource container managed by the device.
Get(const OpKernelContext * ctx)43 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
44   // When an Op kernel wants to use an XLA JIT context, the
45   // per-step context is looked up in the resource manager. The
46   // JIT will prepopulate the JITContext.
47   XlaContext* context;
48   TF_CHECK_OK(ctx->resource_manager()->Lookup(
49       ctx->step_container()->name(), kXlaContextResourceName, &context));
50   // The resource manager handed us a fresh reference to 'context', but retains
51   // a reference itself so the context won't be freed. The resource manager will
52   // outlive the JIT compilation.
53   context->Unref();
54   return *context;
55 }
56 
set_args(std::vector<XlaExpression> args)57 void XlaContext::set_args(std::vector<XlaExpression> args) {
58   args_ = std::move(args);
59 }
60 
XlaContext(XlaCompiler * compiler,xla::XlaBuilder * builder)61 XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
62     : compiler_(compiler), builder_(builder) {}
63 
DebugString() const64 string XlaContext::DebugString() const { return "XLA JIT context"; }
65 
SetRetval(int index,const XlaExpression & expression)66 void XlaContext::SetRetval(int index, const XlaExpression& expression) {
67   if (retvals_.size() <= index) {
68     retvals_.resize(index + 1);
69   }
70   retvals_[index] = expression;
71 }
72 
AddResource(std::unique_ptr<XlaResource> resource)73 XlaResource* XlaContext::AddResource(std::unique_ptr<XlaResource> resource) {
74   resources_.push_back(std::move(resource));
75   return resources_.back().get();
76 }
77 
GetOrCreateMax(const DataType type)78 const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
79   return LookupOrCreate(type, &max_func_, [type] {
80     const string type_string = DataTypeString(type);
81     VLOG(1) << "Building Max() for " << type_string;
82     xla::XlaBuilder b("max<" + type_string + ">");
83     xla::PrimitiveType xla_type;
84     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
85     auto x =
86         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
87     auto y =
88         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
89     xla::Max(x, y);
90     return b.Build().ConsumeValueOrDie();
91   });
92 }
93 
GetOrCreateMin(const DataType type)94 const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
95   return LookupOrCreate(type, &min_func_, [type] {
96     const string type_string = DataTypeString(type);
97     VLOG(1) << "Building Min() for " << type_string;
98     xla::XlaBuilder b("min<" + type_string + ">");
99     xla::PrimitiveType xla_type;
100     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
101     auto x =
102         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
103     auto y =
104         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
105     xla::Min(x, y);
106     return b.Build().ConsumeValueOrDie();
107   });
108 }
109 
GetOrCreateAdd(const DataType type)110 const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
111   return LookupOrCreate(type, &add_func_, [type] {
112     const string type_string = DataTypeString(type);
113     VLOG(1) << "Building Add() for " << type_string;
114     xla::XlaBuilder b("add<" + type_string + ">");
115     xla::PrimitiveType xla_type;
116     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
117     auto x =
118         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
119     auto y =
120         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
121     xla::Add(x, y);
122     return b.Build().ConsumeValueOrDie();
123   });
124 }
125 
GetOrCreateMul(const DataType type)126 const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
127   return LookupOrCreate(type, &mul_func_, [type] {
128     const string type_string = DataTypeString(type);
129     VLOG(1) << "Building Mul() for " << type_string;
130     xla::XlaBuilder b("mul<" + type_string + ">");
131     xla::PrimitiveType xla_type;
132     TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
133     auto x =
134         xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
135     auto y =
136         xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
137     xla::Mul(x, y);
138     return b.Build().ConsumeValueOrDie();
139   });
140 }
141 
LookupOrCreate(DataType type,ComputationMap * out,const std::function<xla::XlaComputation ()> & create)142 const xla::XlaComputation* XlaContext::LookupOrCreate(
143     DataType type, ComputationMap* out,
144     const std::function<xla::XlaComputation()>& create) {
145   {
146     const auto& entry = (*out)[type];
147     if (!entry.IsNull()) {
148       return &entry;
149     }
150   }
151   auto new_entry = create();
152   {
153     // Somebody else might have made one concurrently.
154     auto& entry = (*out)[type];
155     if (entry.IsNull()) {
156       entry = std::move(new_entry);
157     }
158     return &entry;
159   }
160 }
161 
162 }  // namespace tensorflow
163