• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_tensor.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 
19 namespace tensorflow {
20 
FromTensor(const Tensor * tensor)21 /*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) {
22   if (tensor->NumElements() == 0) {
23     return nullptr;
24   }
25   XlaTensor* xla_tensor =
26       FromOpaquePointer(const_cast<char*>(tensor->tensor_data().data()));
27   return xla_tensor;
28 }
29 
DeviceMemoryFromTensor(const Tensor & tensor)30 /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
31     const Tensor& tensor) {
32   const XlaTensor* xla_tensor = FromTensor(&tensor);
33   if (xla_tensor) {
34     CHECK(xla_tensor->has_shaped_buffer());
35     return xla_tensor->shaped_buffer().root_buffer();
36   } else {
37     return se::DeviceMemoryBase(const_cast<char*>(tensor.tensor_data().data()),
38                                 tensor.tensor_data().size());
39   }
40 }
41 
AllocateShapedBuffer(DataType dtype,const xla::Shape & on_host_shape,xla::LocalClient * client,int device_ordinal)42 Status XlaTensor::AllocateShapedBuffer(DataType dtype,
43                                        const xla::Shape& on_host_shape,
44                                        xla::LocalClient* client,
45                                        int device_ordinal) {
46   xla::Shape on_device_shape =
47       client->backend().transfer_manager()->HostShapeToDeviceShape(
48           on_host_shape);
49 
50   xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
51                                         client->backend().memory_allocator(),
52                                         device_ordinal);
53   for (auto& index_to_buffer : shaped_buffer.buffers()) {
54     xla::Shape subshape =
55         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
56     uint64 size =
57         client->backend().transfer_manager()->GetByteSizeRequirement(subshape);
58     TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer,
59                         client->backend().memory_allocator()->Allocate(
60                             device_ordinal, size, /*retry_on_failure=*/false,
61                             subshape.layout().memory_space()));
62     // Move our buffer into shaped_buffer, which takes ownership of it.
63     index_to_buffer.second = buffer.Release();
64   }
65 
66   VLOG(4) << shaped_buffer.ToString();
67 
68   set_shaped_buffer(std::move(shaped_buffer));
69   return Status::OK();
70 }
71 
WaitForDefinitionEventOnStream(se::Stream * stream)72 void XlaTensor::WaitForDefinitionEventOnStream(se::Stream* stream) {
73   mutex_lock lock(mu_);
74   if (!definition_event_) {
75     return;
76   }
77 
78   // The set of defined streams is expected to be very small indeed (usually
79   // 1-2), so a simple linear scan should be fast enough.
80   if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
81                 stream) != streams_defined_on_.end()) {
82     // stream is in streams_defined_on_; it doesn't need to be waited on.
83     return;
84   }
85 
86   stream->ThenWaitFor(definition_event_.get());
87   streams_defined_on_.push_back(stream);
88 }
89 
ResetDefinitionEvent(std::shared_ptr<se::Event> event,se::Stream * stream)90 void XlaTensor::ResetDefinitionEvent(std::shared_ptr<se::Event> event,
91                                      se::Stream* stream) {
92   mutex_lock lock(mu_);
93   definition_event_ = std::move(event);
94   streams_defined_on_ = {stream};
95 }
96 
RefreshStatusOfStreams()97 Status XlaTensor::RefreshStatusOfStreams() {
98   mutex_lock lock(mu_);
99   Status status;
100   for (se::Stream* stream : streams_defined_on_) {
101     status.Update(stream->RefreshStatus());
102   }
103   return status;
104 }
105 
106 // The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
107 // device-side tensors, which are either CPU or GPU memory pointers. This works
108 // because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
109 namespace {
110 constexpr uintptr_t kTag = 0x1ULL;
111 }
112 
FromOpaquePointer(void * ptr)113 /*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) {
114   uintptr_t value = reinterpret_cast<uintptr_t>(ptr);
115   if (value & kTag) {
116     return reinterpret_cast<XlaTensor*>(value & ~kTag);
117   } else {
118     return nullptr;
119   }
120 }
121 
ToOpaquePointer(XlaTensor * tensor)122 /*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) {
123   uintptr_t value = reinterpret_cast<uintptr_t>(tensor);
124   CHECK_EQ(value & kTag, 0);
125   value |= kTag;
126   return reinterpret_cast<XlaTensor*>(value);
127 }
128 
129 }  // namespace tensorflow
130