• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_tensor.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 
19 namespace tensorflow {
20 
FromTensor(const Tensor * tensor)21 /*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
22   if (tensor->NumElements() == 0) {
23     return nullptr;
24   }
25   XlaTensor* xla_tensor =
26       FromOpaquePointer(const_cast<char*>(tensor->tensor_data().data()));
27   return xla_tensor;
28 }
29 
RefCountIsOne(const Tensor & tensor)30 /*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) {
31   return tensor.RefCountIsOne();
32 }
33 
DeviceMemoryFromTensor(const Tensor & tensor)34 /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
35     const Tensor& tensor) {
36   const XlaTensor* xla_tensor = FromTensor(&tensor);
37   if (xla_tensor) {
38     CHECK(xla_tensor->has_shaped_buffer());
39     return xla_tensor->shaped_buffer().root_buffer();
40   } else {
41     return se::DeviceMemoryBase(const_cast<char*>(tensor.tensor_data().data()),
42                                 tensor.tensor_data().size());
43   }
44 }
45 
AllocateShapedBuffer(DataType dtype,const xla::Shape & on_host_shape,xla::LocalClient * client,int device_ordinal)46 Status XlaTensor::AllocateShapedBuffer(DataType dtype,
47                                        const xla::Shape& on_host_shape,
48                                        xla::LocalClient* client,
49                                        int device_ordinal) {
50   xla::Shape on_device_shape =
51       client->backend().transfer_manager()->HostShapeToDeviceShape(
52           on_host_shape);
53 
54   xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
55                                         client->backend().memory_allocator(),
56                                         device_ordinal);
57   for (auto& index_to_buffer : shaped_buffer.buffers()) {
58     xla::Shape subshape =
59         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
60     uint64 size =
61         client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
62     TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
63                         client->backend().memory_allocator()->Allocate(
64                             device_ordinal, size, /*retry_on_failure=*/false));
65     // Move our buffer into shaped_buffer, which takes ownership of it.
66     index_to_buffer.second = buffer.Forget();
67   }
68 
69   VLOG(4) << shaped_buffer.ToString();
70 
71   set_shaped_buffer(std::move(shaped_buffer));
72   return Status::OK();
73 }
74 
WaitForDefinitionEventOnStream(se::Stream * stream)75 void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) {
76   mutex_lock lock(mu_);
77   if (!definition_event_) {
78     return;
79   }
80 
81   // The set of defined streams is expected to be very small indeed (usually
82   // 1-2), so a simple linear scan should be fast enough.
83   if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
84                 stream) != streams_defined_on_.end()) {
85     // stream is in streams_defined_on_; it doesn't need to be waited on.
86     return;
87   }
88 
89   stream->ThenWaitFor(definition_event_.get());
90   streams_defined_on_.push_back(stream);
91 }
92 
ResetDefinitionEvent(std::shared_ptr<se::Event> event,se::Stream * stream)93 void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
94                                      se::Stream* stream) {
95   mutex_lock lock(mu_);
96   definition_event_ = std::move(event);
97   streams_defined_on_ = {stream};
98 }
99 
100 // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
101 // device-side tensors, which are either CPU or GPU memory pointers. This works
102 // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
103 namespace {
104 constexpr uintptr_t kTag = 0x1ULL;
105 }
106 
FromOpaquePointer(void * ptr)107 /*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) {
108   uintptr_t value = reinterpret_cast<uintptr_t>(ptr);
109   if (value & kTag) {
110     return reinterpret_cast<XlaTensor*>(value & ~kTag);
111   } else {
112     return nullptr;
113   }
114 }
115 
ToOpaquePointer(XlaTensor * tensor)116 /*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) {
117   uintptr_t value = reinterpret_cast<uintptr_t>(tensor);
118   CHECK_EQ(value & kTag, 0);
119   value |= kTag;
120   return reinterpret_cast<XlaTensor*>(value);
121 }
122 
123 }  // namespace tensorflow
124