1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_tensor.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18
19 namespace tensorflow {
20
FromTensor(const Tensor * tensor)21 /*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
22 if (tensor->NumElements() == 0) {
23 return nullptr;
24 }
25 XlaTensor* xla_tensor =
26 FromOpaquePointer(const_cast<char*>(tensor->tensor_data().data()));
27 return xla_tensor;
28 }
29
RefCountIsOne(const Tensor & tensor)30 /*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
31 return tensor.RefCountIsOne();
32 }
33
DeviceMemoryFromTensor(const Tensor & tensor)34 /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
35 const Tensor& tensor) {
36 const XlaTensor* xla_tensor = FromTensor(&tensor);
37 if (xla_tensor) {
38 CHECK(xla_tensor->has_shaped_buffer());
39 return xla_tensor->shaped_buffer().root_buffer();
40 } else {
41 return se::DeviceMemoryBase(const_cast<char*>(tensor.tensor_data().data()),
42 tensor.tensor_data().size());
43 }
44 }
45
AllocateShapedBuffer(DataType dtype,const xla::Shape & on_host_shape,xla::LocalClient * client,int device_ordinal)46 Status XlaTensor::AllocateShapedBuffer(DataType dtype,
47 const xla::Shape& on_host_shape,
48 xla::LocalClient* client,
49 int device_ordinal) {
50 xla::Shape on_device_shape =
51 client->backend().transfer_manager()->HostShapeToDeviceShape(
52 on_host_shape);
53
54 xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
55 client->backend().memory_allocator(),
56 device_ordinal);
57 for (auto& index_to_buffer : shaped_buffer.buffers()) {
58 xla::Shape subshape =
59 xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
60 uint64 size =
61 client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
62 TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
63 client->backend().memory_allocator()->Allocate(
64 device_ordinal, size, /*retry_on_failure=*/false));
65 // Move our buffer into shaped_buffer, which takes ownership of it.
66 index_to_buffer.second = buffer.Forget();
67 }
68
69 VLOG(4) << shaped_buffer.ToString();
70
71 set_shaped_buffer(std::move(shaped_buffer));
72 return Status::OK();
73 }
74
WaitForDefinitionEventOnStream(se::Stream * stream)75 void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) {
76 mutex_lock lock(mu_);
77 if (!definition_event_) {
78 return;
79 }
80
81 // The set of defined streams is expected to be very small indeed (usually
82 // 1-2), so a simple linear scan should be fast enough.
83 if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
84 stream) != streams_defined_on_.end()) {
85 // stream is in streams_defined_on_; it doesn't need to be waited on.
86 return;
87 }
88
89 stream->ThenWaitFor(definition_event_.get());
90 streams_defined_on_.push_back(stream);
91 }
92
ResetDefinitionEvent(std::shared_ptr<se::Event> event,se::Stream * stream)93 void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
94 se::Stream* stream) {
95 mutex_lock lock(mu_);
96 definition_event_ = std::move(event);
97 streams_defined_on_ = {stream};
98 }
99
100 // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
101 // device-side tensors, which are either CPU or GPU memory pointers. This works
102 // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
103 namespace {
104 constexpr uintptr_t kTag = 0x1ULL;
105 }
106
FromOpaquePointer(void * ptr)107 /*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) {
108 uintptr_t value = reinterpret_cast<uintptr_t>(ptr);
109 if (value & kTag) {
110 return reinterpret_cast<XlaTensor*>(value & ~kTag);
111 } else {
112 return nullptr;
113 }
114 }
115
ToOpaquePointer(XlaTensor * tensor)116 /*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) {
117 uintptr_t value = reinterpret_cast<uintptr_t>(tensor);
118 CHECK_EQ(value & kTag, 0);
119 value |= kTag;
120 return reinterpret_cast<XlaTensor*>(value);
121 }
122
123 } // namespace tensorflow
124