1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_context.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/layout_util.h"
32 #include "tensorflow/compiler/xla/literal.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/common_runtime/dma_helper.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace tensorflow {
38
39 const char XlaContext::kXlaContextResourceName[] = "_xla_context";
40
41 // Looks up the context associated with the current step. It is stored
42 // in a resource container managed by the device.
Get(const OpKernelContext * ctx)43 /* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
44 // When an Op kernel wants to use an XLA JIT context, the
45 // per-step context is looked up in the resource manager. The
46 // JIT will prepopulate the JITContext.
47 XlaContext* context;
48 TF_CHECK_OK(ctx->resource_manager()->Lookup(
49 ctx->step_container()->name(), kXlaContextResourceName, &context));
50 // The resource manager handed us a fresh reference to 'context', but retains
51 // a reference itself so the context won't be freed. The resource manager will
52 // outlive the JIT compilation.
53 context->Unref();
54 return *context;
55 }
56
set_args(std::vector<XlaExpression> args)57 void XlaContext::set_args(std::vector<XlaExpression> args) {
58 args_ = std::move(args);
59 }
60
XlaContext(XlaCompiler * compiler,xla::XlaBuilder * builder)61 XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
62 : compiler_(compiler), builder_(builder) {}
63
DebugString() const64 string XlaContext::DebugString() const { return "XLA JIT context"; }
65
SetRetval(int index,const XlaExpression & expression)66 void XlaContext::SetRetval(int index, const XlaExpression& expression) {
67 if (retvals_.size() <= index) {
68 retvals_.resize(index + 1);
69 }
70 retvals_[index] = expression;
71 }
72
AddResource(std::unique_ptr<XlaResource> resource)73 XlaResource* XlaContext::AddResource(std::unique_ptr<XlaResource> resource) {
74 resources_.push_back(std::move(resource));
75 return resources_.back().get();
76 }
77
GetOrCreateMax(const DataType type)78 const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
79 return LookupOrCreate(type, &max_func_, [type] {
80 const string type_string = DataTypeString(type);
81 VLOG(1) << "Building Max() for " << type_string;
82 xla::XlaBuilder b("max<" + type_string + ">");
83 xla::PrimitiveType xla_type;
84 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
85 auto x =
86 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
87 auto y =
88 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
89 xla::Max(x, y);
90 return b.Build().ConsumeValueOrDie();
91 });
92 }
93
GetOrCreateMin(const DataType type)94 const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
95 return LookupOrCreate(type, &min_func_, [type] {
96 const string type_string = DataTypeString(type);
97 VLOG(1) << "Building Min() for " << type_string;
98 xla::XlaBuilder b("min<" + type_string + ">");
99 xla::PrimitiveType xla_type;
100 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
101 auto x =
102 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
103 auto y =
104 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
105 xla::Min(x, y);
106 return b.Build().ConsumeValueOrDie();
107 });
108 }
109
GetOrCreateAdd(const DataType type)110 const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
111 return LookupOrCreate(type, &add_func_, [type] {
112 const string type_string = DataTypeString(type);
113 VLOG(1) << "Building Add() for " << type_string;
114 xla::XlaBuilder b("add<" + type_string + ">");
115 xla::PrimitiveType xla_type;
116 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
117 auto x =
118 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
119 auto y =
120 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
121 xla::Add(x, y);
122 return b.Build().ConsumeValueOrDie();
123 });
124 }
125
GetOrCreateMul(const DataType type)126 const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
127 return LookupOrCreate(type, &mul_func_, [type] {
128 const string type_string = DataTypeString(type);
129 VLOG(1) << "Building Mul() for " << type_string;
130 xla::XlaBuilder b("mul<" + type_string + ">");
131 xla::PrimitiveType xla_type;
132 TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
133 auto x =
134 xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
135 auto y =
136 xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
137 xla::Mul(x, y);
138 return b.Build().ConsumeValueOrDie();
139 });
140 }
141
LookupOrCreate(DataType type,ComputationMap * out,const std::function<xla::XlaComputation ()> & create)142 const xla::XlaComputation* XlaContext::LookupOrCreate(
143 DataType type, ComputationMap* out,
144 const std::function<xla::XlaComputation()>& create) {
145 {
146 const auto& entry = (*out)[type];
147 if (!entry.IsNull()) {
148 return &entry;
149 }
150 }
151 auto new_entry = create();
152 {
153 // Somebody else might have made one concurrently.
154 auto& entry = (*out)[type];
155 if (entry.IsNull()) {
156 entry = std::move(new_entry);
157 }
158 return &entry;
159 }
160 }
161
162 } // namespace tensorflow
163