• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
21 
22 #include <functional>
23 #include <memory>
24 #include <string>
25 
26 #include "tensorflow/compiler/tf2xla/literal_util.h"
27 #include "tensorflow/compiler/tf2xla/shape_util.h"
28 #include "tensorflow/compiler/tf2xla/type_util.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/compiler/xrt/xrt.pb.h"
36 #include "tensorflow/compiler/xrt/xrt_device.h"
37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
38 #include "tensorflow/compiler/xrt/xrt_metrics.h"
39 #include "tensorflow/compiler/xrt/xrt_state.h"
40 #include "tensorflow/core/common_runtime/dma_helper.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_shape.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/refcount.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
51 #include "tensorflow/core/lib/monitoring/timed.h"
52 #include "tensorflow/core/platform/types.h"
53 
54 namespace tensorflow {
55 
56 // Helper functions for templated ops.
57 class XRTStateHelpers {
58  public:
59   // The Status return value allows us to use the
60   // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
61   // OpKernel::Compute method.
MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62   static Status MakeLiteral(const xla::LiteralProto& proto,
63                             xla::Literal* literal) {
64     TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
65     return Status::OK();
66   }
67 
68   // ParseTupleNode is the recursive function used to parse a recursive
69   // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
70   // tuple shape where every leaf is an existing allocation. As a side-effect it
71   // fills in input_vector by looking up allocations from handles in the
72   // input_tensor_list as they are referenced by nodes in the proto.
ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73   static Status ParseTupleNode(
74       const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
75       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
76       xla::Shape* shape, ResourceMgr* rm) {
77     if (tuple_node.tuples_size() > 0) {
78       // This is an internal node in the proto so descend recursively.
79       xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
80       std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
81       *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
82           xla::ShapeUtil::MakeTupleShape(subshapes);
83       for (int i = 0; i < tuple_node.tuples_size(); ++i) {
84         TF_RETURN_IF_ERROR(ParseTupleNode(
85             tuple_node.tuples(i), input_tensor_list, input_vector,
86             xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
87       }
88     } else {
89       // This is a leaf node in the proto so look up the referenced input.
90       int input_index = tuple_node.input_index();
91       if (input_index < 0 || input_index >= input_vector->size()) {
92         return errors::InvalidArgument("Invalid tuple input index ",
93                                        input_index, ": MakeTuple has ",
94                                        input_vector->size(), " inputs.");
95       }
96       bool release_this_input = tuple_node.release_input_handle();
97       XRTTupleAllocation::ExpandedTupleInput& input =
98           input_vector->at(input_index);
99       if (input.allocation != nullptr &&
100           (input.release_allocation_after_use || release_this_input)) {
101         return errors::InvalidArgument(
102             "Invalid tuple tree: input index ", input_index,
103             " is repeated but release_input_handle is true.");
104       }
105       if (input.allocation == nullptr) {
106         // We haven't dereferenced this handle yet.
107         TF_RET_CHECK(
108             TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
109         int64_t key = input_tensor_list[input_index].scalar<int64>()();
110         TF_ASSIGN_OR_RETURN(input.allocation,
111                             XRTMemoryManager::Get(rm)->Lookup(key));
112         input.release_allocation_after_use = release_this_input;
113       }
114     }
115     return Status::OK();
116   }
117 
118   // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
119   // ShapeTree where each leaf is an allocation corresponding to a handle in
120   // input_tensor_list. The ordinal of one of the allocations is returned in
121   // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
122   // no leaves, device_ordinal will always be filled in by a successful call to
123   // ParseTupleTree.
ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124   static Status ParseTupleTree(
125       const xrt::XLATupleNode& tuple_tree_root,
126       const OpInputList& input_tensor_list,
127       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
128       xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
129       int* device_ordinal, ResourceMgr* rm) {
130     // First get the shape of the 'spine' of the new tuple, where every leaf is
131     // an existing allocation. As a side-effect dereference the input handles
132     // into allocations in input_vector.
133     xla::Shape tuple_tree_shape;
134     TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
135                                       input_vector, &tuple_tree_shape, rm));
136     // Make the shape tree of allocations where the shape is the spine and each
137     // leaf is one of the allocations looked up in input_vector. Internal nodes
138     // have nullptr allocations.
139     *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
140         tuple_tree_shape);
141     tuple_shape_tree->ForEachMutableElement(
142         [&](const xla::ShapeIndex& index,
143             XRTTupleAllocation::ExpandedTupleInput* element) {
144           if (tuple_shape_tree->IsLeaf(index)) {
145             // Find the matching leaf in the proto tree.
146             const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
147             for (int i = 0; i < index.size(); ++i) {
148               tuple_node = &tuple_node->tuples(index[i]);
149             }
150             // Copy the appropriate input allocation to the leaf of the
151             // tuple_shape_tree.
152             int input_index = tuple_node->input_index();
153             *element = input_vector->at(input_index);
154             CHECK(element->release_allocation_after_use ==
155                   tuple_node->release_input_handle());
156             // We just need to know the device_ordinal of one of the
157             // allocations. We will validate later that they are all the same.
158             *device_ordinal = (*element).allocation->device_ordinal();
159           }
160         });
161     return Status::OK();
162   }
163 };
164 
165 // Op that allocates memory for a literal and transfers it to the device.
166 template <class DeviceAccessor>
167 class XRTAllocateOp : public OpKernel {
168  public:
XRTAllocateOp(OpKernelConstruction * ctx)169   explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170   ~XRTAllocateOp() override = default;
171   XRTAllocateOp(const XRTAllocateOp&) = delete;
172   XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
173 
Compute(OpKernelContext * ctx)174   void Compute(OpKernelContext* ctx) override {
175     VLOG(1) << "XRTAllocateOp::Compute";
176     auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell());
177 
178     const Tensor& allocation_info = ctx->input(0);
179     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
180                 errors::Internal("allocation input should be a string scalar"));
181     xrt::XLAAllocation allocation_proto;
182     OP_REQUIRES(ctx,
183                 ParseFromTString(allocation_info.scalar<tstring>()(),
184                                  &allocation_proto),
185                 errors::InvalidArgument(
186                     "Unable to parse allocation input to XLAAllocation"));
187 
188     xla::Literal literal;
189     OP_REQUIRES_OK(
190         ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
191 
192     ResourceMgr* rm;
193     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
194 
195     // We are guaranteed that the underlying device object won't be deleted out
196     // from under us, while the ScopedRef is live.
197     class DeviceAccessor::ScopedRef device_ref;
198     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
199 
200     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
201     XRTTupleAllocation* allocation;
202     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
203                             literal, memory_manager.get(), device_ref.backend(),
204                             device_ref.device_ordinal(), &allocation,
205                             device_ref.allocator()));
206 
207     Tensor output(DT_INT64, TensorShape({}));
208     output.scalar<int64>()() = memory_manager->Register(allocation);
209     ctx->set_output(0, output);
210   }
211 };
212 
213 // Op that allocates uninitialized memory on the device for a tensor of
214 // a particular shape.
215 template <class DeviceAccessor>
216 class XRTAllocateUninitializedOp : public OpKernel {
217  public:
XRTAllocateUninitializedOp(OpKernelConstruction * ctx)218   explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
219       : OpKernel(ctx) {
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
221     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
222     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
223   }
224   ~XRTAllocateUninitializedOp() override = default;
225   XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
226   XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
227       delete;
228 
Compute(OpKernelContext * ctx)229   void Compute(OpKernelContext* ctx) override {
230     VLOG(1) << "XRTAllocateUninitializedOp::Compute";
231     auto timed =
232         monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell());
233     ResourceMgr* rm;
234     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
235 
236     // We are guaranteed that the underlying device object won't be deleted out
237     // from under us, while the ScopedRef is live.
238     class DeviceAccessor::ScopedRef device_ref;
239     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
240 
241     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
242     XRTTupleAllocation* allocation;
243     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateUninitialized(
244                             xla_shape_, memory_manager.get(),
245                             device_ref.backend(), device_ref.device_ordinal(),
246                             &allocation, device_ref.allocator()));
247 
248     Tensor output(DT_INT64, TensorShape({}));
249     output.scalar<int64>()() = memory_manager->Register(allocation);
250     ctx->set_output(0, output);
251   }
252 
253  private:
254   DataType dtype_;
255   TensorShape tf_shape_;
256   xla::Shape xla_shape_;
257 };
258 
259 // Op that allocates memory for a tensor (with optional layout) and transfers it
260 // to the device, returning an allocation handle.
261 template <class DeviceAccessor>
262 class XRTAllocateFromTensorOp : public OpKernel {
263  public:
XRTAllocateFromTensorOp(OpKernelConstruction * ctx)264   explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
265     bool make_tuple = false;
266     OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
267     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
268     OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
269     std::vector<int64> minor_to_major;
270     if (ctx->HasAttr("layouts")) {
271       OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major));
272     }
273     OP_REQUIRES(
274         ctx, tf_shapes_.size() == dtypes_.size(),
275         errors::InvalidArgument("shapes and dtypes must be the same length"));
276     std::vector<xla::Shape> xla_shapes;
277     xla_shapes.reserve(tf_shapes_.size());
278     for (int i = 0; i < tf_shapes_.size(); i++) {
279       xla::Shape xla_shape;
280       OP_REQUIRES_OK(
281           ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
282       xla_shapes.push_back(std::move(xla_shape));
283     }
284     if (xla_shapes.size() > 1 || make_tuple) {
285       shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
286     } else {
287       shape_.Swap(&xla_shapes.front());
288     }
289     if (!minor_to_major.empty()) {
290       xla::Shape shape_with_layouts;
291       OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major,
292                                              /*layout_func=*/nullptr,
293                                              &shape_with_layouts));
294       shape_.Swap(&shape_with_layouts);
295     }
296   }
297 
298   ~XRTAllocateFromTensorOp() override = default;
299   XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete;
300   XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete;
301 
Compute(OpKernelContext * ctx)302   void Compute(OpKernelContext* ctx) override {
303     VLOG(1) << "XRTAllocateFromTensorOp::Compute";
304     auto timed =
305         monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell());
306 
307     OpInputList values;
308     OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
309     OP_REQUIRES(ctx, values.size() == tf_shapes_.size(),
310                 errors::InvalidArgument(
311                     "Wrong number of inputs to XRTAllocateFromTensor: ",
312                     values.size(), " vs. ", tf_shapes_.size()));
313 
314     std::vector<const char*> tensors_data;
315     for (size_t i = 0; i < values.size(); ++i) {
316       const Tensor& input_tensor = values[i];
317       OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
318                   errors::InvalidArgument(
319                       "Input tensor type and input dtype do not match"));
320       // We allow the requested on-device shape to differ from the shape of the
321       // input tensor, as long as they have the same number of elements.
322       OP_REQUIRES(
323           ctx,
324           input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(),
325           errors::InvalidArgument(
326               "Input tensor must have the number of elements specified "
327               "in the matching input shape: ",
328               input_tensor.shape().num_elements(), " vs. ",
329               tf_shapes_[i].num_elements(), " at index ", i));
330       tensors_data.push_back(
331           static_cast<const char*>(DMAHelper::base(&input_tensor)));
332     }
333     // Use the buffer straight out of the input tensors to create the literal.
334     xla::BorrowingLiteral literal =
335         shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_)
336                          : xla::BorrowingLiteral(tensors_data.front(), shape_);
337     ResourceMgr* rm;
338     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
339 
340     // We are guaranteed that the underlying device object won't be deleted out
341     // from under us, while the ScopedRef is live.
342     class DeviceAccessor::ScopedRef device_ref;
343     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
344 
345     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
346     XRTTupleAllocation* allocation;
347     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
348                             literal, memory_manager.get(), device_ref.backend(),
349                             device_ref.device_ordinal(), &allocation,
350                             device_ref.allocator()));
351 
352     Tensor output(DT_INT64, TensorShape({}));
353     output.scalar<int64>()() = memory_manager->Register(allocation);
354     ctx->set_output(0, output);
355   }
356 
357  private:
358   std::vector<TensorShape> tf_shapes_;
359   DataTypeVector dtypes_;
360   xla::Shape shape_;
361 };
362 
363 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the
364 // input.
365 template <bool discard_, class DeviceAccessor>
366 class XRTSubTupleOp : public OpKernel {
367  public:
XRTSubTupleOp(OpKernelConstruction * ctx)368   explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
369   ~XRTSubTupleOp() override = default;
370   XRTSubTupleOp(const XRTSubTupleOp&) = delete;
371   XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
372 
Compute(OpKernelContext * ctx)373   void Compute(OpKernelContext* ctx) override {
374     VLOG(1) << "XRTSubTupleOp::Compute";
375     auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell());
376 
377     const Tensor& handle_tensor = ctx->input(0);
378     OP_REQUIRES(
379         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
380         errors::Internal("computation input should be an int64 scalar"));
381     int64_t allocation_handle = handle_tensor.scalar<int64>()();
382 
383     const Tensor& subtuple_info = ctx->input(1);
384     OP_REQUIRES(
385         ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
386         errors::Internal("tuple index input should be an int32 vector"));
387     xla::ShapeIndex shape_index;
388     for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
389       shape_index.push_back(subtuple_info.vec<int32>()(i));
390     }
391 
392     ResourceMgr* rm;
393     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
394 
395     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
396     RefPtr<XRTTupleAllocation> allocation;
397     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
398 
399     if (discard_) {
400       VLOG(2) << "Releasing handle " << allocation_handle;
401       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
402     }
403 
404     XRTTupleAllocation* suballocation;
405     OP_REQUIRES_OK(
406         ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index,
407                                                &suballocation, !discard_));
408 
409     Tensor output(DT_INT64, TensorShape({}));
410     output.scalar<int64>()() = memory_manager->Register(suballocation);
411     ctx->set_output(0, output);
412   }
413 };
414 
415 // Op that allocates memory for a literal and transfers it to the device.
416 template <class DeviceAccessor>
417 class XRTMakeTupleOp : public OpKernel {
418  public:
XRTMakeTupleOp(OpKernelConstruction * ctx)419   explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
420   ~XRTMakeTupleOp() override = default;
421   XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
422   XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
423 
Compute(OpKernelContext * ctx)424   void Compute(OpKernelContext* ctx) override {
425     VLOG(1) << "XRTMakeTupleOp::Compute";
426     auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell());
427 
428     const Tensor& tuple_info = ctx->input(0);
429     OP_REQUIRES(
430         ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
431         errors::Internal("tuple description input should be a string scalar"));
432     xrt::XLATupleNode tuple_proto;
433     OP_REQUIRES(
434         ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto),
435         errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
436 
437     OpInputList arg_list;
438     OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
439 
440     // For each input, the allocation it corresponds to and a flag indicating
441     // whether or not it should be released, i.e. discarded from the resource
442     // manager. One ref on each allocation is owned by this vector, and freed on
443     // exit.
444     std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
445         arg_list.size());
446     ResourceMgr* rm;
447     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
448 
449     xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
450     // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
451     // the allocations. It is guaranteed that there is at least on allocation in
452     // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
453     // all the allocations are on the same device.
454     int device_ordinal;
455     OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
456                             tuple_proto, arg_list, &input_vector,
457                             &tuple_shape_tree, &device_ordinal, rm));
458 
459     // We are guaranteed that the underlying device object won't be deleted out
460     // from under us, while the ScopedRef is live.
461     class DeviceAccessor::ScopedRef device_ref;
462     OP_REQUIRES_OK(
463         ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
464 
465     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
466     XRTTupleAllocation* output_allocation;
467     OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
468                             memory_manager.get(), device_ref.backend(),
469                             device_ref.device_ordinal(), tuple_shape_tree,
470                             &output_allocation, device_ref.allocator()));
471     RefPtr<XRTTupleAllocation> output_ptr(output_allocation);
472     for (int i = 0; i < input_vector.size(); ++i) {
473       if (input_vector[i].release_allocation_after_use) {
474         OP_REQUIRES_OK(ctx,
475                        memory_manager->Release(arg_list[i].scalar<int64>()()));
476       }
477     }
478 
479     Tensor output(DT_INT64, TensorShape({}));
480     output.scalar<int64>()() = memory_manager->Register(std::move(output_ptr));
481     ctx->set_output(0, output);
482   }
483 };
484 
485 // Op that reads a device-resident tuple to host memory and returns it as a
486 // literal.
487 template <bool discard_, class DeviceAccessor>
488 class XRTReadLiteralOp : public OpKernel {
489  public:
XRTReadLiteralOp(OpKernelConstruction * ctx)490   explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
491   ~XRTReadLiteralOp() override = default;
492   XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
493   XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
494 
Compute(OpKernelContext * ctx)495   void Compute(OpKernelContext* ctx) override {
496     VLOG(1) << "XRTReadLiteralOp::Compute";
497     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell());
498 
499     const Tensor& handle_tensor = ctx->input(0);
500     OP_REQUIRES(
501         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
502         errors::Internal("computation input should be an int64 scalar"));
503     int64_t allocation_handle = handle_tensor.scalar<int64>()();
504 
505     ResourceMgr* rm;
506     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
507 
508     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
509     RefPtr<XRTTupleAllocation> allocation;
510     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
511 
512     if (discard_) {
513       VLOG(2) << "Releasing handle " << allocation_handle;
514       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
515     }
516 
517     // We are guaranteed that the underlying device object won't be deleted out
518     // from under us, while the ScopedRef is live.
519     class DeviceAccessor::ScopedRef device_ref;
520     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
521                             ctx, allocation->device_ordinal(), &device_ref));
522 
523     xla::Literal literal(allocation->on_host_shape());
524     OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal));
525     xla::LiteralProto literal_proto = literal.ToProto();
526 
527     Tensor output(DT_STRING, TensorShape({}));
528     SerializeToTString(literal_proto, &output.scalar<tstring>()());
529     ctx->set_output(0, output);
530   }
531 };
532 
533 // Op that reads a device-resident tuple to host memory and returns it as a
534 // literal.
535 template <class DeviceAccessor>
536 class XRTReadToTensorOp : public OpKernel {
537  public:
XRTReadToTensorOp(OpKernelConstruction * ctx)538   explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
539     OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_));
540     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
541   }
542   ~XRTReadToTensorOp() override = default;
543   XRTReadToTensorOp(const XRTReadToTensorOp&) = delete;
544   XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete;
545 
Compute(OpKernelContext * ctx)546   void Compute(OpKernelContext* ctx) override {
547     VLOG(1) << "XRTReadToTensorOp::Compute";
548     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell());
549 
550     const Tensor& handle_tensor = ctx->input(0);
551     // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not
552     // just scalars.)
553     OP_REQUIRES(
554         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
555         errors::Internal("computation input should be an int64 scalar"));
556     int64_t allocation_handle = handle_tensor.scalar<int64>()();
557 
558     ResourceMgr* rm;
559     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
560 
561     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
562     RefPtr<XRTTupleAllocation> allocation;
563     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
564 
565     if (discard_) {
566       VLOG(2) << "Releasing handle " << allocation_handle;
567       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
568     }
569 
570     // We are guaranteed that the underlying device object won't be deleted out
571     // from under us, while the ScopedRef is live.
572     class DeviceAccessor::ScopedRef device_ref;
573     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
574                             ctx, allocation->device_ordinal(), &device_ref));
575 
576     xla::Shape shape = allocation->on_host_shape();
577     int output = 0;
578     Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus(
579         &shape,
580         [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status {
581           if (subshape->IsTuple()) return Status::OK();
582 
583           xla::PrimitiveType xla_type;
584           TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(
585               ctx->expected_output_dtype(output), &xla_type));
586           if (xla_type != subshape->element_type()) {
587             return errors::InvalidArgument(
588                 "Type mismatch between buffer type (", subshape->ToString(),
589                 ") and tensor type (",
590                 DataTypeString(ctx->expected_output_dtype(output)),
591                 ") for output tensor ", output);
592           }
593 
594           TensorShape output_shape;
595           TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape));
596 
597           Tensor* output_tensor;
598           TF_RETURN_IF_ERROR(
599               ctx->allocate_output(output, output_shape, &output_tensor));
600 
601           XRTTupleAllocation* sub;
602           TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
603               allocation.get(), index, &sub, /*alias_parent_allocation=*/true));
604           core::ScopedUnref sub_unref(sub);
605 
606           xla::MutableBorrowingLiteral literal;
607           TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
608               xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
609               &literal));
610           TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal));
611 
612           ++output;
613           return Status::OK();
614         });
615     OP_REQUIRES_OK(ctx, status);
616   }
617   bool discard_;
618   DataTypeVector dtypes_;
619 };
620 
621 // Op that writes a new literal value into device-resident memory.
622 template <class DeviceAccessor>
623 class XRTWriteLiteralOp : public OpKernel {
624  public:
XRTWriteLiteralOp(OpKernelConstruction * ctx)625   explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
626   ~XRTWriteLiteralOp() override = default;
627   XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
628   XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
629 
Compute(OpKernelContext * ctx)630   void Compute(OpKernelContext* ctx) override {
631     VLOG(1) << "XRTWriteLiteralOp::Compute";
632     auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell());
633 
634     const Tensor& handle_tensor = ctx->input(0);
635     OP_REQUIRES(
636         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
637         errors::Internal("computation input should be an int64 scalar"));
638     int64_t allocation_handle = handle_tensor.scalar<int64>()();
639 
640     const Tensor& literal_info = ctx->input(1);
641     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
642                 errors::Internal("literal input should be a string scalar"));
643     xla::LiteralProto literal_proto;
644     OP_REQUIRES(
645         ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto),
646         errors::InvalidArgument(
647             "Unable to parse allocation input to LiteralProto"));
648     xla::Literal literal;
649     OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
650 
651     ResourceMgr* rm;
652     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
653 
654     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
655     RefPtr<XRTTupleAllocation> allocation;
656     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
657 
658     // We are guaranteed that the underlying device object won't be deleted out
659     // from under us, while the ScopedRef is live.
660     typename DeviceAccessor::ScopedRef device_ref;
661     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
662                             ctx, allocation->device_ordinal(), &device_ref));
663     OP_REQUIRES_OK(ctx,
664                    allocation->WriteLiteral(device_ref.backend(), literal));
665 
666     Tensor output(DT_INT64, TensorShape({}));
667     output.scalar<int64>()() = allocation_handle;
668     ctx->set_output(0, output);
669   }
670 };
671 
672 // Op that discards a handle to device memory.
673 template <class DeviceAccessor>
674 class XRTReleaseAllocationOp : public OpKernel {
675  public:
XRTReleaseAllocationOp(OpKernelConstruction * ctx)676   explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
677   ~XRTReleaseAllocationOp() override = default;
678   XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
679   XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
680 
Compute(OpKernelContext * ctx)681   void Compute(OpKernelContext* ctx) override {
682     VLOG(1) << "XRTReleaseAllocationOp::Compute";
683     auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell());
684 
685     ResourceMgr* rm;
686     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
687 
688     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
689     const Tensor& allocation_handle = ctx->input(0);
690     auto flat_keys = allocation_handle.flat<int64>();
691     for (int64_t i = 0; i < flat_keys.size(); ++i) {
692       int64_t key = flat_keys(i);
693       OP_REQUIRES_OK(ctx, memory_manager->Release(key));
694       VLOG(2) << "Released allocation handle " << key;
695     }
696   }
697 };
698 
699 // Op that discards a handle to device memory.
700 template <class DeviceAccessor>
701 class XRTReleaseAllAllocationsOp : public OpKernel {
702  public:
XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)703   explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx)
704       : OpKernel(ctx) {}
705   ~XRTReleaseAllAllocationsOp() override = default;
706   XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete;
707   XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) =
708       delete;
709 
Compute(OpKernelContext * ctx)710   void Compute(OpKernelContext* ctx) override {
711     VLOG(1) << "XRTReleaseAllAllocationsOp::Compute";
712     auto timed =
713         monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell());
714 
715     ResourceMgr* rm;
716     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
717     XRTMemoryManager::Get(rm)->ReleaseAllAllocations();
718   }
719 };
720 
721 template <class DeviceAccessor>
722 class XRTCompactAllocationsOp : public OpKernel {
723  public:
XRTCompactAllocationsOp(OpKernelConstruction * ctx)724   explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
725   ~XRTCompactAllocationsOp() override = default;
726   XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete;
727   XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete;
728 
Compute(OpKernelContext * ctx)729   void Compute(OpKernelContext* ctx) override {
730     VLOG(1) << "XRTCompactAllocationsOp::Compute";
731     auto timed =
732         monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell());
733 
734     ResourceMgr* rm;
735     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
736     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
737     class DeviceAccessor::ScopedRef device_ref;
738     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
739     OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations(
740                             device_ref.backend(), device_ref.device_ordinal(),
741                             device_ref.allocator()));
742   }
743 };
744 
745 template <class DeviceAccessor>
746 class XRTMemoryInfoOp : public OpKernel {
747  public:
XRTMemoryInfoOp(OpKernelConstruction * ctx)748   explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
749   ~XRTMemoryInfoOp() override = default;
750   XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete;
751   XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete;
752 
Compute(OpKernelContext * ctx)753   void Compute(OpKernelContext* ctx) override {
754     auto kernel_fn = [&]() -> Status {
755       VLOG(1) << "XRTMemoryInfoOp::Compute";
756 
757       class DeviceAccessor::ScopedRef device_ref;
758       TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref));
759       TF_ASSIGN_OR_RETURN(
760           se::StreamExecutor * stream_executor,
761           device_ref.backend()->stream_executor(device_ref.device_ordinal()));
762       int64_t mem_free = -1;
763       int64_t mem_total = -1;
764       if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) {
765         VLOG(2) << "Device " << ctx->device()->name()
766                 << " does not expose memory information";
767       }
768       xrt::MemoryInfo mem_info;
769       mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1);
770       mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1);
771 
772       Tensor output(DT_STRING, TensorShape({}));
773       output.scalar<tstring>()() = mem_info.SerializeAsString();
774       ctx->set_output(0, output);
775       return Status::OK();
776     };
777     OP_REQUIRES_OK(ctx, kernel_fn());
778   }
779 };
780 
781 }  // namespace tensorflow
782 
783 #endif  // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
784