1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/stream.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "third_party/eigen3/Eigen/Core"
23 #include "tensorflow/compiler/xla/stream_executor/blas.h"
24 #include "tensorflow/compiler/xla/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/compiler/xla/stream_executor/platform.h"
26 #include "tensorflow/compiler/xla/stream_executor/platform/logging.h"
27 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
28 #include "tensorflow/compiler/xla/stream_executor/rng.h"
29 #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
30 #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
31
32 namespace stream_executor {
33
34 namespace {
35 // Code to turn parameters to functions on stream into strings that
36 // will be VLOG'ed. We need overloads, instead of
37 // e.g. BatchDescriptorToVlogString(), as the code that calls these
38 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)39 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
40 return descriptor.ToShortString();
41 }
42
ToVlogString(const dnn::FilterDescriptor & descriptor)43 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
44 return descriptor.ToShortString();
45 }
46
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)47 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
48 return descriptor.ToShortString();
49 }
50
ToVlogString(const dnn::PoolingDescriptor & descriptor)51 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
52 return descriptor.ToShortString();
53 }
54
ToVlogString(const dnn::NormalizeDescriptor & descriptor)55 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
56 return descriptor.ToShortString();
57 }
58
ToVlogString(dnn::ActivationMode mode)59 std::string ToVlogString(dnn::ActivationMode mode) {
60 return dnn::ActivationModeString(mode);
61 }
62
ToVlogString(const dnn::AlgorithmConfig & algo_config)63 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
64 return algo_config.ToString();
65 }
66
ToVlogString(dnn::ElementwiseOperation op)67 std::string ToVlogString(dnn::ElementwiseOperation op) {
68 return dnn::ElementwiseOperationString(op);
69 }
70
ToVlogString(dnn::QuantizedActivationMode mode)71 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
72 return dnn::QuantizedActivationModeString(mode);
73 }
74
ToVlogString(blas::Transpose t)75 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
76
ToVlogString(blas::UpperLower ul)77 std::string ToVlogString(blas::UpperLower ul) {
78 return blas::UpperLowerString(ul);
79 }
80
ToVlogString(blas::Diagonal d)81 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
82
ToVlogString(blas::Side s)83 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
84
ToVlogString(blas::ComputationType ty)85 std::string ToVlogString(blas::ComputationType ty) {
86 return blas::ComputationTypeString(ty);
87 }
88
ToVlogString(const void * ptr)89 std::string ToVlogString(const void *ptr) {
90 if (ptr == nullptr) {
91 return "null";
92 }
93
94 // StrCat does not convert pointers to text.
95 std::ostringstream out;
96 out << ptr;
97 return out.str();
98 }
99
100 template <class T>
ToVlogString(const std::complex<T> & c)101 std::string ToVlogString(const std::complex<T> &c) {
102 // StrCat does not convert std::complex to text.
103 std::ostringstream out;
104 out << c;
105 return out.str();
106 }
107
108 template <class T>
ToVlogString(const std::function<T> & f)109 std::string ToVlogString(const std::function<T> &f) {
110 return f == nullptr ? "null" : "<non-null function>";
111 }
112
ToVlogString(const DeviceMemoryBase & memory)113 std::string ToVlogString(const DeviceMemoryBase &memory) {
114 return ToVlogString(memory.opaque());
115 }
116
ToVlogString(const DeviceMemoryBase * memory)117 std::string ToVlogString(const DeviceMemoryBase *memory) {
118 return memory == nullptr ? "null" : ToVlogString(*memory);
119 }
120
ToVlogString(const Eigen::half & h)121 std::string ToVlogString(const Eigen::half &h) {
122 return absl::StrCat(static_cast<float>(h));
123 }
124
ToVlogString(int i)125 std::string ToVlogString(int i) { return absl::StrCat(i); }
126
ToVlogString(uint32 i)127 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
128
ToVlogString(uint64_t i)129 std::string ToVlogString(uint64_t i) { return absl::StrCat(i); }
130
ToVlogString(int64_t i)131 std::string ToVlogString(int64_t i) { return absl::StrCat(i); }
132
ToVlogString(float f)133 std::string ToVlogString(float f) { return absl::StrCat(f); }
134
ToVlogString(double d)135 std::string ToVlogString(double d) { return absl::StrCat(d); }
136
137 template <class T>
ToVlogString(port::ArraySlice<T> elements)138 std::string ToVlogString(port::ArraySlice<T> elements) { // non-absl ok
139 std::string str = absl::StrCat(
140 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
141 elements.size(), "]{");
142 const char *separator = "";
143 size_t max_to_show = std::numeric_limits<size_t>::max();
144 if (!VLOG_IS_ON(2)) {
145 max_to_show = 5;
146 } else if (!VLOG_IS_ON(3)) {
147 max_to_show = 20;
148 } else if (!VLOG_IS_ON(11)) {
149 max_to_show = 1000;
150 }
151 for (size_t i = 0; i < elements.size(); ++i) {
152 if (i == max_to_show) {
153 str += ", ...";
154 break;
155 }
156 absl::StrAppend(&str, separator, ToVlogString(elements[i]));
157 separator = ", ";
158 }
159 str += "}";
160 return str;
161 }
162
163 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)164 std::string ToVlogString(port::MutableArraySlice<T> elements) { // non-absl ok
165 return ToVlogString(port::ArraySlice<T>(elements)); // non-absl ok
166 }
167
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)168 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
169 switch (depth_to_space_layout) {
170 case dnn::DepthToSpaceLayout::DepthHeightWidth:
171 return "DepthToSpaceLayout::DepthHeightWidth";
172 }
173 return "unknown DepthToSpaceLayout";
174 }
175
ToVlogString(dnn::DataType data_type)176 std::string ToVlogString(dnn::DataType data_type) {
177 switch (data_type) {
178 case dnn::DataType::kFloat:
179 return "dnn::DataType::kFloat";
180 case dnn::DataType::kDouble:
181 return "dnn::DataType::kDouble";
182 case dnn::DataType::kHalf:
183 return "dnn::DataType::kHalf";
184 case dnn::DataType::kInt8:
185 return "dnn::DataType::kInt8";
186 case dnn::DataType::kInt32:
187 return "dnn::DataType::kInt32";
188 default:
189 return "unknown DataType";
190 }
191 }
192
193 // Used together with PARAM to VLOG calls made to the stream. Intended
194 // to be used like this:
195 //
196 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
197 //
198 // where a and b are the parameters to MyFunction.
199 //
200 // See VLOG_CALL for a short-hand for this. This way of doing it saves
201 // a tremendous amount of boilerplate code given how many functions
202 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)203 std::string CallStr(const char *function_name, Stream *stream,
204 std::vector<std::pair<const char *, std::string>> params) {
205 // Do not call this function unless VLOG is on since just
206 // constructing all the strings in params is expensive.
207 CHECK(VLOG_IS_ON(1));
208
209 std::string str = absl::StrCat(stream->DebugStreamPointers(),
210 " Called Stream::", function_name, "(");
211 const char *separator = "";
212 for (const auto ¶m : params) {
213 absl::StrAppend(&str, separator, param.first, "=", param.second);
214 separator = ", ";
215 }
216 absl::StrAppend(&str, ")");
217 if (VLOG_IS_ON(10)) {
218 absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
219 }
220 return str;
221 }
222
223 // Use this macro to avoid having to type every parameter twice to log
224 // it with VLOG and CallStr.
225 #define PARAM(parameter) \
226 { \
227 #parameter, ToVlogString(parameter) \
228 }
229
230 // Use this macro to avoid having to type out the name of each
231 // function and to save some boilerplate. Intended to be used like this:
232 //
233 // VLOG_CALL(PARAM(a), PARAM(b))
234 //
235 // This saves a tremendous amount of boilerplate compared to the alternative:
236 //
237 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
238 // << ", b=" << ToVlogString(b);
239 //
240 // Note here that most of the parameter names are not short and that
241 // most of the functions take many more than 2 parameters.
242 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
243
244 } // namespace
245
Stream(StreamExecutor * parent)246 Stream::Stream(StreamExecutor *parent)
247 : parent_(parent),
248 implementation_(parent->implementation()->GetStreamImplementation()),
249 allocated_(false),
250 status_(port::InternalError("Uninitialized stream")),
251 temporary_memory_manager_(this) {
252 VLOG_CALL(PARAM(parent));
253 }
254
~Stream()255 Stream::~Stream() {
256 VLOG_CALL();
257
258 // Ensure the stream is completed.
259 auto status = BlockHostUntilDone();
260 if (!status.ok()) {
261 LOG(WARNING) << "Error blocking host until done in stream destructor: "
262 << status;
263 }
264 temporary_memory_manager_.ForceDeallocateAll();
265 RunAfterBlockHostUntilDoneCallbacks();
266
267 if (allocated_) {
268 parent_->DeallocateStream(this);
269 }
270 }
271
RefreshStatus()272 port::Status Stream::RefreshStatus() {
273 port::Status status = parent_->GetStatus(this);
274 // We should not put the stream in an error state, just because the GetStatus
275 // method is unimplemented.
276 if (status != port::Status(port::error::UNIMPLEMENTED,
277 "GetStatus is not supported on this executor.")) {
278 CheckStatus(status);
279 }
280 return status;
281 }
282
Init()283 Stream &Stream::Init() {
284 VLOG_CALL();
285
286 absl::MutexLock lock(&mu_);
287 CHECK_EQ(false, allocated_)
288 << "stream appears to already have been initialized";
289 CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
290
291 if (parent_->AllocateStream(this)) {
292 // Successful initialization!
293 allocated_ = true;
294 status_ = ::tensorflow::OkStatus();
295 } else {
296 LOG(ERROR) << "failed to allocate stream during initialization";
297 }
298
299 return *this;
300 }
301
InitTimer(Timer * timer)302 Stream &Stream::InitTimer(Timer *timer) {
303 VLOG_CALL(PARAM(timer));
304
305 CheckError(parent_->AllocateTimer(timer));
306 return *this;
307 }
308
InitWithTimer(Timer * timer)309 Stream &Stream::InitWithTimer(Timer *timer) {
310 VLOG_CALL(PARAM(timer));
311
312 return Init().InitTimer(timer);
313 }
314
ThenRecordEvent(Event * event)315 Stream &Stream::ThenRecordEvent(Event *event) {
316 VLOG_CALL(PARAM(event));
317
318 port::Status status = parent_->RecordEvent(this, event);
319 if (!status.ok()) {
320 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
321 << "; not marking stream as bad, as the Event object may be "
322 << "at fault. Monitor for further errors.";
323 }
324
325 return *this;
326 }
327
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)328 Stream &Stream::ThenBatchNormalizationForward(
329 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
330 const DeviceMemory<float> &offset,
331 const DeviceMemory<float> &estimated_mean,
332 const DeviceMemory<float> &estimated_variance,
333 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
334 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
335 const double exponential_average_factor,
336 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
337 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
338 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
339 bool is_training, ScratchAllocator *reserve_space_allocator,
340 ScratchAllocator *workspace_allocator) {
341 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
342 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
343 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
344 CheckError(dnn->DoBatchNormalizationForward(
345 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
346 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
347 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
348 is_training, reserve_space_allocator, workspace_allocator));
349 } else {
350 SetErrorAndLogNoDnnSupport();
351 }
352 return *this;
353 }
354
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<float> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<float> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)355 Stream &Stream::ThenBatchNormalizationBackward(
356 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
357 const DeviceMemory<float> &scale, const DeviceMemory<float> &offset,
358 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
359 const DeviceMemory<float> &y, const dnn::BatchDescriptor &x_desc,
360 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
361 dnn::ActivationMode activation_mode, DeviceMemory<float> *x_backprop,
362 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
363 DeviceMemory<float> *side_input_backprop,
364 DeviceMemory<uint8> *reserve_space_data,
365 ScratchAllocator *workspace_allocator) {
366 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
367 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
368 PARAM(scale_backprop), PARAM(offset_backprop));
369 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
370 CheckError(dnn->DoBatchNormalizationBackward(
371 this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc,
372 scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop,
373 offset_backprop, side_input_backprop, reserve_space_data,
374 workspace_allocator));
375 } else {
376 SetErrorAndLogNoDnnSupport();
377 }
378 return *this;
379 }
380
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)381 Stream &Stream::ThenBatchNormalizationForward(
382 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
383 const DeviceMemory<float> &offset,
384 const DeviceMemory<float> &estimated_mean,
385 const DeviceMemory<float> &estimated_variance,
386 const DeviceMemory<Eigen::half> &side_input,
387 const dnn::BatchDescriptor &x_desc,
388 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
389 const double exponential_average_factor,
390 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
391 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
392 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
393 bool is_training, ScratchAllocator *reserve_space_allocator,
394 ScratchAllocator *workspace_allocator) {
395 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
396 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
397 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
398 CheckError(dnn->DoBatchNormalizationForward(
399 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
400 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
401 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
402 is_training, reserve_space_allocator, workspace_allocator));
403 } else {
404 SetErrorAndLogNoDnnSupport();
405 }
406 return *this;
407 }
408
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const DeviceMemory<Eigen::half> & y,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<Eigen::half> * side_input_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)409 Stream &Stream::ThenBatchNormalizationBackward(
410 const DeviceMemory<Eigen::half> &y_backprop,
411 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
412 const DeviceMemory<float> &offset, const DeviceMemory<float> &mean,
413 const DeviceMemory<float> &inv_var, const DeviceMemory<Eigen::half> &y,
414 const dnn::BatchDescriptor &x_desc,
415 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
416 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *x_backprop,
417 DeviceMemory<float> *scale_backprop, DeviceMemory<float> *offset_backprop,
418 DeviceMemory<Eigen::half> *side_input_backprop,
419 DeviceMemory<uint8> *reserve_space_data,
420 ScratchAllocator *workspace_allocator) {
421 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
422 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
423 PARAM(scale_backprop), PARAM(offset_backprop));
424 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
425 CheckError(dnn->DoBatchNormalizationBackward(
426 this, y_backprop, x, scale, offset, mean, inv_var, y, x_desc,
427 scale_offset_desc, epsilon, activation_mode, x_backprop, scale_backprop,
428 offset_backprop, side_input_backprop, reserve_space_data,
429 workspace_allocator));
430
431 } else {
432 SetErrorAndLogNoDnnSupport();
433 }
434 return *this;
435 }
436
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)437 Stream &Stream::ThenConvolve(
438 const dnn::BatchDescriptor &input_descriptor,
439 const DeviceMemory<float> &input_data,
440 const dnn::FilterDescriptor &filter_descriptor,
441 const DeviceMemory<float> &filter_data,
442 const dnn::ConvolutionDescriptor &convolution_descriptor,
443 const dnn::BatchDescriptor &output_descriptor,
444 DeviceMemory<float> *output) {
445 if (ok()) {
446 CheckError(ConvolveWithAlgorithm(
447 dnn::ConvolutionKind::FORWARD, input_descriptor, input_data,
448 filter_descriptor, filter_data, output_descriptor, *output,
449 convolution_descriptor,
450 /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
451 /*output_profile_result=*/nullptr)
452 .ok());
453 }
454 return *this;
455 }
456
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)457 Stream &Stream::ThenConvolveQuantized(
458 const dnn::BatchDescriptor &input_descriptor,
459 const DeviceMemory<float> &input_data,
460 const dnn::FilterDescriptor &filter_descriptor,
461 const DeviceMemory<int8> &filter_coefficients,
462 const DeviceMemory<float> &coefficient_scales,
463 const dnn::ConvolutionDescriptor &convolution_descriptor,
464 const dnn::BatchDescriptor &output_descriptor,
465 DeviceMemory<float> *output) {
466 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
467 PARAM(filter_descriptor), PARAM(filter_coefficients),
468 PARAM(coefficient_scales), PARAM(convolution_descriptor),
469 PARAM(output_descriptor), PARAM(output));
470
471 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
472 CheckError(dnn->DoConvolveQuantized(
473 this, input_descriptor, input_data, filter_descriptor,
474 filter_coefficients, coefficient_scales, convolution_descriptor,
475 output_descriptor, output));
476 } else {
477 SetError();
478 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
479 "without DNN support";
480 }
481 return *this;
482 }
483
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)484 Stream &Stream::ThenConvolveQuantized(
485 const dnn::BatchDescriptor &input_descriptor,
486 const DeviceMemory<float> &input_data,
487 const dnn::FilterDescriptor &filter_descriptor,
488 const DeviceMemory<int16> &filter_coefficients,
489 const DeviceMemory<float> &coefficient_scales,
490 const dnn::ConvolutionDescriptor &convolution_descriptor,
491 const dnn::BatchDescriptor &output_descriptor,
492 DeviceMemory<float> *output) {
493 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
494 PARAM(filter_descriptor), PARAM(filter_coefficients),
495 PARAM(coefficient_scales), PARAM(convolution_descriptor),
496 PARAM(output_descriptor), PARAM(output));
497
498 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
499 CheckError(dnn->DoConvolveQuantized(
500 this, input_descriptor, input_data, filter_descriptor,
501 filter_coefficients, coefficient_scales, convolution_descriptor,
502 output_descriptor, output));
503 } else {
504 SetError();
505 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
506 "without DNN support";
507 }
508 return *this;
509 }
510
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)511 Stream &Stream::ThenSeparableConvolve(
512 const dnn::BatchDescriptor &batch_descriptor,
513 const DeviceMemory<float> &input_data,
514 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
515 const DeviceMemory<float> &first_weights,
516 const DeviceMemory<float> &second_weights,
517 const dnn::ConvolutionDescriptor &convolution_descriptor,
518 const dnn::BatchDescriptor &output_descriptor,
519 DeviceMemory<float> *output) {
520 VLOG_CALL(
521 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
522 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
523 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
524
525 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
526 CheckError(dnn->DoSeparableConvolve(
527 this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
528 first_weights, second_weights, convolution_descriptor,
529 output_descriptor, output));
530 } else {
531 SetErrorAndLogNoDnnSupport();
532 }
533 return *this;
534 }
535
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)536 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
537 const DeviceMemory<float> &weights,
538 const dnn::BatchDescriptor &input_dimensions,
539 const dnn::BatchDescriptor &output_dimensions,
540 DeviceMemory<float> *output_data) {
541 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
542 PARAM(output_dimensions), PARAM(output_data));
543
544 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
545 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
546 output_dimensions, output_data));
547 } else {
548 SetErrorAndLogNoDnnSupport();
549 }
550 return *this;
551 }
552
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)553 Stream &Stream::ThenMatMulQuantized(
554 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
555 const DeviceMemory<float> &weight_scales,
556 const dnn::BatchDescriptor &input_dimensions,
557 const dnn::BatchDescriptor &output_dimensions,
558 DeviceMemory<float> *output_data) {
559 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
560 PARAM(input_dimensions), PARAM(output_dimensions),
561 PARAM(output_data));
562
563 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
564 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
565 input_dimensions, output_dimensions,
566 output_data));
567 } else {
568 SetErrorAndLogNoDnnSupport();
569 }
570 return *this;
571 }
572
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)573 Stream &Stream::ThenMatMulQuantized(
574 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
575 const DeviceMemory<float> &weight_scales,
576 const dnn::BatchDescriptor &input_dimensions,
577 const dnn::BatchDescriptor &output_dimensions,
578 DeviceMemory<float> *output_data) {
579 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
580 PARAM(input_dimensions), PARAM(output_dimensions),
581 PARAM(output_data));
582
583 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
584 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
585 input_dimensions, output_dimensions,
586 output_data));
587 } else {
588 SetErrorAndLogNoDnnSupport();
589 }
590 return *this;
591 }
592
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)593 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
594 const DeviceMemory<float> &biases,
595 const dnn::BatchDescriptor &dimensions,
596 DeviceMemory<float> *output_data) {
597 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
598 PARAM(output_data));
599
600 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
601 CheckError(
602 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
603 } else {
604 SetErrorAndLogNoDnnSupport();
605 }
606 return *this;
607 }
608
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)609 Stream &Stream::ThenNormalizeWithDimensions(
610 const dnn::NormalizeDescriptor &normalize_descriptor,
611 const dnn::BatchDescriptor &dimensions,
612 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
613 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
614 PARAM(output_data));
615
616 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
617 CheckError(dnn->DoNormalizeWithDimensions(
618 this, normalize_descriptor, dimensions, input_data, output_data));
619 } else {
620 SetErrorAndLogNoDnnSupport();
621 }
622 return *this;
623 }
624
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)625 Stream &Stream::ThenNormalizeBackwardWithDimensions(
626 const dnn::NormalizeDescriptor &normalize_descriptor,
627 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
628 const DeviceMemory<float> &normalized_data,
629 const DeviceMemory<float> &normalized_variable_gradient,
630 DeviceMemory<float> *raw_variable_gradient,
631 ScratchAllocator *workspace_allocator) {
632 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
633 PARAM(normalized_data), PARAM(normalized_variable_gradient),
634 PARAM(raw_variable_gradient), PARAM(workspace_allocator));
635
636 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
637 CheckError(dnn->DoNormalizeBackwardWithDimensions(
638 this, normalize_descriptor, dimensions, raw_data, normalized_data,
639 normalized_variable_gradient, raw_variable_gradient,
640 workspace_allocator));
641 } else {
642 SetErrorAndLogNoDnnSupport();
643 }
644 return *this;
645 }
646
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)647 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
648 const dnn::BatchDescriptor &dimensions,
649 const DeviceMemory<float> &input_data,
650 DeviceMemory<float> *output_data) {
651 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
652 output_data, /*options=*/0);
653 }
654
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64_t options)655 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
656 const dnn::BatchDescriptor &dimensions,
657 const DeviceMemory<float> &input_data,
658 DeviceMemory<float> *output_data,
659 uint64_t options) {
660 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
661 PARAM(output_data), PARAM(options));
662
663 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
664 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
665 output_data, options));
666 } else {
667 SetErrorAndLogNoDnnSupport();
668 }
669 return *this;
670 }
671
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)672 Stream &Stream::ThenDepthConcatenate(
673 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
674 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
675 DeviceMemory<float> *output_data) {
676 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
677
678 for (size_t i = 1; i < input_dimensions.size(); ++i) {
679 if (input_dimensions[i].count() != input_dimensions[0].count() ||
680 input_dimensions[i].height() != input_dimensions[0].height() ||
681 input_dimensions[i].width() != input_dimensions[0].width()) {
682 SetError();
683 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
684 << "input_dimensions[0]: " << input_dimensions[0].ToString()
685 << "input_dimensions[" << i
686 << "]: " << input_dimensions[i].ToString();
687 return *this;
688 }
689 }
690
691 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
692 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
693 output_data));
694 } else {
695 SetErrorAndLogNoDnnSupport();
696 }
697 return *this;
698 }
699
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)700 Stream &Stream::ThenSpaceConcatenate(
701 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
702 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
703 DeviceMemory<float> *output_data,
704 dnn::SpaceConcatenateMode concat_direction) {
705 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
706
707 // Check that the input dimensions of all the other batches match those of the
708 // first batch.
709 for (size_t i = 1; i < input_dimensions.size(); ++i) {
710 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
711 (input_dimensions[i].count() != input_dimensions[0].count() ||
712 input_dimensions[i].height() != input_dimensions[0].height() ||
713 input_dimensions[i].feature_map_count() !=
714 input_dimensions[0].feature_map_count())) {
715 SetError();
716 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
717 << "input_dimensions[0]: " << input_dimensions[0].ToString()
718 << "input_dimensions[" << i
719 << "]: " << input_dimensions[i].ToString();
720 return *this;
721 }
722
723 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
724 (input_dimensions[i].count() != input_dimensions[0].count() ||
725 input_dimensions[i].width() != input_dimensions[0].width() ||
726 input_dimensions[i].feature_map_count() !=
727 input_dimensions[0].feature_map_count())) {
728 SetError();
729 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
730 << "input_dimensions[0]: " << input_dimensions[0].ToString()
731 << "input_dimensions[" << i
732 << "]: " << input_dimensions[i].ToString();
733 return *this;
734 }
735 }
736 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
737 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
738 output_data, concat_direction));
739 } else {
740 SetErrorAndLogNoDnnSupport();
741 }
742 return *this;
743 }
744
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)745 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
746 const DeviceMemory<float> &input_data,
747 const dnn::BatchDescriptor &output_dimensions,
748 DeviceMemory<float> *output_data) {
749 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
750 PARAM(output_dimensions), PARAM(output_data));
751
752 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
753 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
754 output_dimensions, output_data));
755 } else {
756 SetErrorAndLogNoDnnSupport();
757 }
758 return *this;
759 }
760
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)761 Stream &Stream::ThenDepthToSpace(
762 const dnn::BatchDescriptor &input_dimensions,
763 const DeviceMemory<float> &input_data,
764 const dnn::DepthToSpaceLayout &depth_to_space_layout,
765 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
766 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
767 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
768 PARAM(output_data));
769
770 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
771 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
772 depth_to_space_layout, sqrt_depth_reduction,
773 output_data));
774 } else {
775 SetErrorAndLogNoDnnSupport();
776 }
777 return *this;
778 }
779
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)780 Stream &Stream::ThenSpaceToDepth(
781 const dnn::BatchDescriptor &input_dimensions,
782 const DeviceMemory<float> &input_data,
783 const dnn::DepthToSpaceLayout &space_to_depth_layout,
784 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
785 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
786 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
787 PARAM(output_data));
788
789 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
790 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
791 space_to_depth_layout, sqrt_depth_increase,
792 output_data));
793 } else {
794 SetErrorAndLogNoDnnSupport();
795 }
796 return *this;
797 }
798
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)799 Stream &Stream::ThenElementwiseOperate(
800 dnn::ElementwiseOperation operation,
801 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
802 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
803 const dnn::BatchDescriptor &output_dimensions,
804 DeviceMemory<float> *output_data) {
805 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
806 PARAM(output_dimensions), PARAM(output_data));
807
808 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
809 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
810 input_data, output_dimensions,
811 output_data));
812 } else {
813 SetErrorAndLogNoDnnSupport();
814 }
815 return *this;
816 }
817
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)818 Stream &Stream::ThenElementwiseOperateScaledQuantized(
819 dnn::ElementwiseOperation operation,
820 port::ArraySlice<int> input_multiplicands, // non-absl ok
821 int output_divisor,
822 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, // non-absl ok
823 port::ArraySlice<const DeviceMemory<float> *> input_data, // non-absl ok
824 const dnn::BatchDescriptor &output_dimensions,
825 DeviceMemory<float> *output_data) {
826 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
827 PARAM(input_dimensions), PARAM(input_data),
828 PARAM(output_dimensions), PARAM(output_data));
829
830 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
831 CheckError(dnn->DoElementwiseOperateScaledQuantized(
832 this, operation, input_multiplicands, output_divisor, input_dimensions,
833 input_data, output_dimensions, output_data));
834 } else {
835 SetErrorAndLogNoDnnSupport();
836 }
837 return *this;
838 }
839
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)840 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
841 const DeviceMemory<float> &input_data,
842 int64_t left_pad, int64_t right_pad, int64_t top_pad,
843 int64_t bottom_pad,
844 DeviceMemory<float> *output_data) {
845 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
846 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
847 PARAM(output_data));
848
849 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
850 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
851 top_pad, bottom_pad, output_data));
852 } else {
853 SetErrorAndLogNoDnnSupport();
854 }
855 return *this;
856 }
857
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)858 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
859 const DeviceMemory<float> &input_data,
860 int64_t left_trim, int64_t right_trim,
861 int64_t top_trim, int64_t bottom_trim,
862 DeviceMemory<float> *output_data) {
863 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
864 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
865 PARAM(output_data));
866
867 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
868 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
869 right_trim, top_trim, bottom_trim, output_data));
870 } else {
871 SetErrorAndLogNoDnnSupport();
872 }
873 return *this;
874 }
875
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)876 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
877 const DeviceMemory<float> &input_data,
878 int64_t replicate_x, int64_t replicate_y,
879 DeviceMemory<float> *output_data) {
880 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
881 PARAM(replicate_y), PARAM(output_data));
882
883 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
884 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
885 replicate_y, output_data));
886 } else {
887 SetErrorAndLogNoDnnSupport();
888 }
889 return *this;
890 }
891
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64_t size)892 Stream &Stream::ThenMemcpyD2HQuantized(
893 const DeviceMemory<float> &gpu_unquantized_src,
894 dnn::QuantizedActivationMode mode, void *host_dst, uint64_t size) {
895 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
896 PARAM(size));
897
898 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
899 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
900 host_dst, size));
901 } else {
902 SetErrorAndLogNoDnnSupport();
903 }
904 return *this;
905 }
906
ThenMemcpyH2DQuantized(const void * host_src,uint64_t size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)907 Stream &Stream::ThenMemcpyH2DQuantized(
908 const void *host_src, uint64_t size, dnn::QuantizedActivationMode mode,
909 DeviceMemory<float> *gpu_unquantized_dst) {
910 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
911 PARAM(gpu_unquantized_dst));
912
913 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
914 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
915 gpu_unquantized_dst));
916 } else {
917 SetErrorAndLogNoDnnSupport();
918 }
919 return *this;
920 }
921
GetOrCreateSubStream()922 Stream *Stream::GetOrCreateSubStream() {
923 // Do not destroy bad streams when holding mu_ because ~Stream() may
924 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
925 std::vector<std::unique_ptr<Stream>> bad_streams;
926
927 absl::MutexLock lock(&mu_);
928
929 // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
930 // we encounter along the way.
931 for (size_t index = 0; index < sub_streams_.size();) {
932 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
933 if (pair.second) {
934 // The sub_stream is reusable.
935 Stream *sub_stream = pair.first.get();
936 if (sub_stream->ok()) {
937 VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
938 << sub_stream->DebugStreamPointers();
939 pair.second = false;
940 return sub_stream;
941 }
942
943 // The stream is reusable and not ok. Streams have a monotonic state
944 // machine; the stream will remain in !ok forever. Swap it with the last
945 // stream and pop it off.
946 const int64_t last = sub_streams_.size() - 1;
947 if (index != last) {
948 std::swap(pair, sub_streams_[last]);
949 }
950 bad_streams.push_back(std::move(sub_streams_.back().first));
951 sub_streams_.pop_back();
952 VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
953 << sub_stream->DebugStreamPointers();
954 } else {
955 // The sub_stream is not reusable, move on to the next one.
956 ++index;
957 }
958 }
959
960 // No streams are reusable; create a new stream.
961 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
962 false);
963 Stream *sub_stream = sub_streams_.back().first.get();
964 sub_stream->Init();
965 if (!sub_stream->ok()) {
966 LOG(ERROR) << "sub-stream failed to be initialized";
967 }
968 VLOG(1) << DebugStreamPointers() << " created new sub_stream "
969 << sub_stream->DebugStreamPointers();
970
971 return sub_stream;
972 }
973
ReturnSubStream(Stream * sub_stream)974 void Stream::ReturnSubStream(Stream *sub_stream) {
975 // Do not destroy bad streams when holding mu_ because ~Stream() may
976 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
977 std::unique_ptr<Stream> bad_stream;
978
979 absl::MutexLock lock(&mu_);
980
981 // Look for the sub-stream.
982 for (int64_t index = 0, end = sub_streams_.size(); index < end; ++index) {
983 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
984 if (pair.first.get() != sub_stream) {
985 continue;
986 }
987
988 // Found the sub_stream.
989 if (sub_stream->ok()) {
990 VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
991 << sub_stream->DebugStreamPointers();
992 pair.second = true;
993 } else {
994 // The returned stream is not ok. Streams have a monotonic state
995 // machine; the stream will remain in !ok forever. Swap it with the last
996 // stream and pop it off.
997 VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
998 << sub_stream->DebugStreamPointers();
999 const int64_t last = sub_streams_.size() - 1;
1000 if (index != last) {
1001 std::swap(pair, sub_streams_[last]);
1002 }
1003 std::swap(bad_stream, sub_streams_.back().first);
1004 sub_streams_.pop_back();
1005 }
1006 return;
1007 }
1008
1009 LOG(FATAL) << DebugStreamPointers()
1010 << " did not create the returned sub-stream "
1011 << sub_stream->DebugStreamPointers();
1012 }
1013
ThenStartTimer(Timer * t)1014 Stream &Stream::ThenStartTimer(Timer *t) {
1015 VLOG_CALL(PARAM(t));
1016
1017 CheckError(parent_->StartTimer(this, t));
1018 return *this;
1019 }
1020
ThenStopTimer(Timer * t)1021 Stream &Stream::ThenStopTimer(Timer *t) {
1022 VLOG_CALL(PARAM(t));
1023
1024 CheckError(parent_->StopTimer(this, t));
1025 return *this;
1026 }
1027
ThenWaitFor(Stream * other)1028 Stream &Stream::ThenWaitFor(Stream *other) {
1029 VLOG_CALL(PARAM(other));
1030
1031 CHECK(this != other) << "stream cannot wait for itself";
1032 if (ok() && other->ok()) {
1033 CheckError(parent_->CreateStreamDependency(this, other));
1034 } else {
1035 SetError();
1036 LOG(INFO) << DebugStreamPointers() << " did not wait for "
1037 << other->DebugStreamPointers();
1038 }
1039 return *this;
1040 }
1041
ThenWaitFor(Event * event)1042 Stream &Stream::ThenWaitFor(Event *event) {
1043 VLOG_CALL(PARAM(event));
1044
1045 if (ok()) {
1046 port::Status status = parent_->WaitForEvent(this, event);
1047 if (!status.ok()) {
1048 LOG(ERROR) << "Error waiting for event in stream: "
1049 << status.error_message()
1050 << "; not marking stream as bad, as the Event object may be "
1051 << "at fault. Monitor for further errors.";
1052 }
1053 } else {
1054 LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1055 }
1056 return *this;
1057 }
1058
1059 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1060 // functions and logs for errors.
1061 template <typename... Args>
1062 struct ThenBlasImpl {
1063 // blas_func is the DoBlasXXX member function pointer, and args are its
1064 // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1065 Stream &operator()(Stream *stream,
1066 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1067 Args... args) {
1068 return Run(stream, blas_func, /*record_error=*/true, args...);
1069 }
1070
1071 // Like operator(), but only calls stream->CheckError() if record_error is
1072 // true.
1073 Stream &Run(Stream *stream,
1074 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1075 bool record_error, Args... args);
1076 };
1077
1078 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1079 Stream &ThenBlasImpl<Args...>::Run(
1080 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1081 bool record_error, Args... args) {
1082 if (stream->ok()) {
1083 bool ok;
1084 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1085 ok = (blas->*blas_func)(stream, args...);
1086 } else {
1087 LOG(WARNING)
1088 << "attempting to perform BLAS operation using StreamExecutor "
1089 "without BLAS support";
1090 ok = false;
1091 }
1092 if (record_error) {
1093 stream->CheckError(ok);
1094 }
1095 }
1096 return *stream;
1097 }
1098
ThenBlasAxpy(uint64_t elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1099 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, float alpha,
1100 const DeviceMemory<float> &x, int incx,
1101 DeviceMemory<float> *y, int incy) {
1102 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1103 PARAM(incy));
1104
1105 ThenBlasImpl<uint64_t, float, const DeviceMemory<float> &, int,
1106 DeviceMemory<float> *, int>
1107 impl;
1108 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1109 y, incy);
1110 }
1111
ThenBlasAxpy(uint64_t elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1112 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, double alpha,
1113 const DeviceMemory<double> &x, int incx,
1114 DeviceMemory<double> *y, int incy) {
1115 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1116 PARAM(incy));
1117
1118 ThenBlasImpl<uint64_t, double, const DeviceMemory<double> &, int,
1119 DeviceMemory<double> *, int>
1120 impl;
1121 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1122 y, incy);
1123 }
1124
ThenBlasAxpy(uint64_t elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1125 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, std::complex<float> alpha,
1126 const DeviceMemory<std::complex<float>> &x,
1127 int incx, DeviceMemory<std::complex<float>> *y,
1128 int incy) {
1129 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1130 PARAM(incy));
1131
1132 ThenBlasImpl<uint64_t, std::complex<float>,
1133 const DeviceMemory<std::complex<float>> &, int,
1134 DeviceMemory<std::complex<float>> *, int>
1135 impl;
1136 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1137 y, incy);
1138 }
1139
ThenBlasAxpy(uint64_t elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1140 Stream &Stream::ThenBlasAxpy(uint64_t elem_count, std::complex<double> alpha,
1141 const DeviceMemory<std::complex<double>> &x,
1142 int incx, DeviceMemory<std::complex<double>> *y,
1143 int incy) {
1144 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1145 PARAM(incy));
1146
1147 ThenBlasImpl<uint64_t, std::complex<double>,
1148 const DeviceMemory<std::complex<double>> &, int,
1149 DeviceMemory<std::complex<double>> *, int>
1150 impl;
1151 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1152 y, incy);
1153 }
1154
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1155 Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory<float> &x,
1156 int incx, DeviceMemory<float> *y, int incy) {
1157 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1158
1159 ThenBlasImpl<uint64_t, const DeviceMemory<float> &, int,
1160 DeviceMemory<float> *, int>
1161 impl;
1162 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1163 incy);
1164 }
1165
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1166 Stream &Stream::ThenBlasCopy(uint64_t elem_count, const DeviceMemory<double> &x,
1167 int incx, DeviceMemory<double> *y, int incy) {
1168 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1169
1170 ThenBlasImpl<uint64_t, const DeviceMemory<double> &, int,
1171 DeviceMemory<double> *, int>
1172 impl;
1173 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1174 incy);
1175 }
1176
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1177 Stream &Stream::ThenBlasCopy(uint64_t elem_count,
1178 const DeviceMemory<std::complex<float>> &x,
1179 int incx, DeviceMemory<std::complex<float>> *y,
1180 int incy) {
1181 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1182
1183 ThenBlasImpl<uint64_t, const DeviceMemory<std::complex<float>> &, int,
1184 DeviceMemory<std::complex<float>> *, int>
1185 impl;
1186 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1187 incy);
1188 }
1189
ThenBlasCopy(uint64_t elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1190 Stream &Stream::ThenBlasCopy(uint64_t elem_count,
1191 const DeviceMemory<std::complex<double>> &x,
1192 int incx, DeviceMemory<std::complex<double>> *y,
1193 int incy) {
1194 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1195
1196 ThenBlasImpl<uint64_t, const DeviceMemory<std::complex<double>> &, int,
1197 DeviceMemory<std::complex<double>> *, int>
1198 impl;
1199 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1200 incy);
1201 }
1202
ThenBlasScal(uint64_t elem_count,float alpha,DeviceMemory<float> * x,int incx)1203 Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha,
1204 DeviceMemory<float> *x, int incx) {
1205 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1206
1207 ThenBlasImpl<uint64_t, float, DeviceMemory<float> *, int> impl;
1208 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1209 }
1210
ThenBlasScal(uint64_t elem_count,double alpha,DeviceMemory<double> * x,int incx)1211 Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha,
1212 DeviceMemory<double> *x, int incx) {
1213 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1214
1215 ThenBlasImpl<uint64_t, double, DeviceMemory<double> *, int> impl;
1216 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1217 }
1218
ThenBlasScal(uint64_t elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1219 Stream &Stream::ThenBlasScal(uint64_t elem_count, float alpha,
1220 DeviceMemory<std::complex<float>> *x, int incx) {
1221 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1222
1223 ThenBlasImpl<uint64_t, float, DeviceMemory<std::complex<float>> *, int> impl;
1224 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1225 }
1226
ThenBlasScal(uint64_t elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1227 Stream &Stream::ThenBlasScal(uint64_t elem_count, double alpha,
1228 DeviceMemory<std::complex<double>> *x, int incx) {
1229 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1230
1231 ThenBlasImpl<uint64_t, double, DeviceMemory<std::complex<double>> *, int>
1232 impl;
1233 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1234 }
1235
ThenBlasScal(uint64_t elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1236 Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex<float> alpha,
1237 DeviceMemory<std::complex<float>> *x, int incx) {
1238 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1239
1240 ThenBlasImpl<uint64_t, std::complex<float>,
1241 DeviceMemory<std::complex<float>> *, int>
1242 impl;
1243 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1244 }
1245
ThenBlasScal(uint64_t elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1246 Stream &Stream::ThenBlasScal(uint64_t elem_count, std::complex<double> alpha,
1247 DeviceMemory<std::complex<double>> *x, int incx) {
1248 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1249
1250 ThenBlasImpl<uint64_t, std::complex<double>,
1251 DeviceMemory<std::complex<double>> *, int>
1252 impl;
1253 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1254 }
1255
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1256 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1257 float alpha, const DeviceMemory<float> &a, int lda,
1258 const DeviceMemory<float> &x, int incx, float beta,
1259 DeviceMemory<float> *y, int incy) {
1260 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1261 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1262 PARAM(incy));
1263
1264 ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, float,
1265 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1266 int, float, DeviceMemory<float> *, int>
1267 impl;
1268 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1269 x, incx, beta, y, incy);
1270 }
1271
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1272 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1273 double alpha, const DeviceMemory<double> &a,
1274 int lda, const DeviceMemory<double> &x, int incx,
1275 double beta, DeviceMemory<double> *y, int incy) {
1276 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1277 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1278 PARAM(incy));
1279
1280 ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, double,
1281 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1282 int, double, DeviceMemory<double> *, int>
1283 impl;
1284 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1285 x, incx, beta, y, incy);
1286 }
1287
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1288 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1289 std::complex<float> alpha,
1290 const DeviceMemory<std::complex<float>> &a,
1291 int lda,
1292 const DeviceMemory<std::complex<float>> &x,
1293 int incx, std::complex<float> beta,
1294 DeviceMemory<std::complex<float>> *y, int incy) {
1295 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1296 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1297 PARAM(incy));
1298
1299 ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, std::complex<float>,
1300 const DeviceMemory<std::complex<float>> &, int,
1301 const DeviceMemory<std::complex<float>> &, int,
1302 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
1303 impl;
1304 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1305 x, incx, beta, y, incy);
1306 }
1307
ThenBlasGemv(blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1308 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64_t m, uint64 n,
1309 std::complex<double> alpha,
1310 const DeviceMemory<std::complex<double>> &a,
1311 int lda,
1312 const DeviceMemory<std::complex<double>> &x,
1313 int incx, std::complex<double> beta,
1314 DeviceMemory<std::complex<double>> *y, int incy) {
1315 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1316 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1317 PARAM(incy));
1318
1319 ThenBlasImpl<blas::Transpose, uint64_t, uint64_t, std::complex<double>,
1320 const DeviceMemory<std::complex<double>> &, int,
1321 const DeviceMemory<std::complex<double>> &, int,
1322 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
1323 impl;
1324 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1325 x, incx, beta, y, incy);
1326 }
1327
ThenBlasSbmv(blas::UpperLower uplo,uint64_t n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1328 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
1329 float alpha, const DeviceMemory<float> &a, int lda,
1330 const DeviceMemory<float> &x, int incx, float beta,
1331 DeviceMemory<float> *y, int incy) {
1332 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
1333 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
1334
1335 ThenBlasImpl<blas::UpperLower, uint64_t, uint64_t, float,
1336 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1337 int, float, DeviceMemory<float> *, int>
1338 impl;
1339 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
1340 x, incx, beta, y, incy);
1341 }
1342
ThenBlasSbmv(blas::UpperLower uplo,uint64_t n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1343 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64_t n, uint64 k,
1344 double alpha, const DeviceMemory<double> &a,
1345 int lda, const DeviceMemory<double> &x, int incx,
1346 double beta, DeviceMemory<double> *y, int incy) {
1347 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
1348 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
1349
1350 ThenBlasImpl<blas::UpperLower, uint64_t, uint64_t, double,
1351 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1352 int, double, DeviceMemory<double> *, int>
1353 impl;
1354 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
1355 x, incx, beta, y, incy);
1356 }
1357
1358 namespace {
1359 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
1360 // blas::ProfileResult*. This functor doesn't put the stream into an error
1361 // state if the op fails and the profile result is non-null. Instead, the
1362 // error-ness is returned in the profile result itself.
1363 template <typename... Args>
1364 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon0297cec10211::ThenBlasWithProfileImpl1365 Stream &operator()(Stream *stream,
1366 bool (blas::BlasSupport::*blas_func)(
1367 Stream *, Args..., blas::ProfileResult *),
1368 Args... args, blas::ProfileResult *profile_result) {
1369 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
1370 bool record_error = profile_result == nullptr;
1371 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
1372 }
1373 };
1374 } // anonymous namespace
1375
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)1376 Stream &Stream::ThenBlasGemvWithProfiling(
1377 blas::Transpose trans, uint64_t m, uint64 n, float alpha,
1378 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
1379 int incx, float beta, DeviceMemory<float> *y, int incy,
1380 blas::ProfileResult *output_profile_result) {
1381 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1382 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1383 PARAM(incy));
1384
1385 ThenBlasWithProfileImpl<
1386 blas::Transpose, uint64_t, uint64_t, float, const DeviceMemory<float> &,
1387 int, const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
1388 impl;
1389 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1390 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1391 }
1392
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)1393 Stream &Stream::ThenBlasGemvWithProfiling(
1394 blas::Transpose trans, uint64_t m, uint64 n, double alpha,
1395 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
1396 int incx, double beta, DeviceMemory<double> *y, int incy,
1397 blas::ProfileResult *output_profile_result) {
1398 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1399 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1400 PARAM(incy));
1401
1402 ThenBlasWithProfileImpl<blas::Transpose, uint64_t, uint64_t, double,
1403 const DeviceMemory<double> &, int,
1404 const DeviceMemory<double> &, int, double,
1405 DeviceMemory<double> *, int>
1406 impl;
1407 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1408 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1409 }
1410
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)1411 Stream &Stream::ThenBlasGemvWithProfiling(
1412 blas::Transpose trans, uint64_t m, uint64 n, std::complex<float> alpha,
1413 const DeviceMemory<std::complex<float>> &a, int lda,
1414 const DeviceMemory<std::complex<float>> &x, int incx,
1415 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
1416 blas::ProfileResult *output_profile_result) {
1417 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1418 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1419 PARAM(incy));
1420
1421 ThenBlasWithProfileImpl<
1422 blas::Transpose, uint64_t, uint64_t, std::complex<float>,
1423 const DeviceMemory<std::complex<float>> &, int,
1424 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
1425 DeviceMemory<std::complex<float>> *, int>
1426 impl;
1427 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1428 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1429 }
1430
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)1431 Stream &Stream::ThenBlasGemvWithProfiling(
1432 blas::Transpose trans, uint64_t m, uint64 n, std::complex<double> alpha,
1433 const DeviceMemory<std::complex<double>> &a, int lda,
1434 const DeviceMemory<std::complex<double>> &x, int incx,
1435 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
1436 blas::ProfileResult *output_profile_result) {
1437 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1438 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1439 PARAM(incy));
1440
1441 ThenBlasWithProfileImpl<
1442 blas::Transpose, uint64_t, uint64_t, std::complex<double>,
1443 const DeviceMemory<std::complex<double>> &, int,
1444 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
1445 DeviceMemory<std::complex<double>> *, int>
1446 impl;
1447 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
1448 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
1449 }
1450
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)1451 Stream &Stream::ThenBlasGemmWithProfiling(
1452 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1453 uint64_t k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
1454 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
1455 DeviceMemory<Eigen::half> *c, int ldc,
1456 blas::ProfileResult *output_profile_result) {
1457 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1458 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1459 PARAM(beta), PARAM(c), PARAM(ldc));
1460
1461 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1462 uint64_t, float, const DeviceMemory<Eigen::half> &,
1463 int, const DeviceMemory<Eigen::half> &, int, float,
1464 DeviceMemory<Eigen::half> *, int>
1465 impl;
1466 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1467 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1468 output_profile_result);
1469 }
1470
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)1471 Stream &Stream::ThenBlasGemmWithProfiling(
1472 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1473 uint64_t k, float alpha, const DeviceMemory<float> &a, int lda,
1474 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
1475 int ldc, blas::ProfileResult *output_profile_result) {
1476 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1477 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1478 PARAM(beta), PARAM(c), PARAM(ldc));
1479
1480 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1481 uint64_t, float, const DeviceMemory<float> &, int,
1482 const DeviceMemory<float> &, int, float,
1483 DeviceMemory<float> *, int>
1484 impl;
1485 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1486 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1487 output_profile_result);
1488 }
1489
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)1490 Stream &Stream::ThenBlasGemmWithProfiling(
1491 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1492 uint64_t k, double alpha, const DeviceMemory<double> &a, int lda,
1493 const DeviceMemory<double> &b, int ldb, double beta,
1494 DeviceMemory<double> *c, int ldc,
1495 blas::ProfileResult *output_profile_result) {
1496 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1497 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1498 PARAM(beta), PARAM(c), PARAM(ldc));
1499
1500 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t,
1501 uint64_t, double, const DeviceMemory<double> &, int,
1502 const DeviceMemory<double> &, int, double,
1503 DeviceMemory<double> *, int>
1504 impl;
1505 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1506 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1507 output_profile_result);
1508 }
1509
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)1510 Stream &Stream::ThenBlasGemmWithProfiling(
1511 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1512 uint64_t k, std::complex<float> alpha,
1513 const DeviceMemory<std::complex<float>> &a, int lda,
1514 const DeviceMemory<std::complex<float>> &b, int ldb,
1515 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
1516 blas::ProfileResult *output_profile_result) {
1517 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1518 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1519 PARAM(beta), PARAM(c), PARAM(ldc));
1520
1521 ThenBlasWithProfileImpl<
1522 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1523 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
1524 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
1525 DeviceMemory<std::complex<float>> *, int>
1526 impl;
1527 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1528 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1529 output_profile_result);
1530 }
1531
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)1532 Stream &Stream::ThenBlasGemmWithProfiling(
1533 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1534 uint64_t k, std::complex<double> alpha,
1535 const DeviceMemory<std::complex<double>> &a, int lda,
1536 const DeviceMemory<std::complex<double>> &b, int ldb,
1537 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
1538 blas::ProfileResult *output_profile_result) {
1539 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1540 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1541 PARAM(beta), PARAM(c), PARAM(ldc));
1542
1543 ThenBlasWithProfileImpl<
1544 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1545 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
1546 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
1547 DeviceMemory<std::complex<double>> *, int>
1548 impl;
1549 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
1550 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
1551 output_profile_result);
1552 }
1553
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)1554 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1555 blas::Transpose transa, blas::Diagonal diag,
1556 uint64_t m, uint64 n, float alpha,
1557 const DeviceMemory<float> &a, int lda,
1558 DeviceMemory<float> *b, int ldb) {
1559 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1560 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1561
1562 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1563 uint64_t, uint64_t, float, const DeviceMemory<float> &, int,
1564 DeviceMemory<float> *, int>
1565 impl;
1566 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1567 n, alpha, a, lda, b, ldb);
1568 }
1569
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)1570 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1571 blas::Transpose transa, blas::Diagonal diag,
1572 uint64_t m, uint64 n, double alpha,
1573 const DeviceMemory<double> &a, int lda,
1574 DeviceMemory<double> *b, int ldb) {
1575 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1576 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1577
1578 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1579 uint64_t, uint64_t, double, const DeviceMemory<double> &, int,
1580 DeviceMemory<double> *, int>
1581 impl;
1582 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1583 n, alpha, a, lda, b, ldb);
1584 }
1585
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)1586 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1587 blas::Transpose transa, blas::Diagonal diag,
1588 uint64_t m, uint64 n, std::complex<float> alpha,
1589 const DeviceMemory<std::complex<float>> &a,
1590 int lda, DeviceMemory<std::complex<float>> *b,
1591 int ldb) {
1592 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1593 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1594
1595 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1596 uint64_t, uint64_t, std::complex<float>,
1597 const DeviceMemory<std::complex<float>> &, int,
1598 DeviceMemory<std::complex<float>> *, int>
1599 impl;
1600 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1601 n, alpha, a, lda, b, ldb);
1602 }
1603
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)1604 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
1605 blas::Transpose transa, blas::Diagonal diag,
1606 uint64_t m, uint64 n, std::complex<double> alpha,
1607 const DeviceMemory<std::complex<double>> &a,
1608 int lda, DeviceMemory<std::complex<double>> *b,
1609 int ldb) {
1610 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1611 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
1612
1613 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1614 uint64_t, uint64_t, std::complex<double>,
1615 const DeviceMemory<std::complex<double>> &, int,
1616 DeviceMemory<std::complex<double>> *, int>
1617 impl;
1618 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
1619 n, alpha, a, lda, b, ldb);
1620 }
1621
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,float alpha,const DeviceMemory<float * > & as,int lda,DeviceMemory<float * > * bs,int ldb,int batch_count)1622 Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1623 blas::Transpose transa, blas::Diagonal diag,
1624 uint64_t m, uint64 n, float alpha,
1625 const DeviceMemory<float *> &as, int lda,
1626 DeviceMemory<float *> *bs, int ldb,
1627 int batch_count) {
1628 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1629 PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1630 PARAM(ldb), PARAM(batch_count));
1631
1632 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1633 uint64_t, uint64_t, float, const DeviceMemory<float *> &, int,
1634 DeviceMemory<float *> *, int, int>
1635 impl;
1636 return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1637 diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1638 }
1639
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,double alpha,const DeviceMemory<double * > & as,int lda,DeviceMemory<double * > * bs,int ldb,int batch_count)1640 Stream &Stream::ThenBlasTrsmBatched(blas::Side side, blas::UpperLower uplo,
1641 blas::Transpose transa, blas::Diagonal diag,
1642 uint64_t m, uint64 n, double alpha,
1643 const DeviceMemory<double *> &as, int lda,
1644 DeviceMemory<double *> *bs, int ldb,
1645 int batch_count) {
1646 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1647 PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1648 PARAM(ldb), PARAM(batch_count));
1649
1650 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1651 uint64_t, uint64_t, double, const DeviceMemory<double *> &, int,
1652 DeviceMemory<double *> *, int, int>
1653 impl;
1654 return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1655 diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1656 }
1657
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float> * > & as,int lda,DeviceMemory<std::complex<float> * > * bs,int ldb,int batch_count)1658 Stream &Stream::ThenBlasTrsmBatched(
1659 blas::Side side, blas::UpperLower uplo, blas::Transpose transa,
1660 blas::Diagonal diag, uint64_t m, uint64 n, std::complex<float> alpha,
1661 const DeviceMemory<std::complex<float> *> &as, int lda,
1662 DeviceMemory<std::complex<float> *> *bs, int ldb, int batch_count) {
1663 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1664 PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1665 PARAM(ldb), PARAM(batch_count));
1666
1667 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1668 uint64_t, uint64_t, std::complex<float>,
1669 const DeviceMemory<std::complex<float> *> &, int,
1670 DeviceMemory<std::complex<float> *> *, int, int>
1671 impl;
1672 return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1673 diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1674 }
1675
ThenBlasTrsmBatched(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64_t m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double> * > & as,int lda,DeviceMemory<std::complex<double> * > * bs,int ldb,int batch_count)1676 Stream &Stream::ThenBlasTrsmBatched(
1677 blas::Side side, blas::UpperLower uplo, blas::Transpose transa,
1678 blas::Diagonal diag, uint64_t m, uint64 n, std::complex<double> alpha,
1679 const DeviceMemory<std::complex<double> *> &as, int lda,
1680 DeviceMemory<std::complex<double> *> *bs, int ldb, int batch_count) {
1681 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
1682 PARAM(n), PARAM(alpha), PARAM(as), PARAM(lda), PARAM(bs),
1683 PARAM(ldb), PARAM(batch_count));
1684
1685 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
1686 uint64_t, uint64_t, std::complex<double>,
1687 const DeviceMemory<std::complex<double> *> &, int,
1688 DeviceMemory<std::complex<double> *> *, int, int>
1689 impl;
1690 return impl(this, &blas::BlasSupport::DoBlasTrsmBatched, side, uplo, transa,
1691 diag, m, n, alpha, as, lda, bs, ldb, batch_count);
1692 }
1693
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)1694 Stream &Stream::ThenBlasGemmBatched(
1695 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1696 uint64_t k, float alpha,
1697 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok
1698 int lda,
1699 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok
1700 int ldb, float beta,
1701 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok
1702 int ldc, int batch_count) {
1703 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1704 b, ldb, beta, c, ldc, batch_count,
1705 /*scratch_allocator=*/nullptr);
1706 }
1707
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1708 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1709 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1710 uint64_t k, float alpha,
1711 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, // non-absl ok
1712 int lda,
1713 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, // non-absl ok
1714 int ldb, float beta,
1715 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, // non-absl ok
1716 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1717 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1718 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1719 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1720
1721 ThenBlasImpl<
1722 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
1723 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
1724 int,
1725 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
1726 int, float,
1727 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, // non-absl ok
1728 int, int, ScratchAllocator *>
1729 impl;
1730 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1731 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1732 scratch_allocator);
1733 }
1734
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)1735 Stream &Stream::ThenBlasGemmBatched(
1736 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1737 uint64_t k, float alpha,
1738 const port::ArraySlice<DeviceMemory<float> *> &a, // non-absl ok
1739 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, // non-absl ok
1740 int ldb, float beta,
1741 const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok
1742 int ldc, int batch_count) {
1743 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1744 b, ldb, beta, c, ldc, batch_count,
1745 /*scratch_allocator=*/nullptr);
1746 }
1747
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1748 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1749 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1750 uint64_t k, float alpha,
1751 const port::ArraySlice<DeviceMemory<float> *> &a, // non-absl ok
1752 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, // non-absl ok
1753 int ldb, float beta,
1754 const port::ArraySlice<DeviceMemory<float> *> &c, // non-absl ok
1755 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1756 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1757 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1758 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1759
1760 ThenBlasImpl<
1761 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, float,
1762 const port::ArraySlice<DeviceMemory<float> *> &, int, // non-absl ok
1763 const port::ArraySlice<DeviceMemory<float> *> &, int, // non-absl ok
1764 float, const port::ArraySlice<DeviceMemory<float> *> &, // non-absl ok
1765 int, int, ScratchAllocator *>
1766 impl;
1767 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1768 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1769 scratch_allocator);
1770 }
1771
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)1772 Stream &Stream::ThenBlasGemmBatched(
1773 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1774 uint64_t k, double alpha,
1775 const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
1776 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
1777 int ldb, double beta,
1778 const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
1779 int ldc, int batch_count) {
1780 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1781 b, ldb, beta, c, ldc, batch_count,
1782 /*scratch_allocator=*/nullptr);
1783 }
1784
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1785 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1786 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1787 uint64_t k, double alpha,
1788 const port::ArraySlice<DeviceMemory<double> *> &a, // non-absl ok
1789 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, // non-absl ok
1790 int ldb, double beta,
1791 const port::ArraySlice<DeviceMemory<double> *> &c, // non-absl ok
1792 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1793 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1794 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1795 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1796
1797 ThenBlasImpl<
1798 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64, double,
1799 const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
1800 int, const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
1801 int, double,
1802 const port::ArraySlice<DeviceMemory<double> *> &, // non-absl ok
1803 int, int, ScratchAllocator *>
1804 impl;
1805 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1806 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1807 scratch_allocator);
1808 }
1809
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)1810 Stream &Stream::ThenBlasGemmBatched(
1811 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1812 uint64_t k, std::complex<float> alpha,
1813 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1814 &a,
1815 int lda,
1816 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1817 &b,
1818 int ldb, std::complex<float> beta,
1819 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1820 &c,
1821 int ldc, int batch_count) {
1822 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1823 b, ldb, beta, c, ldc, batch_count,
1824 /*scratch_allocator=*/nullptr);
1825 }
1826
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1827 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1828 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1829 uint64_t k, std::complex<float> alpha,
1830 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1831 &a,
1832 int lda,
1833 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1834 &b,
1835 int ldb, std::complex<float> beta,
1836 const port::ArraySlice<DeviceMemory<std::complex<float>> *> // non-absl ok
1837 &c,
1838 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1839 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1840 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1841 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1842
1843 ThenBlasImpl<
1844 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1845 std::complex<float>, const DeviceMemorySlice<std::complex<float>> &, int,
1846 const DeviceMemorySlice<std::complex<float>> &, int, std::complex<float>,
1847 const DeviceMemorySlice<std::complex<float>> &, int, int,
1848 ScratchAllocator *>
1849 impl;
1850 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1851 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1852 scratch_allocator);
1853 }
1854
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemorySlice<std::complex<double>> & a,int lda,const DeviceMemorySlice<std::complex<double>> & b,int ldb,std::complex<double> beta,const DeviceMemorySlice<std::complex<double>> & c,int ldc,int batch_count)1855 Stream &Stream::ThenBlasGemmBatched(
1856 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1857 uint64_t k, std::complex<double> alpha,
1858 const DeviceMemorySlice<std::complex<double>> &a, int lda,
1859 const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1860 std::complex<double> beta, const DeviceMemorySlice<std::complex<double>> &c,
1861 int ldc, int batch_count) {
1862 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
1863 b, ldb, beta, c, ldc, batch_count,
1864 /*scratch_allocator=*/nullptr);
1865 }
1866
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64_t m,uint64 n,uint64_t k,std::complex<double> alpha,const DeviceMemorySlice<std::complex<double>> & a,int lda,const DeviceMemorySlice<std::complex<double>> & b,int ldb,std::complex<double> beta,const DeviceMemorySlice<std::complex<double>> & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)1867 Stream &Stream::ThenBlasGemmBatchedWithScratch(
1868 blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
1869 uint64_t k, std::complex<double> alpha,
1870 const DeviceMemorySlice<std::complex<double>> &a, int lda,
1871 const DeviceMemorySlice<std::complex<double>> &b, int ldb,
1872 std::complex<double> beta,
1873 const DeviceMemorySlice<std::complex<double>> &c, // non-absl ok
1874 int ldc, int batch_count, ScratchAllocator *scratch_allocator) {
1875 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
1876 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
1877 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
1878
1879 ThenBlasImpl<
1880 blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
1881 std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
1882 int, const DeviceMemorySlice<std::complex<double>> &, int,
1883 std::complex<double>, const DeviceMemorySlice<std::complex<double>> &,
1884 int, int, ScratchAllocator *>
1885 impl;
1886 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
1887 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
1888 scratch_allocator);
1889 }
1890
ThenSetRngSeed(const uint8 * seed,uint64_t seed_bytes)1891 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) {
1892 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
1893
1894 if (rng::RngSupport *rng = parent_->AsRng()) {
1895 CheckError(rng->SetSeed(this, seed, seed_bytes));
1896 } else {
1897 SetError();
1898 LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
1899 }
1900 return *this;
1901 }
1902
ThenPopulateRandUniform(DeviceMemory<float> * values)1903 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
1904 VLOG_CALL(PARAM(values));
1905
1906 if (rng::RngSupport *rng = parent_->AsRng()) {
1907 CheckError(rng->DoPopulateRandUniform(this, values));
1908 } else {
1909 SetError();
1910 LOG(INFO) << DebugStreamPointers()
1911 << " attempting to perform RNG operation using StreamExecutor"
1912 " without RNG support.";
1913 }
1914 return *this;
1915 }
1916
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)1917 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
1918 DeviceMemory<float> *values) {
1919 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
1920
1921 if (rng::RngSupport *rng = parent_->AsRng()) {
1922 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
1923 } else {
1924 SetError();
1925 LOG(INFO) << DebugStreamPointers()
1926 << " attempting to perform RNG operation using StreamExecutor"
1927 " without RNG support.";
1928 }
1929 return *this;
1930 }
1931
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)1932 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
1933 DeviceMemory<double> *values) {
1934 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
1935
1936 if (rng::RngSupport *rng = parent_->AsRng()) {
1937 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
1938 } else {
1939 SetError();
1940 LOG(INFO) << DebugStreamPointers()
1941 << " attempting to perform RNG operation using StreamExecutor"
1942 " without RNG support.";
1943 }
1944 return *this;
1945 }
1946
ThenPopulateRandUniform(DeviceMemory<double> * values)1947 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
1948 VLOG_CALL(PARAM(values));
1949
1950 if (rng::RngSupport *rng = parent_->AsRng()) {
1951 CheckError(rng->DoPopulateRandUniform(this, values));
1952 } else {
1953 SetError();
1954 LOG(INFO) << DebugStreamPointers()
1955 << " attempting to perform RNG operation using StreamExecutor"
1956 " without RNG support.";
1957 }
1958 return *this;
1959 }
1960
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)1961 Stream &Stream::ThenPopulateRandUniform(
1962 DeviceMemory<std::complex<float>> *values) {
1963 VLOG_CALL(PARAM(values));
1964
1965 if (rng::RngSupport *rng = parent_->AsRng()) {
1966 CheckError(rng->DoPopulateRandUniform(this, values));
1967 } else {
1968 SetError();
1969 LOG(INFO) << DebugStreamPointers()
1970 << " attempting to perform RNG operation using StreamExecutor"
1971 " without RNG support.";
1972 }
1973 return *this;
1974 }
1975
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)1976 Stream &Stream::ThenPopulateRandUniform(
1977 DeviceMemory<std::complex<double>> *values) {
1978 VLOG_CALL(PARAM(values));
1979
1980 if (rng::RngSupport *rng = parent_->AsRng()) {
1981 CheckError(rng->DoPopulateRandUniform(this, values));
1982 } else {
1983 SetError();
1984 LOG(INFO) << DebugStreamPointers()
1985 << " attempting to perform RNG operation using StreamExecutor"
1986 " without RNG support.";
1987 }
1988 return *this;
1989 }
1990
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64_t size)1991 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
1992 uint64_t size) {
1993 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
1994
1995 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
1996 return *this;
1997 }
1998
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64_t size)1999 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
2000 uint64_t size) {
2001 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
2002
2003 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
2004 return *this;
2005 }
2006
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64_t size)2007 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
2008 const DeviceMemoryBase &gpu_src, uint64_t size) {
2009 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
2010
2011 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
2012 return *this;
2013 }
2014
ThenMemZero(DeviceMemoryBase * location,uint64_t size)2015 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64_t size) {
2016 VLOG_CALL(PARAM(location), PARAM(size));
2017
2018 CheckStatus(parent_->MemZero(this, location, size));
2019 return *this;
2020 }
2021
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64_t size)2022 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
2023 uint64_t size) {
2024 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
2025
2026 CheckStatus(parent_->Memset32(this, location, pattern, size));
2027 return *this;
2028 }
2029
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2030 Stream &Stream::ThenRnnForward(
2031 const dnn::RnnDescriptor &rnn_desc,
2032 const dnn::RnnSequenceTensorDescriptor &input_desc,
2033 const DeviceMemory<Eigen::half> &input_data,
2034 const DeviceMemory<int> &seq_lengths_data,
2035 const dnn::RnnStateTensorDescriptor &input_h_desc,
2036 const DeviceMemory<Eigen::half> &input_h_data,
2037 const dnn::RnnStateTensorDescriptor &input_c_desc,
2038 const DeviceMemory<Eigen::half> &input_c_data,
2039 const DeviceMemory<Eigen::half> ¶ms,
2040 const dnn::RnnSequenceTensorDescriptor &output_desc,
2041 DeviceMemory<Eigen::half> *output_data,
2042 const dnn::RnnStateTensorDescriptor &output_h_desc,
2043 DeviceMemory<Eigen::half> *output_h_data,
2044 const dnn::RnnStateTensorDescriptor &output_c_desc,
2045 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
2046 ScratchAllocator *reserve_space_allocator,
2047 ScratchAllocator *workspace_allocator,
2048 dnn::ProfileResult *output_profile_result) {
2049 // TODO(zhengxq): add VLOG PARAM calls.
2050 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2051 auto status = dnn->DoRnnForward(
2052 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2053 input_h_data, input_c_desc, input_c_data, params, output_desc,
2054 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2055 is_training, reserve_space_allocator, workspace_allocator,
2056 output_profile_result);
2057 if (!status && !output_profile_result) {
2058 SetError();
2059 }
2060 } else {
2061 SetErrorAndLogNoDnnSupport();
2062 }
2063 return *this;
2064 }
2065
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2066 Stream &Stream::ThenRnnForward(
2067 const dnn::RnnDescriptor &rnn_desc,
2068 const dnn::RnnSequenceTensorDescriptor &input_desc,
2069 const DeviceMemory<float> &input_data,
2070 const DeviceMemory<int> &seq_lengths_data,
2071 const dnn::RnnStateTensorDescriptor &input_h_desc,
2072 const DeviceMemory<float> &input_h_data,
2073 const dnn::RnnStateTensorDescriptor &input_c_desc,
2074 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
2075 const dnn::RnnSequenceTensorDescriptor &output_desc,
2076 DeviceMemory<float> *output_data,
2077 const dnn::RnnStateTensorDescriptor &output_h_desc,
2078 DeviceMemory<float> *output_h_data,
2079 const dnn::RnnStateTensorDescriptor &output_c_desc,
2080 DeviceMemory<float> *output_c_data, bool is_training,
2081 ScratchAllocator *reserve_space_allocator,
2082 ScratchAllocator *workspace_allocator,
2083 dnn::ProfileResult *output_profile_result) {
2084 // TODO(zhengxq): add VLOG PARAM calls.
2085 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2086 auto status = dnn->DoRnnForward(
2087 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2088 input_h_data, input_c_desc, input_c_data, params, output_desc,
2089 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2090 is_training, reserve_space_allocator, workspace_allocator,
2091 output_profile_result);
2092 if (!status && !output_profile_result) {
2093 SetError();
2094 }
2095 } else {
2096 SetErrorAndLogNoDnnSupport();
2097 }
2098 return *this;
2099 }
2100
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2101 Stream &Stream::ThenRnnForward(
2102 const dnn::RnnDescriptor &rnn_desc,
2103 const dnn::RnnSequenceTensorDescriptor &input_desc,
2104 const DeviceMemory<double> &input_data,
2105 const DeviceMemory<int> &seq_lengths_data,
2106 const dnn::RnnStateTensorDescriptor &input_h_desc,
2107 const DeviceMemory<double> &input_h_data,
2108 const dnn::RnnStateTensorDescriptor &input_c_desc,
2109 const DeviceMemory<double> &input_c_data,
2110 const DeviceMemory<double> ¶ms,
2111 const dnn::RnnSequenceTensorDescriptor &output_desc,
2112 DeviceMemory<double> *output_data,
2113 const dnn::RnnStateTensorDescriptor &output_h_desc,
2114 DeviceMemory<double> *output_h_data,
2115 const dnn::RnnStateTensorDescriptor &output_c_desc,
2116 DeviceMemory<double> *output_c_data, bool is_training,
2117 ScratchAllocator *reserve_space_allocator,
2118 ScratchAllocator *workspace_allocator,
2119 dnn::ProfileResult *output_profile_result) {
2120 // TODO(zhengxq): add VLOG PARAM calls.
2121 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2122 auto status = dnn->DoRnnForward(
2123 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2124 input_h_data, input_c_desc, input_c_data, params, output_desc,
2125 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2126 is_training, reserve_space_allocator, workspace_allocator,
2127 output_profile_result);
2128 if (!status && !output_profile_result) {
2129 SetError();
2130 }
2131 } else {
2132 SetErrorAndLogNoDnnSupport();
2133 }
2134 return *this;
2135 }
2136
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2137 Stream &Stream::ThenRnnBackward(
2138 const dnn::RnnDescriptor &rnn_desc,
2139 const dnn::RnnSequenceTensorDescriptor &input_desc,
2140 const DeviceMemory<Eigen::half> &input_data,
2141 const DeviceMemory<int> &seq_lengths_data,
2142 const dnn::RnnStateTensorDescriptor &input_h_desc,
2143 const DeviceMemory<Eigen::half> &input_h_data,
2144 const dnn::RnnStateTensorDescriptor &input_c_desc,
2145 const DeviceMemory<Eigen::half> &input_c_data,
2146 const DeviceMemory<Eigen::half> ¶ms,
2147 const dnn::RnnSequenceTensorDescriptor &output_desc,
2148 const DeviceMemory<Eigen::half> &output_data,
2149 const dnn::RnnStateTensorDescriptor &output_h_desc,
2150 const DeviceMemory<Eigen::half> &output_h_data,
2151 const dnn::RnnStateTensorDescriptor &output_c_desc,
2152 const DeviceMemory<Eigen::half> &output_c_data,
2153 const DeviceMemory<Eigen::half> &output_backprop_data,
2154 const DeviceMemory<Eigen::half> &output_h_backprop_data,
2155 const DeviceMemory<Eigen::half> &output_c_backprop_data,
2156 DeviceMemory<Eigen::half> *input_backprop_data,
2157 DeviceMemory<Eigen::half> *input_h_backprop_data,
2158 DeviceMemory<Eigen::half> *input_c_backprop_data,
2159 DeviceMemory<Eigen::half> *params_backprop_data,
2160 DeviceMemory<uint8> *reserve_space_data,
2161 ScratchAllocator *workspace_allocator,
2162 dnn::ProfileResult *output_profile_result) {
2163 // TODO(zhengxq): add VLOG PARAM calls.
2164 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2165 auto status = dnn->DoRnnBackward(
2166 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2167 input_h_data, input_c_desc, input_c_data, params, output_desc,
2168 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2169 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2170 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2171 params_backprop_data, reserve_space_data, workspace_allocator,
2172 output_profile_result);
2173 if (!status && !output_profile_result) {
2174 SetError();
2175 }
2176 } else {
2177 SetError();
2178 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2179 }
2180 return *this;
2181 }
2182
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2183 Stream &Stream::ThenRnnBackward(
2184 const dnn::RnnDescriptor &rnn_desc,
2185 const dnn::RnnSequenceTensorDescriptor &input_desc,
2186 const DeviceMemory<float> &input_data,
2187 const DeviceMemory<int> &seq_lengths_data,
2188 const dnn::RnnStateTensorDescriptor &input_h_desc,
2189 const DeviceMemory<float> &input_h_data,
2190 const dnn::RnnStateTensorDescriptor &input_c_desc,
2191 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
2192 const dnn::RnnSequenceTensorDescriptor &output_desc,
2193 const DeviceMemory<float> &output_data,
2194 const dnn::RnnStateTensorDescriptor &output_h_desc,
2195 const DeviceMemory<float> &output_h_data,
2196 const dnn::RnnStateTensorDescriptor &output_c_desc,
2197 const DeviceMemory<float> &output_c_data,
2198 const DeviceMemory<float> &output_backprop_data,
2199 const DeviceMemory<float> &output_h_backprop_data,
2200 const DeviceMemory<float> &output_c_backprop_data,
2201 DeviceMemory<float> *input_backprop_data,
2202 DeviceMemory<float> *input_h_backprop_data,
2203 DeviceMemory<float> *input_c_backprop_data,
2204 DeviceMemory<float> *params_backprop_data,
2205 DeviceMemory<uint8> *reserve_space_data,
2206 ScratchAllocator *workspace_allocator,
2207 dnn::ProfileResult *output_profile_result) {
2208 // TODO(zhengxq): add VLOG PARAM calls.
2209 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2210 auto status = dnn->DoRnnBackward(
2211 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2212 input_h_data, input_c_desc, input_c_data, params, output_desc,
2213 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2214 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2215 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2216 params_backprop_data, reserve_space_data, workspace_allocator,
2217 output_profile_result);
2218 if (!status && !output_profile_result) {
2219 SetError();
2220 }
2221 } else {
2222 SetError();
2223 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2224 }
2225 return *this;
2226 }
2227
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2228 Stream &Stream::ThenRnnBackward(
2229 const dnn::RnnDescriptor &rnn_desc,
2230 const dnn::RnnSequenceTensorDescriptor &input_desc,
2231 const DeviceMemory<double> &input_data,
2232 const DeviceMemory<int> &seq_lengths_data,
2233 const dnn::RnnStateTensorDescriptor &input_h_desc,
2234 const DeviceMemory<double> &input_h_data,
2235 const dnn::RnnStateTensorDescriptor &input_c_desc,
2236 const DeviceMemory<double> &input_c_data,
2237 const DeviceMemory<double> ¶ms,
2238 const dnn::RnnSequenceTensorDescriptor &output_desc,
2239 const DeviceMemory<double> &output_data,
2240 const dnn::RnnStateTensorDescriptor &output_h_desc,
2241 const DeviceMemory<double> &output_h_data,
2242 const dnn::RnnStateTensorDescriptor &output_c_desc,
2243 const DeviceMemory<double> &output_c_data,
2244 const DeviceMemory<double> &output_backprop_data,
2245 const DeviceMemory<double> &output_h_backprop_data,
2246 const DeviceMemory<double> &output_c_backprop_data,
2247 DeviceMemory<double> *input_backprop_data,
2248 DeviceMemory<double> *input_h_backprop_data,
2249 DeviceMemory<double> *input_c_backprop_data,
2250 DeviceMemory<double> *params_backprop_data,
2251 DeviceMemory<uint8> *reserve_space_data,
2252 ScratchAllocator *workspace_allocator,
2253 dnn::ProfileResult *output_profile_result) {
2254 // TODO(zhengxq): add VLOG PARAM calls.
2255 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2256 auto status = dnn->DoRnnBackward(
2257 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
2258 input_h_data, input_c_desc, input_c_data, params, output_desc,
2259 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
2260 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
2261 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
2262 params_backprop_data, reserve_space_data, workspace_allocator,
2263 output_profile_result);
2264 if (!status && !output_profile_result) {
2265 SetError();
2266 }
2267 } else {
2268 SetError();
2269 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
2270 }
2271 return *this;
2272 }
2273
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)2274 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
2275 const DeviceMemory<float> &probs_data,
2276 absl::Span<const int> labels_data,
2277 absl::Span<const int> labels_lengths_data,
2278 absl::Span<const int> input_lengths_data,
2279 DeviceMemory<float> *costs_data,
2280 const dnn::RnnStateTensorDescriptor &grads_desc,
2281 DeviceMemory<float> *grads_data,
2282 ScratchAllocator *workspace_allocator) {
2283 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2284 DeviceMemory<uint8> scratch_memory;
2285 int ctc_loss_algo_id;
2286 auto status =
2287 dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
2288 labels_data, labels_lengths_data,
2289 input_lengths_data, workspace_allocator,
2290 &scratch_memory, &ctc_loss_algo_id)
2291 .ok();
2292 if (status) {
2293 status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
2294 labels_lengths_data, input_lengths_data,
2295 costs_data, grads_desc, grads_data,
2296 &scratch_memory, ctc_loss_algo_id);
2297 }
2298 if (!status) {
2299 SetError();
2300 }
2301 } else {
2302 SetErrorAndLogNoDnnSupport();
2303 }
2304 return *this;
2305 }
2306
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)2307 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
2308 dnn::DataType input_type,
2309 const DeviceMemoryBase &input_data,
2310 const dnn::BatchDescriptor &output_desc,
2311 dnn::DataType output_type, float scale,
2312 DeviceMemoryBase *output_data) {
2313 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
2314 PARAM(output_desc), PARAM(output_type), PARAM(scale),
2315 PARAM(output_data));
2316 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
2317 CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
2318 output_desc, output_type, scale,
2319 output_data));
2320 } else {
2321 SetErrorAndLogNoDnnSupport();
2322 }
2323 return *this;
2324 }
2325
ThenDoHostCallback(std::function<void ()> callback)2326 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
2327 VLOG_CALL(PARAM(callback));
2328
2329 if (!ok()) {
2330 LOG(INFO) << DebugStreamPointers()
2331 << " was in error state before adding host callback";
2332 }
2333 CheckError(parent_->HostCallback(this, std::move(callback)));
2334 return *this;
2335 }
2336
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)2337 Stream &Stream::ThenDoHostCallbackWithStatus(
2338 std::function<port::Status()> callback) {
2339 VLOG_CALL(PARAM(callback));
2340
2341 if (!ok()) {
2342 LOG(INFO) << DebugStreamPointers()
2343 << " was in error state before adding host callback";
2344 }
2345 CheckError(parent_->HostCallback(this, std::move(callback)));
2346 return *this;
2347 }
2348
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)2349 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
2350 std::function<void()> callback) {
2351 VLOG_CALL(PARAM(callback));
2352
2353 if (!ok()) {
2354 LOG(INFO) << DebugStreamPointers()
2355 << " was in error state before adding callback to be run after "
2356 "next block-host-until-done.";
2357 }
2358 absl::MutexLock lock(&mu_);
2359 after_block_host_until_done_callbacks_.push_back(std::move(callback));
2360 return *this;
2361 }
2362
CheckError(bool operation_retcode)2363 void Stream::CheckError(bool operation_retcode) {
2364 if (operation_retcode) {
2365 return;
2366 }
2367 absl::MutexLock lock(&mu_);
2368 status_ = port::InternalError("Unknown error");
2369 }
2370
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)2371 Stream &Stream::ThenFft(fft::Plan *plan,
2372 const DeviceMemory<std::complex<float>> &input,
2373 DeviceMemory<std::complex<float>> *output) {
2374 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2375
2376 if (fft::FftSupport *fft = parent_->AsFft()) {
2377 CheckError(fft->DoFft(this, plan, input, output));
2378 } else {
2379 SetError();
2380 LOG(INFO) << DebugStreamPointers()
2381 << " attempting to perform FFT operation using StreamExecutor"
2382 " without FFT support";
2383 }
2384 return *this;
2385 }
2386
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)2387 Stream &Stream::ThenFft(fft::Plan *plan,
2388 const DeviceMemory<std::complex<double>> &input,
2389 DeviceMemory<std::complex<double>> *output) {
2390 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2391
2392 if (fft::FftSupport *fft = parent_->AsFft()) {
2393 CheckError(fft->DoFft(this, plan, input, output));
2394 } else {
2395 SetError();
2396 LOG(INFO) << DebugStreamPointers()
2397 << " attempting to perform FFT operation using StreamExecutor"
2398 " without FFT support";
2399 }
2400 return *this;
2401 }
2402
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)2403 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
2404 DeviceMemory<std::complex<float>> *output) {
2405 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2406
2407 if (fft::FftSupport *fft = parent_->AsFft()) {
2408 CheckError(fft->DoFft(this, plan, input, output));
2409 } else {
2410 SetError();
2411 LOG(INFO) << DebugStreamPointers()
2412 << " attempting to perform FFT operation using StreamExecutor"
2413 " without FFT support";
2414 }
2415 return *this;
2416 }
2417
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)2418 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
2419 DeviceMemory<std::complex<double>> *output) {
2420 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2421
2422 if (fft::FftSupport *fft = parent_->AsFft()) {
2423 CheckError(fft->DoFft(this, plan, input, output));
2424 } else {
2425 SetError();
2426 LOG(INFO) << DebugStreamPointers()
2427 << " attempting to perform FFT operation using StreamExecutor"
2428 " without FFT support";
2429 }
2430 return *this;
2431 }
2432
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)2433 Stream &Stream::ThenFft(fft::Plan *plan,
2434 const DeviceMemory<std::complex<float>> &input,
2435 DeviceMemory<float> *output) {
2436 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2437
2438 if (fft::FftSupport *fft = parent_->AsFft()) {
2439 CheckError(fft->DoFft(this, plan, input, output));
2440 } else {
2441 SetError();
2442 LOG(INFO) << DebugStreamPointers()
2443 << " attempting to perform FFT operation using StreamExecutor"
2444 " without FFT support";
2445 }
2446 return *this;
2447 }
2448
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)2449 Stream &Stream::ThenFft(fft::Plan *plan,
2450 const DeviceMemory<std::complex<double>> &input,
2451 DeviceMemory<double> *output) {
2452 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
2453
2454 if (fft::FftSupport *fft = parent_->AsFft()) {
2455 CheckError(fft->DoFft(this, plan, input, output));
2456 } else {
2457 SetError();
2458 LOG(INFO) << DebugStreamPointers()
2459 << " attempting to perform FFT operation using StreamExecutor"
2460 " without FFT support";
2461 }
2462 return *this;
2463 }
2464
2465 // It looks confusing, but all this is doing is inserting a callback at the
2466 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)2467 Stream &Stream::ThenEnqueueOnBackgroundThread(
2468 std::function<void(StreamExecutor *)> task) {
2469 VLOG_CALL(PARAM(task));
2470
2471 StreamExecutor *stream_executor = this->parent_;
2472 std::function<void()> bound_task = std::bind(task, stream_executor);
2473
2474 return ThenDoHostCallback([stream_executor, bound_task]() {
2475 stream_executor->EnqueueOnBackgroundThread(bound_task);
2476 });
2477 }
2478
BlockHostUntilDone()2479 port::Status Stream::BlockHostUntilDone() {
2480 VLOG_CALL();
2481
2482 if (!ok()) {
2483 absl::MutexLock lock(&mu_);
2484 LOG(INFO) << status_.ToString();
2485 port::Status status = port::Status(
2486 port::error::INTERNAL,
2487 "stream did not block host until done; was already in an error state");
2488 LOG(INFO) << DebugStreamPointers() << " " << status;
2489 return status;
2490 }
2491
2492 temporary_memory_manager_.DeallocateFinalizedTemporaries();
2493
2494 port::Status error = parent_->BlockHostUntilDone(this);
2495 CheckError(error.ok());
2496
2497 RunAfterBlockHostUntilDoneCallbacks();
2498 return error;
2499 }
2500
RunAfterBlockHostUntilDoneCallbacks()2501 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
2502 std::vector<std::function<void()>> callbacks;
2503 {
2504 absl::MutexLock lock(&mu_);
2505 std::swap(callbacks, after_block_host_until_done_callbacks_);
2506 }
2507 for (const auto &fn : callbacks) {
2508 fn();
2509 }
2510 }
2511
DebugStreamPointers() const2512 std::string Stream::DebugStreamPointers() const {
2513 // Relies on the ToVlogString(const void*) overload above.
2514 return absl::StrCat("[stream=", ToVlogString(this),
2515 ",impl=", ToVlogString(implementation_.get()), "]");
2516 }
2517
CheckStatus(port::Status status)2518 void Stream::CheckStatus(port::Status status) {
2519 if (status.ok()) {
2520 return;
2521 }
2522 LOG(ERROR) << status;
2523 absl::MutexLock lock(&mu_);
2524 status_ = status;
2525 }
2526
2527 } // namespace stream_executor
2528