• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for compiling XLA computations and managing handles that refer to
17 // them.
18 
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/service/compiler.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/compiler/xrt/xrt.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
35 #include "tensorflow/compiler/xrt/xrt_device.h"
36 #include "tensorflow/compiler/xrt/xrt_metrics.h"
37 #include "tensorflow/compiler/xrt/xrt_util.h"
38 #include "tensorflow/core/framework/op_kernel.h"
39 #include "tensorflow/core/framework/resource_mgr.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/tensor_shape.h"
42 #include "tensorflow/core/framework/types.pb.h"
43 #include "tensorflow/core/lib/core/refcount.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/monitoring/timed.h"
46 #include "tensorflow/core/lib/strings/proto_serialization.h"
47 #include "tensorflow/core/platform/fingerprint.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace tensorflow {
51 
52 namespace {
53 
GenerateXlaDeviceAssignment(const xrt::DeviceAssignment & xrt_device_assignment,int num_replicas,int num_cores_per_replica,xla::DeviceAssignment * device_assignment)54 Status GenerateXlaDeviceAssignment(
55     const xrt::DeviceAssignment& xrt_device_assignment, int num_replicas,
56     int num_cores_per_replica, xla::DeviceAssignment* device_assignment) {
57   if (num_cores_per_replica !=
58       xrt_device_assignment.computation_devices_size()) {
59     return errors::InvalidArgument(
60         "Device assignment does not have the correct number of "
61         "computation_devices: num_cores_per_replica=",
62         num_cores_per_replica, " computation_devices=",
63         xrt_device_assignment.computation_devices_size());
64   }
65   for (int64_t c = 0; c < xrt_device_assignment.computation_devices_size();
66        ++c) {
67     const auto& computation_devices =
68         xrt_device_assignment.computation_devices(c);
69     if (num_replicas != computation_devices.replica_devices_size()) {
70       return errors::InvalidArgument(
71           "Device assignment does not have the correct number of "
72           "replica_device_ids: num_replicas=",
73           num_replicas,
74           " replica_devices=", computation_devices.replica_devices_size());
75     }
76     for (int64_t r = 0; r < computation_devices.replica_devices_size(); ++r) {
77       const auto& coords = computation_devices.replica_devices(r);
78       if (coords.value_size() != 4) {
79         return errors::InvalidArgument(
80             "Device assignment mesh coordinates must have 4 entries, got ",
81             coords.value_size());
82       }
83       for (int n = 0; n < 3; ++n) {
84         if (coords.value(n) != 0) {
85           return errors::InvalidArgument("Mesh coordinate at index ", n,
86                                          " must be 0, got ", coords.value(n));
87         }
88       }
89       (*device_assignment)(r, c) = coords.value(3);
90     }
91   }
92   return OkStatus();
93 }
94 
95 class XRTCompileOp : public OpKernel {
96  public:
97   explicit XRTCompileOp(OpKernelConstruction* ctx);
98   ~XRTCompileOp() override;
99   XRTCompileOp(const XRTCompileOp&) = delete;
100   XRTCompileOp& operator=(const XRTCompileOp&) = delete;
101 
102   void Compute(OpKernelContext* ctx) override;
103 
104  private:
105   Status Compile(OpKernelContext* ctx,
106                  const xrt::XLAComputation& computation_proto,
107                  std::unique_ptr<xla::LocalExecutable>* program);
108 };
109 
CompilationCacheKey(const xrt::XLAComputation & computation,string * key)110 Status CompilationCacheKey(const xrt::XLAComputation& computation,
111                            string* key) {
112   const size_t size = computation.ByteSizeLong();
113   auto serialized = absl::make_unique<char[]>(size);
114   TF_RET_CHECK(
115       SerializeToBufferDeterministic(computation, serialized.get(), size));
116   uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
117   *key = absl::StrCat(fingerprint);
118   return OkStatus();
119 }
120 
XRTCompileOp(OpKernelConstruction * ctx)121 XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
122 
Compile(OpKernelContext * ctx,const xrt::XLAComputation & computation_proto,std::unique_ptr<xla::LocalExecutable> * program)123 Status XRTCompileOp::Compile(OpKernelContext* ctx,
124                              const xrt::XLAComputation& computation_proto,
125                              std::unique_ptr<xla::LocalExecutable>* program) {
126   const xrt::XLAComputationConfig& config = computation_proto.config();
127   // Sanity checks for options not yet supported.
128   int num_cores_per_replica = std::max<int>(config.num_cores_per_replica(), 1);
129   TF_RET_CHECK(num_cores_per_replica == 1);
130   TF_RET_CHECK(config.per_core_program_shape_size() == 0);
131 
132   // The default config value is 0; treat it as 1 for convenience.
133   int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
134 
135   // We are guaranteed that the underlying device object won't be deleted out
136   // from under us, while the ScopedRef is live.
137   class XRTGenericDeviceAccessor::ScopedRef device_ref;
138   TF_RETURN_IF_ERROR(XRTGenericDeviceAccessor::InitScopedRef(ctx, &device_ref));
139 
140   xla::LocalClient* client = device_ref.client();
141 
142   // There is officially no way to use XLA in a client/server architecture where
143   // client and server are built from different revisions, because the XLA team
144   // does not want to give any guarantees about the stability of the Hlo
145   // proto. For cloud TPU this is fine because server and client versions can be
146   // assumed to be synced to the same version. For general use the mechanism
147   // here (using a snapshot from XlaComputation) works as well as the "official"
148   // XLA client/server design, which serializes the same proto between client
149   // and server, so in reality is probably fine.
150   TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
151                       client->LoadSnapshot(computation_proto.hlo_snapshot()));
152 
153   std::vector<xla::Shape> argument_layouts(
154       config.program_shape().parameters_size());
155   std::vector<const xla::Shape*> argument_layout_ptrs(
156       config.program_shape().parameters_size());
157   for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
158     argument_layouts[i] = xla::Shape(config.program_shape().parameters(i));
159     argument_layout_ptrs[i] = &argument_layouts[i];
160   }
161   xla::ExecutableBuildOptions build_options;
162   build_options.set_device_ordinal(device_ref.device_ordinal());
163   build_options.set_num_replicas(num_replicas);
164   build_options.set_result_layout(xla::Shape(config.program_shape().result()));
165   build_options.set_device_allocator(device_ref.allocator());
166   if (config.has_debug_options()) {
167     *build_options.mutable_debug_options() =
168         BuildXlaDebugOptions(config.debug_options());
169   }
170   if (config.has_device_assignment()) {
171     xla::DeviceAssignment device_assignment(num_replicas,
172                                             num_cores_per_replica);
173     TF_RETURN_IF_ERROR(
174         GenerateXlaDeviceAssignment(config.device_assignment(), num_replicas,
175                                     num_cores_per_replica, &device_assignment));
176     build_options.set_device_assignment(device_assignment);
177   }
178 
179   VLOG(1) << "Building executable";
180   TF_ASSIGN_OR_RETURN(
181       auto executables,
182       client->Compile(computation, argument_layout_ptrs, build_options));
183   TF_RET_CHECK(executables.size() == 1);
184   *program = std::move(executables[0]);
185   return OkStatus();
186 }
187 
Compute(OpKernelContext * ctx)188 void XRTCompileOp::Compute(OpKernelContext* ctx) {
189   VLOG(1) << "XRTCompileOp::Compute";
190   auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell());
191 
192   ResourceMgr* rm;
193   OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
194 
195   const Tensor& computation_input = ctx->input(0);
196   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
197               errors::Internal("computation input should be a string scalar"));
198 
199   xrt::XLAComputation computation_proto;
200   OP_REQUIRES(ctx,
201               ParseFromTString(computation_input.scalar<tstring>()(),
202                                &computation_proto),
203               errors::InvalidArgument(
204                   "Unable to parse computation input to XLAComputation"));
205 
206   string key;
207   OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
208 
209   // Process-wide cache of XLA executables.
210   auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
211       ctx, /*max_number_of_entries=*/0);
212   OP_REQUIRES_OK(ctx, cache_or.status());
213   auto cache = std::move(cache_or).value();
214 
215   int64_t uid;
216   OP_REQUIRES_OK(
217       ctx, cache->CompileIfKeyAbsent(
218                key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
219                  VLOG(1) << "Compiling XLA executable";
220                  return Compile(ctx, computation_proto, program);
221                }));
222   std::unique_ptr<XRTCompilationCacheEntryRef> entry;
223   OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
224 
225   Tensor handle_output(DT_INT64, TensorShape({}));
226   handle_output.scalar<int64_t>()() = uid;
227   ctx->set_output(0, handle_output);
228 
229   xla::LocalExecutable* executable = entry->get().get_executable();
230   xla::ProgramShapeProto program_shape = executable->executable()
231                                              ->module()
232                                              .config()
233                                              .entry_computation_layout()
234                                              .ComputeProgramShape()
235                                              .ToProto();
236   Tensor program_shape_output(DT_STRING, TensorShape({1}));
237   program_shape_output.vec<tstring>()(0) = program_shape.SerializeAsString();
238   ctx->set_output(1, program_shape_output);
239 }
240 
241 XRTCompileOp::~XRTCompileOp() = default;
242 
243 class XRTReleaseCompilationRefOp : public OpKernel {
244  public:
245   explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
246   ~XRTReleaseCompilationRefOp() override;
247   XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
248   XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
249       delete;
250 
251   void Compute(OpKernelContext* ctx) override;
252 };
253 
XRTReleaseCompilationRefOp(OpKernelConstruction * ctx)254 XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
255     OpKernelConstruction* ctx)
256     : OpKernel(ctx) {}
257 
258 XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
259 
Compute(OpKernelContext * ctx)260 void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
261   VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
262   auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
263 
264   // Process-wide cache of XLA executables.
265   auto cache_or = XRTGenericDeviceAccessor::GetOrCreateCompilationCache(
266       ctx, /*max_number_of_entries=*/0);
267   OP_REQUIRES_OK(ctx, cache_or.status());
268   auto cache = std::move(cache_or).value();
269 
270   const Tensor& keys_tensor = ctx->input(0);
271   auto flat_keys = keys_tensor.flat<int64_t>();
272   for (int64_t i = 0; i < flat_keys.size(); ++i) {
273     int64_t key = flat_keys(i);
274     OP_REQUIRES_OK(ctx, cache->Release(key));
275     VLOG(2) << "Released computation handle " << key;
276   }
277 }
278 
279 }  // namespace
280 
281 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
282                             .Device(DEVICE_XLA_CPU)
283                             .HostMemory("computation")
284                             .HostMemory("handle"),
285                         XRTCompileOp);
286 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
287                             .Device(DEVICE_XLA_GPU)
288                             .HostMemory("computation")
289                             .HostMemory("handle"),
290                         XRTCompileOp);
291 
292 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
293                             .Device(DEVICE_XLA_CPU)
294                             .HostMemory("handle"),
295                         XRTReleaseCompilationRefOp);
296 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
297                             .Device(DEVICE_XLA_GPU)
298                             .HostMemory("handle"),
299                         XRTReleaseCompilationRefOp);
300 
301 }  // namespace tensorflow
302