1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for allocating XLA literals in device memory and managing handles 17 // that refer to them. 18 19 #ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 20 #define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 21 22 #include <functional> 23 #include <memory> 24 #include <string> 25 26 #include "tensorflow/compiler/tf2xla/literal_util.h" 27 #include "tensorflow/compiler/tf2xla/shape_util.h" 28 #include "tensorflow/compiler/tf2xla/type_util.h" 29 #include "tensorflow/compiler/xla/client/local_client.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/literal.h" 32 #include "tensorflow/compiler/xla/status_macros.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/compiler/xrt/xrt.pb.h" 36 #include "tensorflow/compiler/xrt/xrt_device.h" 37 #include "tensorflow/compiler/xrt/xrt_memory_manager.h" 38 #include "tensorflow/compiler/xrt/xrt_metrics.h" 39 #include "tensorflow/compiler/xrt/xrt_state.h" 40 #include "tensorflow/core/common_runtime/dma_helper.h" 41 #include "tensorflow/core/framework/op_kernel.h" 42 #include "tensorflow/core/framework/resource_mgr.h" 43 #include "tensorflow/core/framework/tensor.h" 44 #include "tensorflow/core/framework/tensor_shape.h" 45 #include "tensorflow/core/framework/types.pb.h" 46 #include "tensorflow/core/lib/core/errors.h" 47 #include "tensorflow/core/lib/core/refcount.h" 48 #include "tensorflow/core/lib/core/status.h" 49 #include "tensorflow/core/lib/gtl/cleanup.h" 50 #include "tensorflow/core/lib/monitoring/percentile_sampler.h" 51 #include "tensorflow/core/lib/monitoring/timed.h" 52 #include "tensorflow/core/platform/types.h" 53 54 namespace tensorflow { 55 56 // Helper functions for templated ops. 57 class XRTStateHelpers { 58 public: 59 // The Status return value allows us to use the 60 // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an 61 // OpKernel::Compute method. MakeLiteral(const xla::LiteralProto & proto,xla::Literal * literal)62 static Status MakeLiteral(const xla::LiteralProto& proto, 63 xla::Literal* literal) { 64 TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); 65 return Status::OK(); 66 } 67 68 // ParseTupleNode is the recursive function used to parse a recursive 69 // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the 70 // tuple shape where every leaf is an existing allocation. As a side-effect it 71 // fills in input_vector by looking up allocations from handles in the 72 // input_tensor_list as they are referenced by nodes in the proto. ParseTupleNode(const xrt::XLATupleNode & tuple_node,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::Shape * shape,ResourceMgr * rm)73 static Status ParseTupleNode( 74 const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list, 75 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 76 xla::Shape* shape, ResourceMgr* rm) { 77 if (tuple_node.tuples_size() > 0) { 78 // This is an internal node in the proto so descend recursively. 79 xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({}); 80 std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy); 81 *xla::ShapeUtil::GetMutableSubshape(shape, {}) = 82 xla::ShapeUtil::MakeTupleShape(subshapes); 83 for (int i = 0; i < tuple_node.tuples_size(); ++i) { 84 TF_RETURN_IF_ERROR(ParseTupleNode( 85 tuple_node.tuples(i), input_tensor_list, input_vector, 86 xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm)); 87 } 88 } else { 89 // This is a leaf node in the proto so look up the referenced input. 90 int input_index = tuple_node.input_index(); 91 if (input_index < 0 || input_index >= input_vector->size()) { 92 return errors::InvalidArgument("Invalid tuple input index ", 93 input_index, ": MakeTuple has ", 94 input_vector->size(), " inputs."); 95 } 96 bool release_this_input = tuple_node.release_input_handle(); 97 XRTTupleAllocation::ExpandedTupleInput& input = 98 input_vector->at(input_index); 99 if (input.allocation != nullptr && 100 (input.release_allocation_after_use || release_this_input)) { 101 return errors::InvalidArgument( 102 "Invalid tuple tree: input index ", input_index, 103 " is repeated but release_input_handle is true."); 104 } 105 if (input.allocation == nullptr) { 106 // We haven't dereferenced this handle yet. 107 TF_RET_CHECK( 108 TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape())); 109 int64 key = input_tensor_list[input_index].scalar<int64>()(); 110 TF_ASSIGN_OR_RETURN(input.allocation, 111 XRTMemoryManager::Get(rm)->Lookup(key)); 112 input.release_allocation_after_use = release_this_input; 113 } 114 } 115 return Status::OK(); 116 } 117 118 // Parses a xrt::XLATupleNode proto recursively and returns the corresponding 119 // ShapeTree where each leaf is an allocation corresponding to a handle in 120 // input_tensor_list. The ordinal of one of the allocations is returned in 121 // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with 122 // no leaves, device_ordinal will always be filled in by a successful call to 123 // ParseTupleTree. ParseTupleTree(const xrt::XLATupleNode & tuple_tree_root,const OpInputList & input_tensor_list,std::vector<XRTTupleAllocation::ExpandedTupleInput> * input_vector,xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> * tuple_shape_tree,int * device_ordinal,ResourceMgr * rm)124 static Status ParseTupleTree( 125 const xrt::XLATupleNode& tuple_tree_root, 126 const OpInputList& input_tensor_list, 127 std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector, 128 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree, 129 int* device_ordinal, ResourceMgr* rm) { 130 // First get the shape of the 'spine' of the new tuple, where every leaf is 131 // an existing allocation. As a side-effect dereference the input handles 132 // into allocations in input_vector. 133 xla::Shape tuple_tree_shape; 134 TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list, 135 input_vector, &tuple_tree_shape, rm)); 136 // Make the shape tree of allocations where the shape is the spine and each 137 // leaf is one of the allocations looked up in input_vector. Internal nodes 138 // have nullptr allocations. 139 *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>( 140 tuple_tree_shape); 141 tuple_shape_tree->ForEachMutableElement( 142 [&](const xla::ShapeIndex& index, 143 XRTTupleAllocation::ExpandedTupleInput* element) { 144 if (tuple_shape_tree->IsLeaf(index)) { 145 // Find the matching leaf in the proto tree. 146 const xrt::XLATupleNode* tuple_node = &tuple_tree_root; 147 for (int i = 0; i < index.size(); ++i) { 148 tuple_node = &tuple_node->tuples(index[i]); 149 } 150 // Copy the appropriate input allocation to the leaf of the 151 // tuple_shape_tree. 152 int input_index = tuple_node->input_index(); 153 *element = input_vector->at(input_index); 154 CHECK(element->release_allocation_after_use == 155 tuple_node->release_input_handle()); 156 // We just need to know the device_ordinal of one of the 157 // allocations. We will validate later that they are all the same. 158 *device_ordinal = (*element).allocation->device_ordinal(); 159 } 160 }); 161 return Status::OK(); 162 } 163 }; 164 165 // Op that allocates memory for a literal and transfers it to the device. 166 template <class DeviceAccessor> 167 class XRTAllocateOp : public OpKernel { 168 public: XRTAllocateOp(OpKernelConstruction * ctx)169 explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 170 ~XRTAllocateOp() override = default; 171 XRTAllocateOp(const XRTAllocateOp&) = delete; 172 XRTAllocateOp& operator=(const XRTAllocateOp&) = delete; 173 Compute(OpKernelContext * ctx)174 void Compute(OpKernelContext* ctx) override { 175 VLOG(1) << "XRTAllocateOp::Compute"; 176 auto timed = monitoring::MakeTimed(xrt_metrics::GetAllocateCell()); 177 178 const Tensor& allocation_info = ctx->input(0); 179 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()), 180 errors::Internal("allocation input should be a string scalar")); 181 xrt::XLAAllocation allocation_proto; 182 OP_REQUIRES(ctx, 183 ParseFromTString(allocation_info.scalar<tstring>()(), 184 &allocation_proto), 185 errors::InvalidArgument( 186 "Unable to parse allocation input to XLAAllocation")); 187 188 xla::Literal literal; 189 OP_REQUIRES_OK( 190 ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); 191 192 ResourceMgr* rm; 193 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 194 195 // We are guaranteed that the underlying device object won't be deleted out 196 // from under us, while the ScopedRef is live. 197 class DeviceAccessor::ScopedRef device_ref; 198 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 199 200 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 201 XRTTupleAllocation* allocation; 202 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 203 literal, memory_manager.get(), device_ref.backend(), 204 device_ref.device_ordinal(), &allocation)); 205 206 Tensor output(DT_INT64, TensorShape({})); 207 output.scalar<int64>()() = memory_manager->Register(allocation); 208 ctx->set_output(0, output); 209 } 210 }; 211 212 // Op that allocates uninitialized memory on the device for a tensor of 213 // a particular shape. 214 template <class DeviceAccessor> 215 class XRTAllocateUninitializedOp : public OpKernel { 216 public: XRTAllocateUninitializedOp(OpKernelConstruction * ctx)217 explicit XRTAllocateUninitializedOp(OpKernelConstruction* ctx) 218 : OpKernel(ctx) { 219 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 220 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tf_shape_)); 221 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, tf_shape_, &xla_shape_)); 222 } 223 ~XRTAllocateUninitializedOp() override = default; 224 XRTAllocateUninitializedOp(const XRTAllocateUninitializedOp&) = delete; 225 XRTAllocateUninitializedOp& operator=(const XRTAllocateUninitializedOp&) = 226 delete; 227 Compute(OpKernelContext * ctx)228 void Compute(OpKernelContext* ctx) override { 229 VLOG(1) << "XRTAllocateUninitializedOp::Compute"; 230 auto timed = 231 monitoring::MakeTimed(xrt_metrics::GetAllocateUninitializedCell()); 232 ResourceMgr* rm; 233 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 234 235 // We are guaranteed that the underlying device object won't be deleted out 236 // from under us, while the ScopedRef is live. 237 class DeviceAccessor::ScopedRef device_ref; 238 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 239 240 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 241 XRTTupleAllocation* allocation; 242 OP_REQUIRES_OK(ctx, 243 XRTTupleAllocation::CreateUninitialized( 244 xla_shape_, memory_manager.get(), device_ref.backend(), 245 device_ref.device_ordinal(), &allocation)); 246 247 Tensor output(DT_INT64, TensorShape({})); 248 output.scalar<int64>()() = memory_manager->Register(allocation); 249 ctx->set_output(0, output); 250 } 251 252 private: 253 DataType dtype_; 254 TensorShape tf_shape_; 255 xla::Shape xla_shape_; 256 }; 257 258 // Op that allocates memory for a tensor (with optional layout) and transfers it 259 // to the device, returning an allocation handle. 260 template <class DeviceAccessor> 261 class XRTAllocateFromTensorOp : public OpKernel { 262 public: XRTAllocateFromTensorOp(OpKernelConstruction * ctx)263 explicit XRTAllocateFromTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 264 bool make_tuple = false; 265 OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &tf_shapes_)); 266 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 267 OP_REQUIRES_OK(ctx, ctx->GetAttr("make_tuple", &make_tuple)); 268 std::vector<int64> minor_to_major; 269 if (ctx->HasAttr("layouts")) { 270 OP_REQUIRES_OK(ctx, ctx->GetAttr("layouts", &minor_to_major)); 271 } 272 OP_REQUIRES( 273 ctx, tf_shapes_.size() == dtypes_.size(), 274 errors::InvalidArgument("shapes and dtypes must be the same length")); 275 std::vector<xla::Shape> xla_shapes; 276 xla_shapes.reserve(tf_shapes_.size()); 277 for (int i = 0; i < tf_shapes_.size(); i++) { 278 xla::Shape xla_shape; 279 OP_REQUIRES_OK( 280 ctx, TensorShapeToXLAShape(dtypes_[i], tf_shapes_[i], &xla_shape)); 281 xla_shapes.push_back(std::move(xla_shape)); 282 } 283 if (xla_shapes.size() > 1 || make_tuple) { 284 shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes); 285 } else { 286 shape_.Swap(&xla_shapes.front()); 287 } 288 if (!minor_to_major.empty()) { 289 xla::Shape shape_with_layouts; 290 OP_REQUIRES_OK(ctx, GetShapeWithLayout(shape_, minor_to_major, 291 /*layout_func=*/nullptr, 292 &shape_with_layouts)); 293 shape_.Swap(&shape_with_layouts); 294 } 295 } 296 297 ~XRTAllocateFromTensorOp() override = default; 298 XRTAllocateFromTensorOp(const XRTAllocateFromTensorOp&) = delete; 299 XRTAllocateFromTensorOp& operator=(const XRTAllocateFromTensorOp&) = delete; 300 Compute(OpKernelContext * ctx)301 void Compute(OpKernelContext* ctx) override { 302 VLOG(1) << "XRTAllocateFromTensorOp::Compute"; 303 auto timed = 304 monitoring::MakeTimed(xrt_metrics::GetAllocateFromTensorCell()); 305 306 OpInputList values; 307 OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &values)); 308 OP_REQUIRES(ctx, values.size() == tf_shapes_.size(), 309 errors::InvalidArgument( 310 "Wrong number of inputs to XRTAllocateFromTensor: ", 311 values.size(), " vs. ", tf_shapes_.size())); 312 313 std::vector<const char*> tensors_data; 314 for (size_t i = 0; i < values.size(); ++i) { 315 const Tensor& input_tensor = values[i]; 316 OP_REQUIRES(ctx, input_tensor.dtype() == dtypes_[i], 317 errors::InvalidArgument( 318 "Input tensor type and input dtype do not match")); 319 // We allow the requested on-device shape to differ from the shape of the 320 // input tensor, as long as they have the same number of elements. 321 OP_REQUIRES( 322 ctx, 323 input_tensor.shape().num_elements() == tf_shapes_[i].num_elements(), 324 errors::InvalidArgument( 325 "Input tensor must have the number of elements specified " 326 "in the matching input shape: ", 327 input_tensor.shape().num_elements(), " vs. ", 328 tf_shapes_[i].num_elements(), " at index ", i)); 329 tensors_data.push_back( 330 static_cast<const char*>(DMAHelper::base(&input_tensor))); 331 } 332 // Use the buffer straight out of the input tensors to create the literal. 333 xla::BorrowingLiteral literal = 334 shape_.IsTuple() ? xla::BorrowingLiteral(tensors_data, shape_) 335 : xla::BorrowingLiteral(tensors_data.front(), shape_); 336 ResourceMgr* rm; 337 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 338 339 // We are guaranteed that the underlying device object won't be deleted out 340 // from under us, while the ScopedRef is live. 341 class DeviceAccessor::ScopedRef device_ref; 342 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 343 344 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 345 XRTTupleAllocation* allocation; 346 OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( 347 literal, memory_manager.get(), device_ref.backend(), 348 device_ref.device_ordinal(), &allocation)); 349 350 Tensor output(DT_INT64, TensorShape({})); 351 output.scalar<int64>()() = memory_manager->Register(allocation); 352 ctx->set_output(0, output); 353 } 354 355 private: 356 std::vector<TensorShape> tf_shapes_; 357 DataTypeVector dtypes_; 358 xla::Shape shape_; 359 }; 360 361 // Op that takes a tuple handle input and returns a handle to a sub-tuple of the 362 // input. 363 template <bool discard_, class DeviceAccessor> 364 class XRTSubTupleOp : public OpKernel { 365 public: XRTSubTupleOp(OpKernelConstruction * ctx)366 explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 367 ~XRTSubTupleOp() override = default; 368 XRTSubTupleOp(const XRTSubTupleOp&) = delete; 369 XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete; 370 Compute(OpKernelContext * ctx)371 void Compute(OpKernelContext* ctx) override { 372 VLOG(1) << "XRTSubTupleOp::Compute"; 373 auto timed = monitoring::MakeTimed(xrt_metrics::GetSubTupleCell()); 374 375 const Tensor& handle_tensor = ctx->input(0); 376 OP_REQUIRES( 377 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 378 errors::Internal("computation input should be an int64 scalar")); 379 int64 allocation_handle = handle_tensor.scalar<int64>()(); 380 381 const Tensor& subtuple_info = ctx->input(1); 382 OP_REQUIRES( 383 ctx, TensorShapeUtils::IsVector(subtuple_info.shape()), 384 errors::Internal("tuple index input should be an int32 vector")); 385 xla::ShapeIndex shape_index; 386 for (int i = 0; i < subtuple_info.dim_size(0); ++i) { 387 shape_index.push_back(subtuple_info.vec<int32>()(i)); 388 } 389 390 ResourceMgr* rm; 391 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 392 393 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 394 RefPtr<XRTTupleAllocation> allocation; 395 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 396 397 if (discard_) { 398 VLOG(2) << "Releasing handle " << allocation_handle; 399 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 400 } 401 402 XRTTupleAllocation* suballocation; 403 OP_REQUIRES_OK( 404 ctx, XRTTupleAllocation::MakeSubBuffer(allocation.get(), shape_index, 405 &suballocation, !discard_)); 406 407 Tensor output(DT_INT64, TensorShape({})); 408 output.scalar<int64>()() = memory_manager->Register(suballocation); 409 ctx->set_output(0, output); 410 } 411 }; 412 413 // Op that allocates memory for a literal and transfers it to the device. 414 template <class DeviceAccessor> 415 class XRTMakeTupleOp : public OpKernel { 416 public: XRTMakeTupleOp(OpKernelConstruction * ctx)417 explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 418 ~XRTMakeTupleOp() override = default; 419 XRTMakeTupleOp(const XRTMakeTupleOp&) = delete; 420 XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete; 421 Compute(OpKernelContext * ctx)422 void Compute(OpKernelContext* ctx) override { 423 VLOG(1) << "XRTMakeTupleOp::Compute"; 424 auto timed = monitoring::MakeTimed(xrt_metrics::GetMakeTupleCell()); 425 426 const Tensor& tuple_info = ctx->input(0); 427 OP_REQUIRES( 428 ctx, TensorShapeUtils::IsScalar(tuple_info.shape()), 429 errors::Internal("tuple description input should be a string scalar")); 430 xrt::XLATupleNode tuple_proto; 431 OP_REQUIRES( 432 ctx, ParseFromTString(tuple_info.scalar<tstring>()(), &tuple_proto), 433 errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); 434 435 OpInputList arg_list; 436 OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list)); 437 438 // For each input, the allocation it corresponds to and a flag indicating 439 // whether or not it should be released, i.e. discarded from the resource 440 // manager. One ref on each allocation is owned by this vector, and freed on 441 // exit. 442 std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector( 443 arg_list.size()); 444 ResourceMgr* rm; 445 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 446 447 xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree; 448 // device_ordinal is filled in by ParseTupleTree with the ordinal of one of 449 // the allocations. It is guaranteed that there is at least on allocation in 450 // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that 451 // all the allocations are on the same device. 452 int device_ordinal; 453 OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree( 454 tuple_proto, arg_list, &input_vector, 455 &tuple_shape_tree, &device_ordinal, rm)); 456 457 // We are guaranteed that the underlying device object won't be deleted out 458 // from under us, while the ScopedRef is live. 459 class DeviceAccessor::ScopedRef device_ref; 460 OP_REQUIRES_OK( 461 ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref)); 462 463 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 464 XRTTupleAllocation* output_allocation; 465 OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple( 466 memory_manager.get(), device_ref.backend(), 467 device_ref.device_ordinal(), tuple_shape_tree, 468 &output_allocation)); 469 RefPtr<XRTTupleAllocation> output_ptr(output_allocation); 470 for (int i = 0; i < input_vector.size(); ++i) { 471 if (input_vector[i].release_allocation_after_use) { 472 OP_REQUIRES_OK(ctx, 473 memory_manager->Release(arg_list[i].scalar<int64>()())); 474 } 475 } 476 477 Tensor output(DT_INT64, TensorShape({})); 478 output.scalar<int64>()() = memory_manager->Register(std::move(output_ptr)); 479 ctx->set_output(0, output); 480 } 481 }; 482 483 // Op that reads a device-resident tuple to host memory and returns it as a 484 // literal. 485 template <bool discard_, class DeviceAccessor> 486 class XRTReadLiteralOp : public OpKernel { 487 public: XRTReadLiteralOp(OpKernelConstruction * ctx)488 explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 489 ~XRTReadLiteralOp() override = default; 490 XRTReadLiteralOp(const XRTReadLiteralOp&) = delete; 491 XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete; 492 Compute(OpKernelContext * ctx)493 void Compute(OpKernelContext* ctx) override { 494 VLOG(1) << "XRTReadLiteralOp::Compute"; 495 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadLiteralCell()); 496 497 const Tensor& handle_tensor = ctx->input(0); 498 OP_REQUIRES( 499 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 500 errors::Internal("computation input should be an int64 scalar")); 501 int64 allocation_handle = handle_tensor.scalar<int64>()(); 502 503 ResourceMgr* rm; 504 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 505 506 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 507 RefPtr<XRTTupleAllocation> allocation; 508 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 509 510 if (discard_) { 511 VLOG(2) << "Releasing handle " << allocation_handle; 512 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 513 } 514 515 // We are guaranteed that the underlying device object won't be deleted out 516 // from under us, while the ScopedRef is live. 517 class DeviceAccessor::ScopedRef device_ref; 518 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 519 ctx, allocation->device_ordinal(), &device_ref)); 520 521 xla::Literal literal(allocation->on_host_shape()); 522 OP_REQUIRES_OK(ctx, allocation->ToLiteral(device_ref.backend(), &literal)); 523 xla::LiteralProto literal_proto = literal.ToProto(); 524 525 Tensor output(DT_STRING, TensorShape({})); 526 SerializeToTString(literal_proto, &output.scalar<tstring>()()); 527 ctx->set_output(0, output); 528 } 529 }; 530 531 // Op that reads a device-resident tuple to host memory and returns it as a 532 // literal. 533 template <class DeviceAccessor> 534 class XRTReadToTensorOp : public OpKernel { 535 public: XRTReadToTensorOp(OpKernelConstruction * ctx)536 explicit XRTReadToTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 537 OP_REQUIRES_OK(ctx, ctx->GetAttr("release_handles", &discard_)); 538 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); 539 } 540 ~XRTReadToTensorOp() override = default; 541 XRTReadToTensorOp(const XRTReadToTensorOp&) = delete; 542 XRTReadToTensorOp& operator=(const XRTReadToTensorOp&) = delete; 543 Compute(OpKernelContext * ctx)544 void Compute(OpKernelContext* ctx) override { 545 VLOG(1) << "XRTReadToTensorOp::Compute"; 546 auto timed = monitoring::MakeTimed(xrt_metrics::GetReadToTensorCell()); 547 548 const Tensor& handle_tensor = ctx->input(0); 549 // TODO(phawkins,dlibenzi): accept multiple handles (i.e., vectors, not 550 // just scalars.) 551 OP_REQUIRES( 552 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 553 errors::Internal("computation input should be an int64 scalar")); 554 int64 allocation_handle = handle_tensor.scalar<int64>()(); 555 556 ResourceMgr* rm; 557 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 558 559 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 560 RefPtr<XRTTupleAllocation> allocation; 561 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 562 563 if (discard_) { 564 VLOG(2) << "Releasing handle " << allocation_handle; 565 OP_REQUIRES_OK(ctx, memory_manager->Release(allocation_handle)); 566 } 567 568 // We are guaranteed that the underlying device object won't be deleted out 569 // from under us, while the ScopedRef is live. 570 class DeviceAccessor::ScopedRef device_ref; 571 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 572 ctx, allocation->device_ordinal(), &device_ref)); 573 574 xla::Shape shape = allocation->on_host_shape(); 575 int output = 0; 576 Status status = xla::ShapeUtil::ForEachMutableSubshapeWithStatus( 577 &shape, 578 [&](xla::Shape* subshape, const xla::ShapeIndex& index) -> Status { 579 if (subshape->IsTuple()) return Status::OK(); 580 581 xla::PrimitiveType xla_type; 582 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType( 583 ctx->expected_output_dtype(output), &xla_type)); 584 if (xla_type != subshape->element_type()) { 585 return errors::InvalidArgument( 586 "Type mismatch between buffer type (", subshape->ToString(), 587 ") and tensor type (", 588 DataTypeString(ctx->expected_output_dtype(output)), 589 ") for output tensor ", output); 590 } 591 592 TensorShape output_shape; 593 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(*subshape, &output_shape)); 594 595 Tensor* output_tensor; 596 TF_RETURN_IF_ERROR( 597 ctx->allocate_output(output, output_shape, &output_tensor)); 598 599 XRTTupleAllocation* sub; 600 TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( 601 allocation.get(), index, &sub, /*alias_parent_allocation=*/true)); 602 core::ScopedUnref sub_unref(sub); 603 604 xla::MutableBorrowingLiteral literal; 605 TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral( 606 xla::LayoutUtil::GetWithDefaultLayout(*subshape), output_tensor, 607 &literal)); 608 TF_RETURN_IF_ERROR(sub->ToLiteral(device_ref.backend(), &literal)); 609 610 ++output; 611 return Status::OK(); 612 }); 613 OP_REQUIRES_OK(ctx, status); 614 } 615 bool discard_; 616 DataTypeVector dtypes_; 617 }; 618 619 // Op that writes a new literal value into device-resident memory. 620 template <class DeviceAccessor> 621 class XRTWriteLiteralOp : public OpKernel { 622 public: XRTWriteLiteralOp(OpKernelConstruction * ctx)623 explicit XRTWriteLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 624 ~XRTWriteLiteralOp() override = default; 625 XRTWriteLiteralOp(const XRTWriteLiteralOp&) = delete; 626 XRTWriteLiteralOp& operator=(const XRTWriteLiteralOp&) = delete; 627 Compute(OpKernelContext * ctx)628 void Compute(OpKernelContext* ctx) override { 629 VLOG(1) << "XRTWriteLiteralOp::Compute"; 630 auto timed = monitoring::MakeTimed(xrt_metrics::GetWriteLiteralCell()); 631 632 const Tensor& handle_tensor = ctx->input(0); 633 OP_REQUIRES( 634 ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), 635 errors::Internal("computation input should be an int64 scalar")); 636 int64 allocation_handle = handle_tensor.scalar<int64>()(); 637 638 const Tensor& literal_info = ctx->input(1); 639 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(literal_info.shape()), 640 errors::Internal("literal input should be a string scalar")); 641 xla::LiteralProto literal_proto; 642 OP_REQUIRES( 643 ctx, ParseFromTString(literal_info.scalar<tstring>()(), &literal_proto), 644 errors::InvalidArgument( 645 "Unable to parse allocation input to LiteralProto")); 646 xla::Literal literal; 647 OP_REQUIRES_OK(ctx, XRTStateHelpers::MakeLiteral(literal_proto, &literal)); 648 649 ResourceMgr* rm; 650 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 651 652 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 653 RefPtr<XRTTupleAllocation> allocation; 654 OP_REQUIRES_OK(ctx, memory_manager->Lookup(allocation_handle, &allocation)); 655 656 // We are guaranteed that the underlying device object won't be deleted out 657 // from under us, while the ScopedRef is live. 658 typename DeviceAccessor::ScopedRef device_ref; 659 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( 660 ctx, allocation->device_ordinal(), &device_ref)); 661 OP_REQUIRES_OK(ctx, 662 allocation->WriteLiteral(device_ref.backend(), literal)); 663 664 Tensor output(DT_INT64, TensorShape({})); 665 output.scalar<int64>()() = allocation_handle; 666 ctx->set_output(0, output); 667 } 668 }; 669 670 // Op that discards a handle to device memory. 671 template <class DeviceAccessor> 672 class XRTReleaseAllocationOp : public OpKernel { 673 public: XRTReleaseAllocationOp(OpKernelConstruction * ctx)674 explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 675 ~XRTReleaseAllocationOp() override = default; 676 XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete; 677 XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete; 678 Compute(OpKernelContext * ctx)679 void Compute(OpKernelContext* ctx) override { 680 VLOG(1) << "XRTReleaseAllocationOp::Compute"; 681 auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseAllocationCell()); 682 683 ResourceMgr* rm; 684 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 685 686 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 687 const Tensor& allocation_handle = ctx->input(0); 688 auto flat_keys = allocation_handle.flat<int64>(); 689 for (int64 i = 0; i < flat_keys.size(); ++i) { 690 int64 key = flat_keys(i); 691 OP_REQUIRES_OK(ctx, memory_manager->Release(key)); 692 VLOG(2) << "Released allocation handle " << key; 693 } 694 } 695 }; 696 697 // Op that discards a handle to device memory. 698 template <class DeviceAccessor> 699 class XRTReleaseAllAllocationsOp : public OpKernel { 700 public: XRTReleaseAllAllocationsOp(OpKernelConstruction * ctx)701 explicit XRTReleaseAllAllocationsOp(OpKernelConstruction* ctx) 702 : OpKernel(ctx) {} 703 ~XRTReleaseAllAllocationsOp() override = default; 704 XRTReleaseAllAllocationsOp(const XRTReleaseAllAllocationsOp&) = delete; 705 XRTReleaseAllAllocationsOp& operator=(const XRTReleaseAllAllocationsOp&) = 706 delete; 707 Compute(OpKernelContext * ctx)708 void Compute(OpKernelContext* ctx) override { 709 VLOG(1) << "XRTReleaseAllAllocationsOp::Compute"; 710 auto timed = 711 monitoring::MakeTimed(xrt_metrics::GetReleaseAllAllocationsCell()); 712 713 ResourceMgr* rm; 714 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 715 XRTMemoryManager::Get(rm)->ReleaseAllAllocations(); 716 } 717 }; 718 719 template <class DeviceAccessor> 720 class XRTCompactAllocationsOp : public OpKernel { 721 public: XRTCompactAllocationsOp(OpKernelConstruction * ctx)722 explicit XRTCompactAllocationsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 723 ~XRTCompactAllocationsOp() override = default; 724 XRTCompactAllocationsOp(const XRTCompactAllocationsOp&) = delete; 725 XRTCompactAllocationsOp& operator=(const XRTCompactAllocationsOp&) = delete; 726 Compute(OpKernelContext * ctx)727 void Compute(OpKernelContext* ctx) override { 728 VLOG(1) << "XRTCompactAllocationsOp::Compute"; 729 auto timed = 730 monitoring::MakeTimed(xrt_metrics::GetCompactAllocationsCell()); 731 732 ResourceMgr* rm; 733 OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm)); 734 RefPtr<XRTMemoryManager> memory_manager = XRTMemoryManager::Get(rm); 735 class DeviceAccessor::ScopedRef device_ref; 736 OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(ctx, &device_ref)); 737 OP_REQUIRES_OK(ctx, memory_manager->CompactAllocations( 738 device_ref.backend(), device_ref.device_ordinal())); 739 } 740 }; 741 742 template <class DeviceAccessor> 743 class XRTMemoryInfoOp : public OpKernel { 744 public: XRTMemoryInfoOp(OpKernelConstruction * ctx)745 explicit XRTMemoryInfoOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 746 ~XRTMemoryInfoOp() override = default; 747 XRTMemoryInfoOp(const XRTMemoryInfoOp&) = delete; 748 XRTMemoryInfoOp& operator=(const XRTMemoryInfoOp&) = delete; 749 Compute(OpKernelContext * ctx)750 void Compute(OpKernelContext* ctx) override { 751 auto kernel_fn = [&]() -> Status { 752 VLOG(1) << "XRTMemoryInfoOp::Compute"; 753 754 class DeviceAccessor::ScopedRef device_ref; 755 TF_RETURN_IF_ERROR(DeviceAccessor::InitScopedRef(ctx, &device_ref)); 756 TF_ASSIGN_OR_RETURN( 757 se::StreamExecutor * stream_executor, 758 device_ref.backend()->stream_executor(device_ref.device_ordinal())); 759 int64 mem_free = -1; 760 int64 mem_total = -1; 761 if (!stream_executor->DeviceMemoryUsage(&mem_free, &mem_total)) { 762 VLOG(2) << "Device " << ctx->device()->name() 763 << " does not expose memory information"; 764 } 765 xrt::MemoryInfo mem_info; 766 mem_info.set_kb_total((mem_total >= 0) ? mem_total / 1024 : -1); 767 mem_info.set_kb_free((mem_free >= 0) ? mem_free / 1024 : -1); 768 769 Tensor output(DT_STRING, TensorShape({})); 770 output.scalar<tstring>()() = mem_info.SerializeAsString(); 771 ctx->set_output(0, output); 772 return Status::OK(); 773 }; 774 OP_REQUIRES_OK(ctx, kernel_fn()); 775 } 776 }; 777 778 } // namespace tensorflow 779 780 #endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_ 781