1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for allocating XLA literals in device memory and managing handles 17 // that refer to them. 18 19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 21 22 #include <functional> 23 #include <memory> 24 #include <string> 25 26 #include "tensorflow/compiler/tf2xla/literal_util.h" 27 #include "tensorflow/compiler/tf2xla/shape_util.h" 28 #include "tensorflow/compiler/tf2xla/type_util.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/compiler/xrt/xrt.pb.h" 36 #include "tensorflow/compiler/xrt/xrt_device.h" 37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h" 38 #include "tensorflow/compiler/xrt/xrt_metrics.h" 39 #include "tensorflow/compiler/xrt/xrt_state.h" 40 #include "tensorflow/core/common_runtime/dma_helper.h" 41 #include "tensorflow/core/framework/op_kernel.h" 42 #include "tensorflow/core/framework/resource_mgr.h" 43 #include "tensorflow/core/framework/tensor.h" 44 #include "tensorflow/core/framework/tensor_shape.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/lib/core/errors.h" 47 #include "tensorflow/core/lib/core/refcount.h" 48 #include "tensorflow/core/lib/core/status.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h" 51 #include "tensorflow/core/lib/monitoring/timed.h" 52 #include "tensorflow/core/platform/types.h" 53 54 namespace tensorflow { 55 56 // Helper functions for templated ops. 57 class XRTStateHelpers { 58 public: 59 // The Status return value allows us to use the 60 // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an 61 // OpKernel::Compute method. MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62 static Status MakeLiteral(const xla::LiteralProto& proto, 63 xla::Literal* literal) { 64 TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); 65 return Status::OK(); 66 } 67 68 // ParseTupleNode is the recursive function used to parse a recursive 69 // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the 70 // tuple shape where every leaf is an existing allocation. As a side-effect it 71 // fills in input_vector by looking up allocations from handles in the 72 // input_tensor_list as they are referenced by nodes in the proto. ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73 static Status ParseTupleNode( 74 const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list, 75 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 76 xla::Shape* shape, ResourceMgr* rm) { 77 if (tuple_node.tuples_size() > 0) { 78 // This is an internal node in the proto so descend recursively. 79 xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({}); 80 std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy); 81 *xla::ShapeUtil::GetMutableSubshape(shape, {}) = 82 xla::ShapeUtil::MakeTupleShape(subshapes); 83 for (int i = 0; i < tuple_node.tuples_size(); ++i) { 84 TF_RETURN_IF_ERROR(ParseTupleNode( 85 tuple_node.tuples(i), input_tensor_list, input_vector, 86 xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm)); 87 } 88 } else { 89 // This is a leaf node in the proto so look up the referenced input. 90 int input_index = tuple_node.input_index(); 91 if (input_index < 0 || input_index >= input_vector->size()) { 92 return errors::InvalidArgument("Invalid tuple input index ", 93 input_index, ": MakeTuple has ", 94 input_vector->size(), " inputs."); 95 } 96 bool release_this_input = tuple_node.release_input_handle(); 97 XRTTupleAllocation::ExpandedTupleInput& input = 98 input_vector->at(input_index); 99 if (input.allocation != nullptr && 100 (input.release_allocation_after_use || release_this_input)) { 101 return errors::InvalidArgument( 102 "Invalid tuple tree: input index ", input_index, 103 " is repeated but release_input_handle is true."); 104 } 105 if (input.allocation == nullptr) { 106 // We haven't dereferenced this handle yet. 107 TF_RET_CHECK( 108 TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape())); 109 int64_t key = input_tensor_list[input_index].scalar<int64>()(); 110 TF_ASSIGN_OR_RETURN(input.allocation, 111 XRTMemoryManager::Get(rm)->Lookup(key)); 112 input.release_allocation_after_use = release_this_input; 113 } 114 } 115 return Status::OK(); 116 } 117 118 // Parses a xrt::XLATupleNode proto recursively and returns the corresponding 119 // ShapeTree where each leaf is an allocation corresponding to a handle in 120 // input_tensor_list. The ordinal of one of the allocations is returned in 121 // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with 122 // no leaves, device_ordinal will always be filled in by a successful call to 123 // ParseTupleTree. ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124 static Status ParseTupleTree( 125 const xrt::XLATupleNode& tuple_tree_root, 126 const OpInputList& input_tensor_list, 127 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 128 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree, 129 int* device_ordinal, ResourceMgr* rm) { 130 // First get the shape of the 'spine' of the new tuple, where every leaf is 131 // an existing allocation. As a side-effect dereference the input handles 132 // into allocations in input_vector. 133 xla::Shape tuple_tree_shape; 134 TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list, 135 input_vector, &tuple_tree_shape, rm)); 136 // Make the shape tree of allocations where the shape is the spine and each 137 // leaf is one of the allocations looked up in input_vector. Internal nodes 138 // have nullptr allocations. 139 *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>( 140 tuple_tree_shape); 141 tuple_shape_tree->ForEachMutableElement( 142 [&](const xla::ShapeIndex& index, 143 XRTTupleAllocation::ExpandedTupleInput* element) { 144 if (tuple_shape_tree->IsLeaf(index)) { 145 // Find the matching leaf in the proto tree. 146 const xrt::XLATupleNode* tuple_node = &tuple_tree_root; 147 for (int i = 0; i < index.size(); ++i) { 148 tuple_node = &tuple_node->tuples(index[i]); 149 } 150 // Copy the appropriate input allocation to the leaf of the 151 // tuple_shape_tree. 152 int input_index = tuple_node->input_index(); 153 *element = input_vector->at(input_index); 154 CHECK(element->release_allocation_after_use == 155 tuple_node->release_input_handle()); 156 // We just need to know the device_ordinal of one of the 157 // allocations. We will validate later that they are all the same. 158 *device_ordinal = (*element).allocation->device_ordinal(); 159 } 160 }); 161 return Status::OK(); 162 } 163 }; 164 165 // Op that allocates memory for a literal and transfers it to the device. 166 template <class DeviceAccessor> 167 class XRTAllocateOp : public OpKernel { 168 public: XRTAllocateOp(OpKernelConstruction * ctx)169 explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 170 ~XRTAllocateOp() override = default; 171 XRTAllocateOp(const XRTAllocateOp&) = delete; 172 XRTAllocateOp& operator=(const XRTAllocateOp&) = delete; 173 Compute(OpKernelContext * ctx)174 void Compute(OpKernelContext* ctx) override { 175 VLOG(1) << "XRTAllocateOp::Compute"; 176 auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell()); 177 178 const Tensor& allocation_info = ctx->input(0); 179 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()), 180 errors::Internal("allocation input should be a string scalar")); 181 xrt::XLAAllocation allocation_proto; 182 OP_REQUIRES(ctx, 183 ParseFromTString(allocation_info.scalar<tstring>()(), 184 &allocation_proto), 185 errors::InvalidArgument( 186 "Unable to parse allocation input to XLAAllocation")); 187 188 xla::Literal literal; 189 OP_REQUIRES_OK( 190 ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); 191 192 ResourceMgr* rm; 193 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 194 195 // We are guaranteed that the underlying device object won't be deleted out 196 // from under us, while the ScopedRef is live. 197 class DeviceAccessor::ScopedRef device_ref; 198 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 199 200 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 201 XRTTupleAllocation* allocation; 202 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 203 literal, memory_manager.get(), device_ref.backend(), 204 device_ref.device_ordinal(), &allocation, 205 device_ref.allocator())); 206 207 Tensor output(DT_INT64, TensorShape({})); 208 output.scalar<int64>()() = memory_manager->Register(allocation); 209 ctx->set_output(0, output); 210 } 211 }; 212 213 // Op that allocates uninitialized memory on the device for a tensor of 214 // a particular shape. 215 template <class DeviceAccessor> 216 class XRTAllocateUninitializedOp : public OpKernel { 217 public: XRTAllocateUninitializedOp(OpKernelConstruction * ctx)218 explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx) 219 : OpKernel(ctx) { 220 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 221 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_)); 222 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_)); 223 } 224 ~XRTAllocateUninitializedOp() override = default; 225 XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete; 226 XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) = 227 delete; 228 Compute(OpKernelContext * ctx)229 void Compute(OpKernelContext* ctx) override { 230 VLOG(1) << "XRTAllocateUninitializedOp::Compute"; 231 auto timed = 232 monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell()); 233 ResourceMgr* rm; 234 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 235 236 // We are guaranteed that the underlying device object won't be deleted out 237 // from under us, while the ScopedRef is live. 238 class DeviceAccessor::ScopedRef device_ref; 239 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 240 241 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 242 XRTTupleAllocation* allocation; 243 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateUninitialized( 244 xla_shape_, memory_manager.get(), 245 device_ref.backend(), device_ref.device_ordinal(), 246 &allocation, device_ref.allocator())); 247 248 Tensor output(DT_INT64, TensorShape({})); 249 output.scalar<int64>()() = memory_manager->Register(allocation); 250 ctx->set_output(0, output); 251 } 252 253 private: 254 DataType dtype_; 255 TensorShape tf_shape_; 256 xla::Shape xla_shape_; 257 }; 258 259 // Op that allocates memory for a tensor (with optional layout) and transfers it 260 // to the device, returning an allocation handle. 261 template <class DeviceAccessor> 262 class XRTAllocateFromTensorOp : public OpKernel { 263 public: XRTAllocateFromTensorOp(OpKernelConstruction * ctx)264 explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 265 bool make_tuple = false; 266 OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); 267 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 268 OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); 269 std::vector<int64> minor_to_major; 270 if (ctx->HasAttr("layouts")) { 271 OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); 272 } 273 OP_REQUIRES( 274 ctx, tf_shapes_.size() == dtypes_.size(), 275 errors::InvalidArgument("shapes and dtypes must be the same length")); 276 std::vector<xla::Shape> xla_shapes; 277 xla_shapes.reserve(tf_shapes_.size()); 278 for (int i = 0; i < tf_shapes_.size(); i++) { 279 xla::Shape xla_shape; 280 OP_REQUIRES_OK( 281 ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); 282 xla_shapes.push_back(std::move(xla_shape)); 283 } 284 if (xla_shapes.size() > 1 || make_tuple) { 285 shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); 286 } else { 287 shape_.Swap(&xla_shapes.front()); 288 } 289 if (!minor_to_major.empty()) { 290 xla::Shape shape_with_layouts; 291 OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, 292 /*layout_func=*/nullptr, 293 &shape_with_layouts)); 294 shape_.Swap(&shape_with_layouts); 295 } 296 } 297 298 ~XRTAllocateFromTensorOp() override = default; 299 XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; 300 XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; 301 Compute(OpKernelContext * ctx)302 void Compute(OpKernelContext* ctx) override { 303 VLOG(1) << "XRTAllocateFromTensorOp::Compute"; 304 auto timed = 305 monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell()); 306 307 OpInputList values; 308 OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); 309 OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), 310 errors::InvalidArgument( 311 "Wrong number of inputs to XRTAllocateFromTensor: ", 312 values.size(), " vs. ", tf_shapes_.size())); 313 314 std::vector<const char*> tensors_data; 315 for (size_t i = 0; i < values.size(); ++i) { 316 const Tensor& input_tensor = values[i]; 317 OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], 318 errors::InvalidArgument( 319 "Input tensor type and input dtype do not match")); 320 // We allow the requested on-device shape to differ from the shape of the 321 // input tensor, as long as they have the same number of elements. 322 OP_REQUIRES( 323 ctx, 324 input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), 325 errors::InvalidArgument( 326 "Input tensor must have the number of elements specified " 327 "in the matching input shape: ", 328 input_tensor.shape().num_elements(), " vs. ", 329 tf_shapes_[i].num_elements(), " at index ", i)); 330 tensors_data.push_back( 331 static_cast<const char*>(DMAHelper::base(&input_tensor))); 332 } 333 // Use the buffer straight out of the input tensors to create the literal. 334 xla::BorrowingLiteral literal = 335 shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) 336 : xla::BorrowingLiteral(tensors_data.front(), shape_); 337 ResourceMgr* rm; 338 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 339 340 // We are guaranteed that the underlying device object won't be deleted out 341 // from under us, while the ScopedRef is live. 342 class DeviceAccessor::ScopedRef device_ref; 343 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 344 345 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 346 XRTTupleAllocation* allocation; 347 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 348 literal, memory_manager.get(), device_ref.backend(), 349 device_ref.device_ordinal(), &allocation, 350 device_ref.allocator())); 351 352 Tensor output(DT_INT64, TensorShape({})); 353 output.scalar<int64>()() = memory_manager->Register(allocation); 354 ctx->set_output(0, output); 355 } 356 357 private: 358 std::vector<TensorShape> tf_shapes_; 359 DataTypeVector dtypes_; 360 xla::Shape shape_; 361 }; 362 363 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the 364 // input. 365 template <bool discard_, class DeviceAccessor> 366 class XRTSubTupleOp : public OpKernel { 367 public: XRTSubTupleOp(OpKernelConstruction * ctx)368 explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 369 ~XRTSubTupleOp() override = default; 370 XRTSubTupleOp(const XRTSubTupleOp&) = delete; 371 XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete; 372 Compute(OpKernelContext * ctx)373 void Compute(OpKernelContext* ctx) override { 374 VLOG(1) << "XRTSubTupleOp::Compute"; 375 auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell()); 376 377 const Tensor& handle_tensor = ctx->input(0); 378 OP_REQUIRES( 379 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 380 errors::Internal("computation input should be an int64 scalar")); 381 int64_t allocation_handle = handle_tensor.scalar<int64>()(); 382 383 const Tensor& subtuple_info = ctx->input(1); 384 OP_REQUIRES( 385 ctx, TensorShapeUtils::IsVector(subtuple_info.shape()), 386 errors::Internal("tuple index input should be an int32 vector")); 387 xla::ShapeIndex shape_index; 388 for (int i = 0; i < subtuple_info.dim_size(0); ++i) { 389 shape_index.push_back(subtuple_info.vec<int32>()(i)); 390 } 391 392 ResourceMgr* rm; 393 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 394 395 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 396 RefPtr<XRTTupleAllocation> allocation; 397 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 398 399 if (discard_) { 400 VLOG(2) << "Releasing handle " << allocation_handle; 401 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 402 } 403 404 XRTTupleAllocation* suballocation; 405 OP_REQUIRES_OK( 406 ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index, 407 &suballocation, !discard_)); 408 409 Tensor output(DT_INT64, TensorShape({})); 410 output.scalar<int64>()() = memory_manager->Register(suballocation); 411 ctx->set_output(0, output); 412 } 413 }; 414 415 // Op that allocates memory for a literal and transfers it to the device. 416 template <class DeviceAccessor> 417 class XRTMakeTupleOp : public OpKernel { 418 public: XRTMakeTupleOp(OpKernelConstruction * ctx)419 explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 420 ~XRTMakeTupleOp() override = default; 421 XRTMakeTupleOp(const XRTMakeTupleOp&) = delete; 422 XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete; 423 Compute(OpKernelContext * ctx)424 void Compute(OpKernelContext* ctx) override { 425 VLOG(1) << "XRTMakeTupleOp::Compute"; 426 auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell()); 427 428 const Tensor& tuple_info = ctx->input(0); 429 OP_REQUIRES( 430 ctx, TensorShapeUtils::IsScalar(tuple_info.shape()), 431 errors::Internal("tuple description input should be a string scalar")); 432 xrt::XLATupleNode tuple_proto; 433 OP_REQUIRES( 434 ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto), 435 errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); 436 437 OpInputList arg_list; 438 OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list)); 439 440 // For each input, the allocation it corresponds to and a flag indicating 441 // whether or not it should be released, i.e. discarded from the resource 442 // manager. One ref on each allocation is owned by this vector, and freed on 443 // exit. 444 std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector( 445 arg_list.size()); 446 ResourceMgr* rm; 447 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 448 449 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree; 450 // device_ordinal is filled in by ParseTupleTree with the ordinal of one of 451 // the allocations. It is guaranteed that there is at least on allocation in 452 // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that 453 // all the allocations are on the same device. 454 int device_ordinal; 455 OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree( 456 tuple_proto, arg_list, &input_vector, 457 &tuple_shape_tree, &device_ordinal, rm)); 458 459 // We are guaranteed that the underlying device object won't be deleted out 460 // from under us, while the ScopedRef is live. 461 class DeviceAccessor::ScopedRef device_ref; 462 OP_REQUIRES_OK( 463 ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref)); 464 465 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 466 XRTTupleAllocation* output_allocation; 467 OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple( 468 memory_manager.get(), device_ref.backend(), 469 device_ref.device_ordinal(), tuple_shape_tree, 470 &output_allocation, device_ref.allocator())); 471 RefPtr<XRTTupleAllocation> output_ptr(output_allocation); 472 for (int i = 0; i < input_vector.size(); ++i) { 473 if (input_vector[i].release_allocation_after_use) { 474 OP_REQUIRES_OK(ctx, 475 memory_manager->Release(arg_list[i].scalar<int64>()())); 476 } 477 } 478 479 Tensor output(DT_INT64, TensorShape({})); 480 output.scalar<int64>()() = memory_manager->Register(std::move(output_ptr)); 481 ctx->set_output(0, output); 482 } 483 }; 484 485 // Op that reads a device-resident tuple to host memory and returns it as a 486 // literal. 487 template <bool discard_, class DeviceAccessor> 488 class XRTReadLiteralOp : public OpKernel { 489 public: XRTReadLiteralOp(OpKernelConstruction * ctx)490 explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 491 ~XRTReadLiteralOp() override = default; 492 XRTReadLiteralOp(const XRTReadLiteralOp&) = delete; 493 XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete; 494 Compute(OpKernelContext * ctx)495 void Compute(OpKernelContext* ctx) override { 496 VLOG(1) << "XRTReadLiteralOp::Compute"; 497 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell()); 498 499 const Tensor& handle_tensor = ctx->input(0); 500 OP_REQUIRES( 501 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 502 errors::Internal("computation input should be an int64 scalar")); 503 int64_t allocation_handle = handle_tensor.scalar<int64>()(); 504 505 ResourceMgr* rm; 506 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 507 508 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 509 RefPtr<XRTTupleAllocation> allocation; 510 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 511 512 if (discard_) { 513 VLOG(2) << "Releasing handle " << allocation_handle; 514 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 515 } 516 517 // We are guaranteed that the underlying device object won't be deleted out 518 // from under us, while the ScopedRef is live. 519 class DeviceAccessor::ScopedRef device_ref; 520 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 521 ctx, allocation->device_ordinal(), &device_ref)); 522 523 xla::Literal literal(allocation->on_host_shape()); 524 OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal)); 525 xla::LiteralProto literal_proto = literal.ToProto(); 526 527 Tensor output(DT_STRING, TensorShape({})); 528 SerializeToTString(literal_proto, &output.scalar<tstring>()()); 529 ctx->set_output(0, output); 530 } 531 }; 532 533 // Op that reads a device-resident tuple to host memory and returns it as a 534 // literal. 535 template <class DeviceAccessor> 536 class XRTReadToTensorOp : public OpKernel { 537 public: XRTReadToTensorOp(OpKernelConstruction * ctx)538 explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 539 OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); 540 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 541 } 542 ~XRTReadToTensorOp() override = default; 543 XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; 544 XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; 545 Compute(OpKernelContext * ctx)546 void Compute(OpKernelContext* ctx) override { 547 VLOG(1) << "XRTReadToTensorOp::Compute"; 548 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell()); 549 550 const Tensor& handle_tensor = ctx->input(0); 551 // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not 552 // just scalars.) 553 OP_REQUIRES( 554 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 555 errors::Internal("computation input should be an int64 scalar")); 556 int64_t allocation_handle = handle_tensor.scalar<int64>()(); 557 558 ResourceMgr* rm; 559 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 560 561 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 562 RefPtr<XRTTupleAllocation> allocation; 563 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 564 565 if (discard_) { 566 VLOG(2) << "Releasing handle " << allocation_handle; 567 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 568 } 569 570 // We are guaranteed that the underlying device object won't be deleted out 571 // from under us, while the ScopedRef is live. 572 class DeviceAccessor::ScopedRef device_ref; 573 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 574 ctx, allocation->device_ordinal(), &device_ref)); 575 576 xla::Shape shape = allocation->on_host_shape(); 577 int output = 0; 578 Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( 579 &shape, 580 [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { 581 if (subshape->IsTuple()) return Status::OK(); 582 583 xla::PrimitiveType xla_type; 584 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( 585 ctx->expected_output_dtype(output), &xla_type)); 586 if (xla_type != subshape->element_type()) { 587 return errors::InvalidArgument( 588 "Type mismatch between buffer type (", subshape->ToString(), 589 ") and tensor type (", 590 DataTypeString(ctx->expected_output_dtype(output)), 591 ") for output tensor ", output); 592 } 593 594 TensorShape output_shape; 595 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); 596 597 Tensor* output_tensor; 598 TF_RETURN_IF_ERROR( 599 ctx->allocate_output(output, output_shape, &output_tensor)); 600 601 XRTTupleAllocation* sub; 602 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( 603 allocation.get(), index, &sub, /*alias_parent_allocation=*/true)); 604 core::ScopedUnref sub_unref(sub); 605 606 xla::MutableBorrowingLiteral literal; 607 TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( 608 xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, 609 &literal)); 610 TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal)); 611 612 ++output; 613 return Status::OK(); 614 }); 615 OP_REQUIRES_OK(ctx, status); 616 } 617 bool discard_; 618 DataTypeVector dtypes_; 619 }; 620 621 // Op that writes a new literal value into device-resident memory. 622 template <class DeviceAccessor> 623 class XRTWriteLiteralOp : public OpKernel { 624 public: XRTWriteLiteralOp(OpKernelConstruction * ctx)625 explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 626 ~XRTWriteLiteralOp() override = default; 627 XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; 628 XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; 629 Compute(OpKernelContext * ctx)630 void Compute(OpKernelContext* ctx) override { 631 VLOG(1) << "XRTWriteLiteralOp::Compute"; 632 auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell()); 633 634 const Tensor& handle_tensor = ctx->input(0); 635 OP_REQUIRES( 636 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 637 errors::Internal("computation input should be an int64 scalar")); 638 int64_t allocation_handle = handle_tensor.scalar<int64>()(); 639 640 const Tensor& literal_info = ctx->input(1); 641 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), 642 errors::Internal("literal input should be a string scalar")); 643 xla::LiteralProto literal_proto; 644 OP_REQUIRES( 645 ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto), 646 errors::InvalidArgument( 647 "Unable to parse allocation input to LiteralProto")); 648 xla::Literal literal; 649 OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); 650 651 ResourceMgr* rm; 652 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 653 654 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 655 RefPtr<XRTTupleAllocation> allocation; 656 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 657 658 // We are guaranteed that the underlying device object won't be deleted out 659 // from under us, while the ScopedRef is live. 660 typename DeviceAccessor::ScopedRef device_ref; 661 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 662 ctx, allocation->device_ordinal(), &device_ref)); 663 OP_REQUIRES_OK(ctx, 664 allocation->WriteLiteral(device_ref.backend(), literal)); 665 666 Tensor output(DT_INT64, TensorShape({})); 667 output.scalar<int64>()() = allocation_handle; 668 ctx->set_output(0, output); 669 } 670 }; 671 672 // Op that discards a handle to device memory. 673 template <class DeviceAccessor> 674 class XRTReleaseAllocationOp : public OpKernel { 675 public: XRTReleaseAllocationOp(OpKernelConstruction * ctx)676 explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 677 ~XRTReleaseAllocationOp() override = default; 678 XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete; 679 XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete; 680 Compute(OpKernelContext * ctx)681 void Compute(OpKernelContext* ctx) override { 682 VLOG(1) << "XRTReleaseAllocationOp::Compute"; 683 auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell()); 684 685 ResourceMgr* rm; 686 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 687 688 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 689 const Tensor& allocation_handle = ctx->input(0); 690 auto flat_keys = allocation_handle.flat<int64>(); 691 for (int64_t i = 0; i < flat_keys.size(); ++i) { 692 int64_t key = flat_keys(i); 693 OP_REQUIRES_OK(ctx, memory_manager->Release(key)); 694 VLOG(2) << "Released allocation handle " << key; 695 } 696 } 697 }; 698 699 // Op that discards a handle to device memory. 700 template <class DeviceAccessor> 701 class XRTReleaseAllAllocationsOp : public OpKernel { 702 public: XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)703 explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) 704 : OpKernel(ctx) {} 705 ~XRTReleaseAllAllocationsOp() override = default; 706 XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; 707 XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = 708 delete; 709 Compute(OpKernelContext * ctx)710 void Compute(OpKernelContext* ctx) override { 711 VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; 712 auto timed = 713 monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell()); 714 715 ResourceMgr* rm; 716 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 717 XRTMemoryManager::Get(rm)->ReleaseAllAllocations(); 718 } 719 }; 720 721 template <class DeviceAccessor> 722 class XRTCompactAllocationsOp : public OpKernel { 723 public: XRTCompactAllocationsOp(OpKernelConstruction * ctx)724 explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 725 ~XRTCompactAllocationsOp() override = default; 726 XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete; 727 XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete; 728 Compute(OpKernelContext * ctx)729 void Compute(OpKernelContext* ctx) override { 730 VLOG(1) << "XRTCompactAllocationsOp::Compute"; 731 auto timed = 732 monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell()); 733 734 ResourceMgr* rm; 735 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 736 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 737 class DeviceAccessor::ScopedRef device_ref; 738 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 739 OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations( 740 device_ref.backend(), device_ref.device_ordinal(), 741 device_ref.allocator())); 742 } 743 }; 744 745 template <class DeviceAccessor> 746 class XRTMemoryInfoOp : public OpKernel { 747 public: XRTMemoryInfoOp(OpKernelConstruction * ctx)748 explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 749 ~XRTMemoryInfoOp() override = default; 750 XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete; 751 XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete; 752 Compute(OpKernelContext * ctx)753 void Compute(OpKernelContext* ctx) override { 754 auto kernel_fn = [&]() -> Status { 755 VLOG(1) << "XRTMemoryInfoOp::Compute"; 756 757 class DeviceAccessor::ScopedRef device_ref; 758 TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref)); 759 TF_ASSIGN_OR_RETURN( 760 se::StreamExecutor * stream_executor, 761 device_ref.backend()->stream_executor(device_ref.device_ordinal())); 762 int64_t mem_free = -1; 763 int64_t mem_total = -1; 764 if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) { 765 VLOG(2) << "Device " << ctx->device()->name() 766 << " does not expose memory information"; 767 } 768 xrt::MemoryInfo mem_info; 769 mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1); 770 mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1); 771 772 Tensor output(DT_STRING, TensorShape({})); 773 output.scalar<tstring>()() = mem_info.SerializeAsString(); 774 ctx->set_output(0, output); 775 return Status::OK(); 776 }; 777 OP_REQUIRES_OK(ctx, kernel_fn()); 778 } 779 }; 780 781 } // namespace tensorflow 782 783 #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 784