• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
17 
18 #include <functional>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
23 #include "tensorflow/core/platform/dynamic_annotations.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/platform/types.h"
27 #include "tensorflow/stream_executor/stream_executor.h"
28 
29 namespace xla {
30 namespace cpu {
31 namespace runtime {
32 
GetXfeedManager(int device_ordinal)33 XfeedManager* GetXfeedManager(int device_ordinal) {
34   static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
35   static absl::Mutex* mutex = new absl::Mutex();
36 
37   absl::MutexLock lock(mutex);
38   auto it = managers->find(device_ordinal);
39   if (it == managers->end()) {
40     it = managers->emplace(device_ordinal, new XfeedManager()).first;
41   }
42   return it->second;
43 }
44 
45 extern const char* const kEigenMatMulF16SymbolName =
46     "__xla_cpu_runtime_EigenMatMulF16";
47 extern const char* const kEigenMatMulF32SymbolName =
48     "__xla_cpu_runtime_EigenMatMulF32";
49 extern const char* const kEigenMatMulF64SymbolName =
50     "__xla_cpu_runtime_EigenMatMulF64";
51 extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32";
52 extern const char* const kMKLMatMulF32SymbolName =
53     "__xla_cpu_runtime_MKLMatMulF32";
54 extern const char* const kMKLMatMulF64SymbolName =
55     "__xla_cpu_runtime_MKLMatMulF64";
56 extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
57     "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
58 extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
59     "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
60 extern const char* const kEigenConvF16SymbolName =
61     "__xla_cpu_runtime_EigenConvF16";
62 extern const char* const kEigenConvF32SymbolName =
63     "__xla_cpu_runtime_EigenConvF32";
64 extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft";
65 extern const char* const kEigenSingleThreadedFftSymbolName =
66     "__xla_cpu_runtime_EigenSingleThreadedFft";
67 extern const char* const kEigenSingleThreadedMatMulF16SymbolName =
68     "__xla_cpu_runtime_EigenSingleThreadedMatMulF16";
69 extern const char* const kEigenSingleThreadedMatMulF32SymbolName =
70     "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
71 extern const char* const kEigenSingleThreadedMatMulF64SymbolName =
72     "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
73 extern const char* const kEigenSingleThreadedConvF16SymbolName =
74     "__xla_cpu_runtime_EigenSingleThreadedConvF16";
75 extern const char* const kEigenSingleThreadedConvF32SymbolName =
76     "__xla_cpu_runtime_EigenSingleThreadedConvF32";
77 extern const char* const kAcquireInfeedBufferForDequeueSymbolName =
78     "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
79 extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName =
80     "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
81 extern const char* const kAcquireOutfeedBufferForPopulationSymbolName =
82     "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation";
83 extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
84     "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
85 extern const char* const kParallelForkJoinSymbolName =
86     "__xla_cpu_runtime_ParallelForkJoin";
87 extern const char* const kKeyValueSortSymbolName =
88     "__xla_cpu_runtime_KeyValueSort";
89 extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
90 }  // namespace runtime
91 }  // namespace cpu
92 }  // namespace xla
93 
94 namespace {
95 
ShapeString(const void * shape_ptr,xla::int32 shape_length)96 tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
97   xla::StatusOr<xla::Shape> shape =
98       xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
99   if (shape.ok()) {
100     return xla::ShapeUtil::HumanStringWithLayout(shape.ValueOrDie());
101   }
102   return "<invalid shape>";
103 }
104 
105 }  // namespace
106 
107 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireInfeedBufferForDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape,xla::int32 shape_length)108 __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
109     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
110     const void* shape, xla::int32 shape_length) {
111   int device_ordinal =
112       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
113 
114   VLOG(2) << "AcquireInfeedBufferForDequeue: "
115           << ShapeString(shape, shape_length) << " on stream executor "
116           << device_ordinal;
117 
118   xla::cpu::runtime::XfeedManager* xfeed =
119       xla::cpu::runtime::GetXfeedManager(device_ordinal);
120   // Wait until there's a buffer to dequeue.
121   xla::cpu::runtime::XfeedBuffer* buffer =
122       xfeed->infeed()->BlockingDequeueBuffer();
123   CHECK_EQ(buffer->length(), buffer_length)
124       << "XLA program infeed request buffer size " << buffer_length
125       << " did not match the runtime's infed buffer length " << buffer->length()
126       << "; program reports desired shape: "
127       << ShapeString(shape, shape_length);
128   return buffer->data();
129 }
130 
131 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)132 __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
133     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
134     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
135   int device_ordinal =
136       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
137 
138   VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
139           << ShapeString(shape_ptr, shape_length) << " on stream executor "
140           << device_ordinal;
141 
142   xla::cpu::runtime::XfeedManager* xfeed =
143       xla::cpu::runtime::GetXfeedManager(device_ordinal);
144   xla::StatusOr<xla::Shape> shape =
145       xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
146   xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
147                                         std::move(shape));
148 }
149 
150 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,const void * shape_ptr,xla::int32 shape_length)151 __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
152     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
153     const void* shape_ptr, xla::int32 shape_length) {
154   int device_ordinal =
155       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
156 
157   VLOG(2) << "AcquireOutfeedBufferForPopulation: "
158           << ShapeString(shape_ptr, shape_length) << " on stream executor "
159           << device_ordinal;
160 
161   xla::cpu::runtime::XfeedManager* xfeed =
162       xla::cpu::runtime::GetXfeedManager(device_ordinal);
163   // Wait until there's a buffer to dequeue.
164   xla::cpu::runtime::XfeedBuffer* buffer =
165       xfeed->outfeed()->BlockingDequeueBuffer();
166   CHECK_EQ(buffer->length(), buffer_length)
167       << "XLA program outfeed request buffer size " << buffer_length
168       << " did not match the runtime's outfeed buffer length "
169       << buffer->length() << "; program reports outfed shape: "
170       << ShapeString(shape_ptr, shape_length);
171   return buffer->data();
172 }
173 
174 TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(const xla::ExecutableRunOptions * run_options,xla::int32 buffer_length,void * buffer_ptr,const void * shape_ptr,xla::int32 shape_length)175 __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
176     const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
177     void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
178   int device_ordinal =
179       run_options ? run_options->stream()->parent()->device_ordinal() : 0;
180 
181   VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
182           << ShapeString(shape_ptr, shape_length) << " on stream executor "
183           << device_ordinal;
184 
185   xla::cpu::runtime::XfeedManager* xfeed =
186       xla::cpu::runtime::GetXfeedManager(device_ordinal);
187   xla::StatusOr<xla::Shape> shape =
188       xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
189   xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
190                                          std::move(shape));
191 }
192