• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
21 
22 #include <functional>
23 #include <memory>
24 #include <string>
25 
26 #include "tensorflow/compiler/tf2xla/literal_util.h"
27 #include "tensorflow/compiler/tf2xla/shape_util.h"
28 #include "tensorflow/compiler/tf2xla/type_util.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/compiler/xrt/xrt.pb.h"
36 #include "tensorflow/compiler/xrt/xrt_device.h"
37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
38 #include "tensorflow/compiler/xrt/xrt_metrics.h"
39 #include "tensorflow/compiler/xrt/xrt_state.h"
40 #include "tensorflow/core/common_runtime/dma_helper.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/resource_mgr.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/tensor_shape.h"
45 #include "tensorflow/core/framework/types.pb.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/core/refcount.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/cleanup.h"
50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h"
51 #include "tensorflow/core/lib/monitoring/timed.h"
52 #include "tensorflow/core/platform/types.h"
53 
54 namespace tensorflow {
55 
56 // Helper functions for templated ops.
57 class XRTStateHelpers {
58  public:
59   // The Status return value allows us to use the
60   // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
61   // OpKernel::Compute method.
MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62   static Status MakeLiteral(const xla::LiteralProto& proto,
63                             xla::Literal* literal) {
64     TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
65     return Status::OK();
66   }
67 
68   // ParseTupleNode is the recursive function used to parse a recursive
69   // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
70   // tuple shape where every leaf is an existing allocation. As a side-effect it
71   // fills in input_vector by looking up allocations from handles in the
72   // input_tensor_list as they are referenced by nodes in the proto.
ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73   static Status ParseTupleNode(
74       const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
75       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
76       xla::Shape* shape, ResourceMgr* rm) {
77     if (tuple_node.tuples_size() > 0) {
78       // This is an internal node in the proto so descend recursively.
79       xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
80       std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
81       *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
82           xla::ShapeUtil::MakeTupleShape(subshapes);
83       for (int i = 0; i < tuple_node.tuples_size(); ++i) {
84         TF_RETURN_IF_ERROR(ParseTupleNode(
85             tuple_node.tuples(i), input_tensor_list, input_vector,
86             xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
87       }
88     } else {
89       // This is a leaf node in the proto so look up the referenced input.
90       int input_index = tuple_node.input_index();
91       if (input_index < 0 || input_index >= input_vector->size()) {
92         return errors::InvalidArgument("Invalid tuple input index ",
93                                        input_index, ": MakeTuple has ",
94                                        input_vector->size(), " inputs.");
95       }
96       bool release_this_input = tuple_node.release_input_handle();
97       XRTTupleAllocation::ExpandedTupleInput& input =
98           input_vector->at(input_index);
99       if (input.allocation != nullptr &&
100           (input.release_allocation_after_use || release_this_input)) {
101         return errors::InvalidArgument(
102             "Invalid tuple tree: input index ", input_index,
103             " is repeated but release_input_handle is true.");
104       }
105       if (input.allocation == nullptr) {
106         // We haven't dereferenced this handle yet.
107         TF_RET_CHECK(
108             TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
109         int64 key = input_tensor_list[input_index].scalar<int64>()();
110         TF_ASSIGN_OR_RETURN(input.allocation,
111                             XRTMemoryManager::Get(rm)->Lookup(key));
112         input.release_allocation_after_use = release_this_input;
113       }
114     }
115     return Status::OK();
116   }
117 
118   // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
119   // ShapeTree where each leaf is an allocation corresponding to a handle in
120   // input_tensor_list. The ordinal of one of the allocations is returned in
121   // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
122   // no leaves, device_ordinal will always be filled in by a successful call to
123   // ParseTupleTree.
ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124   static Status ParseTupleTree(
125       const xrt::XLATupleNode& tuple_tree_root,
126       const OpInputList& input_tensor_list,
127       std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
128       xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
129       int* device_ordinal, ResourceMgr* rm) {
130     // First get the shape of the 'spine' of the new tuple, where every leaf is
131     // an existing allocation. As a side-effect dereference the input handles
132     // into allocations in input_vector.
133     xla::Shape tuple_tree_shape;
134     TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
135                                       input_vector, &tuple_tree_shape, rm));
136     // Make the shape tree of allocations where the shape is the spine and each
137     // leaf is one of the allocations looked up in input_vector. Internal nodes
138     // have nullptr allocations.
139     *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
140         tuple_tree_shape);
141     tuple_shape_tree->ForEachMutableElement(
142         [&](const xla::ShapeIndex& index,
143             XRTTupleAllocation::ExpandedTupleInput* element) {
144           if (tuple_shape_tree->IsLeaf(index)) {
145             // Find the matching leaf in the proto tree.
146             const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
147             for (int i = 0; i < index.size(); ++i) {
148               tuple_node = &tuple_node->tuples(index[i]);
149             }
150             // Copy the appropriate input allocation to the leaf of the
151             // tuple_shape_tree.
152             int input_index = tuple_node->input_index();
153             *element = input_vector->at(input_index);
154             CHECK(element->release_allocation_after_use ==
155                   tuple_node->release_input_handle());
156             // We just need to know the device_ordinal of one of the
157             // allocations. We will validate later that they are all the same.
158             *device_ordinal = (*element).allocation->device_ordinal();
159           }
160         });
161     return Status::OK();
162   }
163 };
164 
165 // Op that allocates memory for a literal and transfers it to the device.
166 template <class DeviceAccessor>
167 class XRTAllocateOp : public OpKernel {
168  public:
XRTAllocateOp(OpKernelConstruction * ctx)169   explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
170   ~XRTAllocateOp() override = default;
171   XRTAllocateOp(const XRTAllocateOp&) = delete;
172   XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
173 
Compute(OpKernelContext * ctx)174   void Compute(OpKernelContext* ctx) override {
175     VLOG(1) << "XRTAllocateOp::Compute";
176     auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell());
177 
178     const Tensor& allocation_info = ctx->input(0);
179     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
180                 errors::Internal("allocation input should be a string scalar"));
181     xrt::XLAAllocation allocation_proto;
182     OP_REQUIRES(ctx,
183                 ParseFromTString(allocation_info.scalar<tstring>()(),
184                                  &allocation_proto),
185                 errors::InvalidArgument(
186                     "Unable to parse allocation input to XLAAllocation"));
187 
188     xla::Literal literal;
189     OP_REQUIRES_OK(
190         ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
191 
192     ResourceMgr* rm;
193     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
194 
195     // We are guaranteed that the underlying device object won't be deleted out
196     // from under us, while the ScopedRef is live.
197     class DeviceAccessor::ScopedRef device_ref;
198     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
199 
200     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
201     XRTTupleAllocation* allocation;
202     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
203                             literal, memory_manager.get(), device_ref.backend(),
204                             device_ref.device_ordinal(), &allocation));
205 
206     Tensor output(DT_INT64, TensorShape({}));
207     output.scalar<int64>()() = memory_manager->Register(allocation);
208     ctx->set_output(0, output);
209   }
210 };
211 
212 // Op that allocates uninitialized memory on the device for a tensor of
213 // a particular shape.
214 template <class DeviceAccessor>
215 class XRTAllocateUninitializedOp : public OpKernel {
216  public:
XRTAllocateUninitializedOp(OpKernelConstruction * ctx)217   explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx)
218       : OpKernel(ctx) {
219     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
220     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_));
221     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_));
222   }
223   ~XRTAllocateUninitializedOp() override = default;
224   XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete;
225   XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) =
226       delete;
227 
Compute(OpKernelContext * ctx)228   void Compute(OpKernelContext* ctx) override {
229     VLOG(1) << "XRTAllocateUninitializedOp::Compute";
230     auto timed =
231         monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell());
232     ResourceMgr* rm;
233     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
234 
235     // We are guaranteed that the underlying device object won't be deleted out
236     // from under us, while the ScopedRef is live.
237     class DeviceAccessor::ScopedRef device_ref;
238     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
239 
240     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
241     XRTTupleAllocation* allocation;
242     OP_REQUIRES_OK(ctx,
243                    XRTTupleAllocation::CreateUninitialized(
244                        xla_shape_, memory_manager.get(), device_ref.backend(),
245                        device_ref.device_ordinal(), &allocation));
246 
247     Tensor output(DT_INT64, TensorShape({}));
248     output.scalar<int64>()() = memory_manager->Register(allocation);
249     ctx->set_output(0, output);
250   }
251 
252  private:
253   DataType dtype_;
254   TensorShape tf_shape_;
255   xla::Shape xla_shape_;
256 };
257 
258 // Op that allocates memory for a tensor (with optional layout) and transfers it
259 // to the device, returning an allocation handle.
260 template <class DeviceAccessor>
261 class XRTAllocateFromTensorOp : public OpKernel {
262  public:
XRTAllocateFromTensorOp(OpKernelConstruction * ctx)263   explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
264     bool make_tuple = false;
265     OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_));
266     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
267     OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple));
268     std::vector<int64> minor_to_major;
269     if (ctx->HasAttr("layouts")) {
270       OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major));
271     }
272     OP_REQUIRES(
273         ctx, tf_shapes_.size() == dtypes_.size(),
274         errors::InvalidArgument("shapes and dtypes must be the same length"));
275     std::vector<xla::Shape> xla_shapes;
276     xla_shapes.reserve(tf_shapes_.size());
277     for (int i = 0; i < tf_shapes_.size(); i++) {
278       xla::Shape xla_shape;
279       OP_REQUIRES_OK(
280           ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape));
281       xla_shapes.push_back(std::move(xla_shape));
282     }
283     if (xla_shapes.size() > 1 || make_tuple) {
284       shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes);
285     } else {
286       shape_.Swap(&xla_shapes.front());
287     }
288     if (!minor_to_major.empty()) {
289       xla::Shape shape_with_layouts;
290       OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major,
291                                              /*layout_func=*/nullptr,
292                                              &shape_with_layouts));
293       shape_.Swap(&shape_with_layouts);
294     }
295   }
296 
297   ~XRTAllocateFromTensorOp() override = default;
298   XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete;
299   XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete;
300 
Compute(OpKernelContext * ctx)301   void Compute(OpKernelContext* ctx) override {
302     VLOG(1) << "XRTAllocateFromTensorOp::Compute";
303     auto timed =
304         monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell());
305 
306     OpInputList values;
307     OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values));
308     OP_REQUIRES(ctx, values.size() == tf_shapes_.size(),
309                 errors::InvalidArgument(
310                     "Wrong number of inputs to XRTAllocateFromTensor: ",
311                     values.size(), " vs. ", tf_shapes_.size()));
312 
313     std::vector<const char*> tensors_data;
314     for (size_t i = 0; i < values.size(); ++i) {
315       const Tensor& input_tensor = values[i];
316       OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i],
317                   errors::InvalidArgument(
318                       "Input tensor type and input dtype do not match"));
319       // We allow the requested on-device shape to differ from the shape of the
320       // input tensor, as long as they have the same number of elements.
321       OP_REQUIRES(
322           ctx,
323           input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(),
324           errors::InvalidArgument(
325               "Input tensor must have the number of elements specified "
326               "in the matching input shape: ",
327               input_tensor.shape().num_elements(), " vs. ",
328               tf_shapes_[i].num_elements(), " at index ", i));
329       tensors_data.push_back(
330           static_cast<const char*>(DMAHelper::base(&input_tensor)));
331     }
332     // Use the buffer straight out of the input tensors to create the literal.
333     xla::BorrowingLiteral literal =
334         shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_)
335                          : xla::BorrowingLiteral(tensors_data.front(), shape_);
336     ResourceMgr* rm;
337     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
338 
339     // We are guaranteed that the underlying device object won't be deleted out
340     // from under us, while the ScopedRef is live.
341     class DeviceAccessor::ScopedRef device_ref;
342     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
343 
344     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
345     XRTTupleAllocation* allocation;
346     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
347                             literal, memory_manager.get(), device_ref.backend(),
348                             device_ref.device_ordinal(), &allocation));
349 
350     Tensor output(DT_INT64, TensorShape({}));
351     output.scalar<int64>()() = memory_manager->Register(allocation);
352     ctx->set_output(0, output);
353   }
354 
355  private:
356   std::vector<TensorShape> tf_shapes_;
357   DataTypeVector dtypes_;
358   xla::Shape shape_;
359 };
360 
361 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the
362 // input.
363 template <bool discard_, class DeviceAccessor>
364 class XRTSubTupleOp : public OpKernel {
365  public:
XRTSubTupleOp(OpKernelConstruction * ctx)366   explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
367   ~XRTSubTupleOp() override = default;
368   XRTSubTupleOp(const XRTSubTupleOp&) = delete;
369   XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
370 
Compute(OpKernelContext * ctx)371   void Compute(OpKernelContext* ctx) override {
372     VLOG(1) << "XRTSubTupleOp::Compute";
373     auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell());
374 
375     const Tensor& handle_tensor = ctx->input(0);
376     OP_REQUIRES(
377         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
378         errors::Internal("computation input should be an int64 scalar"));
379     int64 allocation_handle = handle_tensor.scalar<int64>()();
380 
381     const Tensor& subtuple_info = ctx->input(1);
382     OP_REQUIRES(
383         ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
384         errors::Internal("tuple index input should be an int32 vector"));
385     xla::ShapeIndex shape_index;
386     for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
387       shape_index.push_back(subtuple_info.vec<int32>()(i));
388     }
389 
390     ResourceMgr* rm;
391     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
392 
393     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
394     RefPtr<XRTTupleAllocation> allocation;
395     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
396 
397     if (discard_) {
398       VLOG(2) << "Releasing handle " << allocation_handle;
399       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
400     }
401 
402     XRTTupleAllocation* suballocation;
403     OP_REQUIRES_OK(
404         ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index,
405                                                &suballocation, !discard_));
406 
407     Tensor output(DT_INT64, TensorShape({}));
408     output.scalar<int64>()() = memory_manager->Register(suballocation);
409     ctx->set_output(0, output);
410   }
411 };
412 
413 // Op that allocates memory for a literal and transfers it to the device.
414 template <class DeviceAccessor>
415 class XRTMakeTupleOp : public OpKernel {
416  public:
XRTMakeTupleOp(OpKernelConstruction * ctx)417   explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
418   ~XRTMakeTupleOp() override = default;
419   XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
420   XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
421 
Compute(OpKernelContext * ctx)422   void Compute(OpKernelContext* ctx) override {
423     VLOG(1) << "XRTMakeTupleOp::Compute";
424     auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell());
425 
426     const Tensor& tuple_info = ctx->input(0);
427     OP_REQUIRES(
428         ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
429         errors::Internal("tuple description input should be a string scalar"));
430     xrt::XLATupleNode tuple_proto;
431     OP_REQUIRES(
432         ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto),
433         errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
434 
435     OpInputList arg_list;
436     OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
437 
438     // For each input, the allocation it corresponds to and a flag indicating
439     // whether or not it should be released, i.e. discarded from the resource
440     // manager. One ref on each allocation is owned by this vector, and freed on
441     // exit.
442     std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
443         arg_list.size());
444     ResourceMgr* rm;
445     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
446 
447     xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
448     // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
449     // the allocations. It is guaranteed that there is at least on allocation in
450     // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
451     // all the allocations are on the same device.
452     int device_ordinal;
453     OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
454                             tuple_proto, arg_list, &input_vector,
455                             &tuple_shape_tree, &device_ordinal, rm));
456 
457     // We are guaranteed that the underlying device object won't be deleted out
458     // from under us, while the ScopedRef is live.
459     class DeviceAccessor::ScopedRef device_ref;
460     OP_REQUIRES_OK(
461         ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
462 
463     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
464     XRTTupleAllocation* output_allocation;
465     OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
466                             memory_manager.get(), device_ref.backend(),
467                             device_ref.device_ordinal(), tuple_shape_tree,
468                             &output_allocation));
469     RefPtr<XRTTupleAllocation> output_ptr(output_allocation);
470     for (int i = 0; i < input_vector.size(); ++i) {
471       if (input_vector[i].release_allocation_after_use) {
472         OP_REQUIRES_OK(ctx,
473                        memory_manager->Release(arg_list[i].scalar<int64>()()));
474       }
475     }
476 
477     Tensor output(DT_INT64, TensorShape({}));
478     output.scalar<int64>()() = memory_manager->Register(std::move(output_ptr));
479     ctx->set_output(0, output);
480   }
481 };
482 
483 // Op that reads a device-resident tuple to host memory and returns it as a
484 // literal.
485 template <bool discard_, class DeviceAccessor>
486 class XRTReadLiteralOp : public OpKernel {
487  public:
XRTReadLiteralOp(OpKernelConstruction * ctx)488   explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
489   ~XRTReadLiteralOp() override = default;
490   XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
491   XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
492 
Compute(OpKernelContext * ctx)493   void Compute(OpKernelContext* ctx) override {
494     VLOG(1) << "XRTReadLiteralOp::Compute";
495     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell());
496 
497     const Tensor& handle_tensor = ctx->input(0);
498     OP_REQUIRES(
499         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
500         errors::Internal("computation input should be an int64 scalar"));
501     int64 allocation_handle = handle_tensor.scalar<int64>()();
502 
503     ResourceMgr* rm;
504     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
505 
506     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
507     RefPtr<XRTTupleAllocation> allocation;
508     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
509 
510     if (discard_) {
511       VLOG(2) << "Releasing handle " << allocation_handle;
512       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
513     }
514 
515     // We are guaranteed that the underlying device object won't be deleted out
516     // from under us, while the ScopedRef is live.
517     class DeviceAccessor::ScopedRef device_ref;
518     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
519                             ctx, allocation->device_ordinal(), &device_ref));
520 
521     xla::Literal literal(allocation->on_host_shape());
522     OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal));
523     xla::LiteralProto literal_proto = literal.ToProto();
524 
525     Tensor output(DT_STRING, TensorShape({}));
526     SerializeToTString(literal_proto, &output.scalar<tstring>()());
527     ctx->set_output(0, output);
528   }
529 };
530 
531 // Op that reads a device-resident tuple to host memory and returns it as a
532 // literal.
533 template <class DeviceAccessor>
534 class XRTReadToTensorOp : public OpKernel {
535  public:
XRTReadToTensorOp(OpKernelConstruction * ctx)536   explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
537     OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_));
538     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_));
539   }
540   ~XRTReadToTensorOp() override = default;
541   XRTReadToTensorOp(const XRTReadToTensorOp&) = delete;
542   XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete;
543 
Compute(OpKernelContext * ctx)544   void Compute(OpKernelContext* ctx) override {
545     VLOG(1) << "XRTReadToTensorOp::Compute";
546     auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell());
547 
548     const Tensor& handle_tensor = ctx->input(0);
549     // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not
550     // just scalars.)
551     OP_REQUIRES(
552         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
553         errors::Internal("computation input should be an int64 scalar"));
554     int64 allocation_handle = handle_tensor.scalar<int64>()();
555 
556     ResourceMgr* rm;
557     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
558 
559     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
560     RefPtr<XRTTupleAllocation> allocation;
561     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
562 
563     if (discard_) {
564       VLOG(2) << "Releasing handle " << allocation_handle;
565       OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle));
566     }
567 
568     // We are guaranteed that the underlying device object won't be deleted out
569     // from under us, while the ScopedRef is live.
570     class DeviceAccessor::ScopedRef device_ref;
571     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
572                             ctx, allocation->device_ordinal(), &device_ref));
573 
574     xla::Shape shape = allocation->on_host_shape();
575     int output = 0;
576     Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus(
577         &shape,
578         [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status {
579           if (subshape->IsTuple()) return Status::OK();
580 
581           xla::PrimitiveType xla_type;
582           TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(
583               ctx->expected_output_dtype(output), &xla_type));
584           if (xla_type != subshape->element_type()) {
585             return errors::InvalidArgument(
586                 "Type mismatch between buffer type (", subshape->ToString(),
587                 ") and tensor type (",
588                 DataTypeString(ctx->expected_output_dtype(output)),
589                 ") for output tensor ", output);
590           }
591 
592           TensorShape output_shape;
593           TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape));
594 
595           Tensor* output_tensor;
596           TF_RETURN_IF_ERROR(
597               ctx->allocate_output(output, output_shape, &output_tensor));
598 
599           XRTTupleAllocation* sub;
600           TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer(
601               allocation.get(), index, &sub, /*alias_parent_allocation=*/true));
602           core::ScopedUnref sub_unref(sub);
603 
604           xla::MutableBorrowingLiteral literal;
605           TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(
606               xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor,
607               &literal));
608           TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal));
609 
610           ++output;
611           return Status::OK();
612         });
613     OP_REQUIRES_OK(ctx, status);
614   }
615   bool discard_;
616   DataTypeVector dtypes_;
617 };
618 
619 // Op that writes a new literal value into device-resident memory.
620 template <class DeviceAccessor>
621 class XRTWriteLiteralOp : public OpKernel {
622  public:
XRTWriteLiteralOp(OpKernelConstruction * ctx)623   explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
624   ~XRTWriteLiteralOp() override = default;
625   XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete;
626   XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete;
627 
Compute(OpKernelContext * ctx)628   void Compute(OpKernelContext* ctx) override {
629     VLOG(1) << "XRTWriteLiteralOp::Compute";
630     auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell());
631 
632     const Tensor& handle_tensor = ctx->input(0);
633     OP_REQUIRES(
634         ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
635         errors::Internal("computation input should be an int64 scalar"));
636     int64 allocation_handle = handle_tensor.scalar<int64>()();
637 
638     const Tensor& literal_info = ctx->input(1);
639     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()),
640                 errors::Internal("literal input should be a string scalar"));
641     xla::LiteralProto literal_proto;
642     OP_REQUIRES(
643         ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto),
644         errors::InvalidArgument(
645             "Unable to parse allocation input to LiteralProto"));
646     xla::Literal literal;
647     OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal));
648 
649     ResourceMgr* rm;
650     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
651 
652     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
653     RefPtr<XRTTupleAllocation> allocation;
654     OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation));
655 
656     // We are guaranteed that the underlying device object won't be deleted out
657     // from under us, while the ScopedRef is live.
658     typename DeviceAccessor::ScopedRef device_ref;
659     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
660                             ctx, allocation->device_ordinal(), &device_ref));
661     OP_REQUIRES_OK(ctx,
662                    allocation->WriteLiteral(device_ref.backend(), literal));
663 
664     Tensor output(DT_INT64, TensorShape({}));
665     output.scalar<int64>()() = allocation_handle;
666     ctx->set_output(0, output);
667   }
668 };
669 
670 // Op that discards a handle to device memory.
671 template <class DeviceAccessor>
672 class XRTReleaseAllocationOp : public OpKernel {
673  public:
XRTReleaseAllocationOp(OpKernelConstruction * ctx)674   explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
675   ~XRTReleaseAllocationOp() override = default;
676   XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
677   XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
678 
Compute(OpKernelContext * ctx)679   void Compute(OpKernelContext* ctx) override {
680     VLOG(1) << "XRTReleaseAllocationOp::Compute";
681     auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell());
682 
683     ResourceMgr* rm;
684     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
685 
686     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
687     const Tensor& allocation_handle = ctx->input(0);
688     auto flat_keys = allocation_handle.flat<int64>();
689     for (int64 i = 0; i < flat_keys.size(); ++i) {
690       int64 key = flat_keys(i);
691       OP_REQUIRES_OK(ctx, memory_manager->Release(key));
692       VLOG(2) << "Released allocation handle " << key;
693     }
694   }
695 };
696 
697 // Op that discards a handle to device memory.
698 template <class DeviceAccessor>
699 class XRTReleaseAllAllocationsOp : public OpKernel {
700  public:
XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)701   explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx)
702       : OpKernel(ctx) {}
703   ~XRTReleaseAllAllocationsOp() override = default;
704   XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete;
705   XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) =
706       delete;
707 
Compute(OpKernelContext * ctx)708   void Compute(OpKernelContext* ctx) override {
709     VLOG(1) << "XRTReleaseAllAllocationsOp::Compute";
710     auto timed =
711         monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell());
712 
713     ResourceMgr* rm;
714     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
715     XRTMemoryManager::Get(rm)->ReleaseAllAllocations();
716   }
717 };
718 
719 template <class DeviceAccessor>
720 class XRTCompactAllocationsOp : public OpKernel {
721  public:
XRTCompactAllocationsOp(OpKernelConstruction * ctx)722   explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
723   ~XRTCompactAllocationsOp() override = default;
724   XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete;
725   XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete;
726 
Compute(OpKernelContext * ctx)727   void Compute(OpKernelContext* ctx) override {
728     VLOG(1) << "XRTCompactAllocationsOp::Compute";
729     auto timed =
730         monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell());
731 
732     ResourceMgr* rm;
733     OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
734     RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm);
735     class DeviceAccessor::ScopedRef device_ref;
736     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref));
737     OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations(
738                             device_ref.backend(), device_ref.device_ordinal()));
739   }
740 };
741 
742 template <class DeviceAccessor>
743 class XRTMemoryInfoOp : public OpKernel {
744  public:
XRTMemoryInfoOp(OpKernelConstruction * ctx)745   explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
746   ~XRTMemoryInfoOp() override = default;
747   XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete;
748   XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete;
749 
Compute(OpKernelContext * ctx)750   void Compute(OpKernelContext* ctx) override {
751     auto kernel_fn = [&]() -> Status {
752       VLOG(1) << "XRTMemoryInfoOp::Compute";
753 
754       class DeviceAccessor::ScopedRef device_ref;
755       TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref));
756       TF_ASSIGN_OR_RETURN(
757           se::StreamExecutor * stream_executor,
758           device_ref.backend()->stream_executor(device_ref.device_ordinal()));
759       int64 mem_free = -1;
760       int64 mem_total = -1;
761       if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) {
762         VLOG(2) << "Device " << ctx->device()->name()
763                 << " does not expose memory information";
764       }
765       xrt::MemoryInfo mem_info;
766       mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1);
767       mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1);
768 
769       Tensor output(DT_STRING, TensorShape({}));
770       output.scalar<tstring>()() = mem_info.SerializeAsString();
771       ctx->set_output(0, output);
772       return Status::OK();
773     };
774     OP_REQUIRES_OK(ctx, kernel_fn());
775   }
776 };
777 
778 }  // namespace tensorflow
779 
780 #endif  // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
781