1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/stream.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_or_device_scalar.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/platform.h"
24 #include "tensorflow/stream_executor/platform/logging.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29
30 namespace stream_executor {
31
32 namespace {
33 // Code to turn parameters to functions on stream into strings that
34 // will be VLOG'ed. We need overloads, instead of
35 // e.g. BatchDescriptorToVlogString(), as the code that calls these
36 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)37 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
38 return descriptor.ToShortString();
39 }
40
ToVlogString(const dnn::FilterDescriptor & descriptor)41 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
42 return descriptor.ToShortString();
43 }
44
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)45 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
46 return descriptor.ToShortString();
47 }
48
ToVlogString(const dnn::PoolingDescriptor & descriptor)49 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
50 return descriptor.ToShortString();
51 }
52
ToVlogString(const dnn::NormalizeDescriptor & descriptor)53 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
54 return descriptor.ToShortString();
55 }
56
ToVlogString(dnn::ActivationMode mode)57 std::string ToVlogString(dnn::ActivationMode mode) {
58 return dnn::ActivationModeString(mode);
59 }
60
ToVlogString(const dnn::AlgorithmConfig & algo_config)61 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
62 return algo_config.ToString();
63 }
64
ToVlogString(dnn::ElementwiseOperation op)65 std::string ToVlogString(dnn::ElementwiseOperation op) {
66 return dnn::ElementwiseOperationString(op);
67 }
68
ToVlogString(dnn::QuantizedActivationMode mode)69 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
70 return dnn::QuantizedActivationModeString(mode);
71 }
72
ToVlogString(blas::Transpose t)73 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
74
ToVlogString(blas::UpperLower ul)75 std::string ToVlogString(blas::UpperLower ul) {
76 return blas::UpperLowerString(ul);
77 }
78
ToVlogString(blas::Diagonal d)79 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
80
ToVlogString(blas::Side s)81 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
82
ToVlogString(blas::ComputationType ty)83 std::string ToVlogString(blas::ComputationType ty) {
84 return blas::ComputationTypeString(ty);
85 }
86
ToVlogString(const void * ptr)87 std::string ToVlogString(const void *ptr) {
88 if (ptr == nullptr) {
89 return "null";
90 }
91
92 // StrCat does not convert pointers to text.
93 std::ostringstream out;
94 out << ptr;
95 return out.str();
96 }
97
98 template <class T>
ToVlogString(const std::complex<T> & c)99 std::string ToVlogString(const std::complex<T> &c) {
100 // StrCat does not convert std::complex to text.
101 std::ostringstream out;
102 out << c;
103 return out.str();
104 }
105
106 template <class T>
ToVlogString(const std::function<T> & f)107 std::string ToVlogString(const std::function<T> &f) {
108 return f == nullptr ? "null" : "<non-null function>";
109 }
110
ToVlogString(const DeviceMemoryBase & memory)111 std::string ToVlogString(const DeviceMemoryBase &memory) {
112 return ToVlogString(memory.opaque());
113 }
114
ToVlogString(const DeviceMemoryBase * memory)115 std::string ToVlogString(const DeviceMemoryBase *memory) {
116 return memory == nullptr ? "null" : ToVlogString(*memory);
117 }
118
ToVlogString(const Eigen::half & h)119 std::string ToVlogString(const Eigen::half &h) {
120 return absl::StrCat(static_cast<float>(h));
121 }
122
ToVlogString(int i)123 std::string ToVlogString(int i) { return absl::StrCat(i); }
124
ToVlogString(uint32 i)125 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
126
ToVlogString(uint64 i)127 std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
128
ToVlogString(int64_t i)129 std::string ToVlogString(int64_t i) { return absl::StrCat(i); }
130
ToVlogString(float f)131 std::string ToVlogString(float f) { return absl::StrCat(f); }
132
ToVlogString(double d)133 std::string ToVlogString(double d) { return absl::StrCat(d); }
134
135 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)136 std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
137 if (memory_or_constant.is_pointer()) {
138 return ToVlogString(memory_or_constant.pointer());
139 }
140 return ToVlogString(memory_or_constant.value());
141 }
142
143 template <class T>
ToVlogString(port::ArraySlice<T> elements)144 std::string ToVlogString(port::ArraySlice<T> elements) {
145 std::string str = absl::StrCat(
146 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
147 elements.size(), "]{");
148 const char *separator = "";
149 size_t max_to_show = std::numeric_limits<size_t>::max();
150 if (!VLOG_IS_ON(2)) {
151 max_to_show = 5;
152 } else if (!VLOG_IS_ON(3)) {
153 max_to_show = 20;
154 } else if (!VLOG_IS_ON(11)) {
155 max_to_show = 1000;
156 }
157 for (size_t i = 0; i < elements.size(); ++i) {
158 if (i == max_to_show) {
159 str += ", ...";
160 break;
161 }
162 absl::StrAppend(&str, separator, ToVlogString(elements[i]));
163 separator = ", ";
164 }
165 str += "}";
166 return str;
167 }
168
169 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)170 std::string ToVlogString(port::MutableArraySlice<T> elements) {
171 return ToVlogString(port::ArraySlice<T>(elements));
172 }
173
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)174 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
175 switch (depth_to_space_layout) {
176 case dnn::DepthToSpaceLayout::DepthHeightWidth:
177 return "DepthToSpaceLayout::DepthHeightWidth";
178 }
179 return "unknown DepthToSpaceLayout";
180 }
181
ToVlogString(dnn::DataType data_type)182 std::string ToVlogString(dnn::DataType data_type) {
183 switch (data_type) {
184 case dnn::DataType::kFloat:
185 return "dnn::DataType::kFloat";
186 case dnn::DataType::kDouble:
187 return "dnn::DataType::kDouble";
188 case dnn::DataType::kHalf:
189 return "dnn::DataType::kHalf";
190 case dnn::DataType::kInt8:
191 return "dnn::DataType::kInt8";
192 case dnn::DataType::kInt32:
193 return "dnn::DataType::kInt32";
194 default:
195 return "unknown DataType";
196 }
197 }
198
199 // Used together with PARAM to VLOG calls made to the stream. Intended
200 // to be used like this:
201 //
202 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
203 //
204 // where a and b are the parameters to MyFunction.
205 //
206 // See VLOG_CALL for a short-hand for this. This way of doing it saves
207 // a tremendous amount of boilerplate code given how many functions
208 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)209 std::string CallStr(const char *function_name, Stream *stream,
210 std::vector<std::pair<const char *, std::string>> params) {
211 // Do not call this function unless VLOG is on since just
212 // constructing all the strings in params is expensive.
213 CHECK(VLOG_IS_ON(1));
214
215 std::string str = absl::StrCat(stream->DebugStreamPointers(),
216 " Called Stream::", function_name, "(");
217 const char *separator = "";
218 for (const auto ¶m : params) {
219 absl::StrAppend(&str, separator, param.first, "=", param.second);
220 separator = ", ";
221 }
222 absl::StrAppend(&str, ")");
223 if (VLOG_IS_ON(10)) {
224 absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
225 }
226 return str;
227 }
228
229 // Use this macro to avoid having to type every parameter twice to log
230 // it with VLOG and CallStr.
231 #define PARAM(parameter) \
232 { #parameter, ToVlogString(parameter) }
233
234 // Use this macro to avoid having to type out the name of each
235 // function and to save some boilerplate. Intended to be used like this:
236 //
237 // VLOG_CALL(PARAM(a), PARAM(b))
238 //
239 // This saves a tremendous amount of boilerplate compared to the alternative:
240 //
241 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
242 // << ", b=" << ToVlogString(b);
243 //
244 // Note here that most of the parameter names are not short and that
245 // most of the functions take many more than 2 parameters.
246 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
247
248 } // namespace
249
Stream(StreamExecutor * parent)250 Stream::Stream(StreamExecutor *parent)
251 : parent_(parent),
252 implementation_(parent->implementation()->GetStreamImplementation()),
253 allocated_(false),
254 status_(port::InternalError("Uninitialized stream")),
255 temporary_memory_manager_(this) {
256 VLOG_CALL(PARAM(parent));
257 }
258
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)259 Stream::Stream(StreamExecutor *parent,
260 internal::StreamInterface *implementation)
261 : parent_(parent),
262 implementation_(implementation),
263 allocated_(false),
264 status_(port::InternalError("Uninitialized stream")),
265 temporary_memory_manager_(this) {
266 VLOG_CALL(PARAM(parent), PARAM(implementation));
267 }
268
~Stream()269 Stream::~Stream() {
270 VLOG_CALL();
271
272 // Ensure the stream is completed.
273 auto status = BlockHostUntilDone();
274 if (!status.ok()) {
275 LOG(WARNING) << "Error blocking host until done in stream destructor: "
276 << status;
277 }
278 temporary_memory_manager_.ForceDeallocateAll();
279 RunAfterBlockHostUntilDoneCallbacks();
280
281 if (allocated_) {
282 parent_->DeallocateStream(this);
283 }
284 }
285
RefreshStatus()286 port::Status Stream::RefreshStatus() {
287 port::Status status = parent_->GetStatus(this);
288 // We should not put the stream in an error state, just because the GetStatus
289 // method is unimplemented.
290 if (status != port::Status(port::error::UNIMPLEMENTED,
291 "GetStatus is not supported on this executor.")) {
292 CheckStatus(status);
293 }
294 return status;
295 }
296
Init()297 Stream &Stream::Init() {
298 VLOG_CALL();
299
300 absl::MutexLock lock(&mu_);
301 CHECK_EQ(false, allocated_)
302 << "stream appears to already have been initialized";
303 CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
304
305 if (parent_->AllocateStream(this)) {
306 // Successful initialization!
307 allocated_ = true;
308 status_ = port::Status::OK();
309 } else {
310 LOG(ERROR) << "failed to allocate stream during initialization";
311 }
312
313 return *this;
314 }
315
InitTimer(Timer * timer)316 Stream &Stream::InitTimer(Timer *timer) {
317 VLOG_CALL(PARAM(timer));
318
319 CheckError(parent_->AllocateTimer(timer));
320 return *this;
321 }
322
InitWithTimer(Timer * timer)323 Stream &Stream::InitWithTimer(Timer *timer) {
324 VLOG_CALL(PARAM(timer));
325
326 return Init().InitTimer(timer);
327 }
328
ThenRecordEvent(Event * event)329 Stream &Stream::ThenRecordEvent(Event *event) {
330 VLOG_CALL(PARAM(event));
331
332 port::Status status = parent_->RecordEvent(this, event);
333 if (!status.ok()) {
334 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
335 << "; not marking stream as bad, as the Event object may be "
336 << "at fault. Monitor for further errors.";
337 }
338
339 return *this;
340 }
341
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)342 Stream &Stream::ThenBatchNormalizationForward(
343 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
344 const DeviceMemory<float> &offset,
345 const DeviceMemory<float> &estimated_mean,
346 const DeviceMemory<float> &estimated_variance,
347 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
348 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
349 const double exponential_average_factor,
350 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
351 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
352 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
353 bool is_training, ScratchAllocator *reserve_space_allocator,
354 ScratchAllocator *workspace_allocator) {
355 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
356 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
357 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
358 CheckError(dnn->DoBatchNormalizationForward(
359 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
360 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
361 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
362 is_training, reserve_space_allocator, workspace_allocator));
363 } else {
364 SetErrorAndLogNoDnnSupport();
365 }
366 return *this;
367 }
368
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)369 Stream &Stream::ThenBatchNormalizationBackward(
370 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
371 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
372 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
373 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
374 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
375 DeviceMemory<float> *offset_backprop,
376 DeviceMemory<uint8> *reserve_space_data,
377 ScratchAllocator *workspace_allocator) {
378 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
379 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
380 PARAM(scale_backprop), PARAM(offset_backprop));
381 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
382 CheckError(dnn->DoBatchNormalizationBackward(
383 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
384 epsilon, x_backprop, scale_backprop, offset_backprop,
385 reserve_space_data, workspace_allocator));
386 } else {
387 SetErrorAndLogNoDnnSupport();
388 }
389 return *this;
390 }
391
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)392 Stream &Stream::ThenBatchNormalizationForward(
393 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
394 const DeviceMemory<float> &offset,
395 const DeviceMemory<float> &estimated_mean,
396 const DeviceMemory<float> &estimated_variance,
397 const DeviceMemory<Eigen::half> &side_input,
398 const dnn::BatchDescriptor &x_desc,
399 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
400 const double exponential_average_factor,
401 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
402 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
403 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
404 bool is_training, ScratchAllocator *reserve_space_allocator,
405 ScratchAllocator *workspace_allocator) {
406 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
407 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
408 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
409 CheckError(dnn->DoBatchNormalizationForward(
410 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
411 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
412 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
413 is_training, reserve_space_allocator, workspace_allocator));
414 } else {
415 SetErrorAndLogNoDnnSupport();
416 }
417 return *this;
418 }
419
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)420 Stream &Stream::ThenBatchNormalizationBackward(
421 const DeviceMemory<Eigen::half> &y_backprop,
422 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
423 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
424 const dnn::BatchDescriptor &x_desc,
425 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
426 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
427 DeviceMemory<float> *offset_backprop,
428 DeviceMemory<uint8> *reserve_space_data,
429 ScratchAllocator *workspace_allocator) {
430 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
431 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
432 PARAM(scale_backprop), PARAM(offset_backprop));
433 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
434 CheckError(dnn->DoBatchNormalizationBackward(
435 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
436 epsilon, x_backprop, scale_backprop, offset_backprop,
437 reserve_space_data, workspace_allocator));
438
439 } else {
440 SetErrorAndLogNoDnnSupport();
441 }
442 return *this;
443 }
444
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)445 Stream &Stream::ThenConvolve(
446 const dnn::BatchDescriptor &input_descriptor,
447 const DeviceMemory<float> &input_data,
448 const dnn::FilterDescriptor &filter_descriptor,
449 const DeviceMemory<float> &filter_data,
450 const dnn::ConvolutionDescriptor &convolution_descriptor,
451 const dnn::BatchDescriptor &output_descriptor,
452 DeviceMemory<float> *output) {
453 if (ok()) {
454 CheckError(ConvolveWithAlgorithm(
455 input_descriptor, input_data, filter_descriptor, filter_data,
456 convolution_descriptor, output_descriptor, output,
457 /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
458 /*output_profile_result=*/nullptr)
459 .ok());
460 }
461 return *this;
462 }
463
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)464 Stream &Stream::ThenConvolveQuantized(
465 const dnn::BatchDescriptor &input_descriptor,
466 const DeviceMemory<float> &input_data,
467 const dnn::FilterDescriptor &filter_descriptor,
468 const DeviceMemory<int8> &filter_coefficients,
469 const DeviceMemory<float> &coefficient_scales,
470 const dnn::ConvolutionDescriptor &convolution_descriptor,
471 const dnn::BatchDescriptor &output_descriptor,
472 DeviceMemory<float> *output) {
473 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
474 PARAM(filter_descriptor), PARAM(filter_coefficients),
475 PARAM(coefficient_scales), PARAM(convolution_descriptor),
476 PARAM(output_descriptor), PARAM(output));
477
478 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
479 CheckError(dnn->DoConvolveQuantized(
480 this, input_descriptor, input_data, filter_descriptor,
481 filter_coefficients, coefficient_scales, convolution_descriptor,
482 output_descriptor, output));
483 } else {
484 SetError();
485 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
486 "without DNN support";
487 }
488 return *this;
489 }
490
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)491 Stream &Stream::ThenConvolveQuantized(
492 const dnn::BatchDescriptor &input_descriptor,
493 const DeviceMemory<float> &input_data,
494 const dnn::FilterDescriptor &filter_descriptor,
495 const DeviceMemory<int16> &filter_coefficients,
496 const DeviceMemory<float> &coefficient_scales,
497 const dnn::ConvolutionDescriptor &convolution_descriptor,
498 const dnn::BatchDescriptor &output_descriptor,
499 DeviceMemory<float> *output) {
500 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
501 PARAM(filter_descriptor), PARAM(filter_coefficients),
502 PARAM(coefficient_scales), PARAM(convolution_descriptor),
503 PARAM(output_descriptor), PARAM(output));
504
505 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
506 CheckError(dnn->DoConvolveQuantized(
507 this, input_descriptor, input_data, filter_descriptor,
508 filter_coefficients, coefficient_scales, convolution_descriptor,
509 output_descriptor, output));
510 } else {
511 SetError();
512 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
513 "without DNN support";
514 }
515 return *this;
516 }
517
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)518 Stream &Stream::ThenSeparableConvolve(
519 const dnn::BatchDescriptor &batch_descriptor,
520 const DeviceMemory<float> &input_data,
521 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
522 const DeviceMemory<float> &first_weights,
523 const DeviceMemory<float> &second_weights,
524 const dnn::ConvolutionDescriptor &convolution_descriptor,
525 const dnn::BatchDescriptor &output_descriptor,
526 DeviceMemory<float> *output) {
527 VLOG_CALL(
528 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
529 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
530 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
531
532 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
533 CheckError(dnn->DoSeparableConvolve(
534 this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
535 first_weights, second_weights, convolution_descriptor,
536 output_descriptor, output));
537 } else {
538 SetErrorAndLogNoDnnSupport();
539 }
540 return *this;
541 }
542
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)543 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
544 const DeviceMemory<float> &weights,
545 const dnn::BatchDescriptor &input_dimensions,
546 const dnn::BatchDescriptor &output_dimensions,
547 DeviceMemory<float> *output_data) {
548 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
549 PARAM(output_dimensions), PARAM(output_data));
550
551 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
552 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
553 output_dimensions, output_data));
554 } else {
555 SetErrorAndLogNoDnnSupport();
556 }
557 return *this;
558 }
559
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)560 Stream &Stream::ThenMatMulQuantized(
561 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
562 const DeviceMemory<float> &weight_scales,
563 const dnn::BatchDescriptor &input_dimensions,
564 const dnn::BatchDescriptor &output_dimensions,
565 DeviceMemory<float> *output_data) {
566 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
567 PARAM(input_dimensions), PARAM(output_dimensions),
568 PARAM(output_data));
569
570 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
571 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
572 input_dimensions, output_dimensions,
573 output_data));
574 } else {
575 SetErrorAndLogNoDnnSupport();
576 }
577 return *this;
578 }
579
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)580 Stream &Stream::ThenMatMulQuantized(
581 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
582 const DeviceMemory<float> &weight_scales,
583 const dnn::BatchDescriptor &input_dimensions,
584 const dnn::BatchDescriptor &output_dimensions,
585 DeviceMemory<float> *output_data) {
586 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
587 PARAM(input_dimensions), PARAM(output_dimensions),
588 PARAM(output_data));
589
590 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
591 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
592 input_dimensions, output_dimensions,
593 output_data));
594 } else {
595 SetErrorAndLogNoDnnSupport();
596 }
597 return *this;
598 }
599
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)600 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
601 const DeviceMemory<float> &biases,
602 const dnn::BatchDescriptor &dimensions,
603 DeviceMemory<float> *output_data) {
604 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
605 PARAM(output_data));
606
607 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
608 CheckError(
609 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
610 } else {
611 SetErrorAndLogNoDnnSupport();
612 }
613 return *this;
614 }
615
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)616 Stream &Stream::ThenPoolForward(
617 const dnn::PoolingDescriptor &pooling_dimensions,
618 const dnn::BatchDescriptor &input_dimensions,
619 const DeviceMemory<double> &input_data,
620 const dnn::BatchDescriptor &output_dimensions,
621 DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
622 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
623 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
624 PARAM(workspace_allocator));
625
626 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
627 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
628 input_data, output_dimensions, output_data,
629 workspace_allocator));
630 } else {
631 SetError();
632 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
633 "without DNN support";
634 }
635 return *this;
636 }
637
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)638 Stream &Stream::ThenPoolForward(
639 const dnn::PoolingDescriptor &pooling_dimensions,
640 const dnn::BatchDescriptor &input_dimensions,
641 const DeviceMemory<float> &input_data,
642 const dnn::BatchDescriptor &output_dimensions,
643 DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
644 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
645 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
646 PARAM(workspace_allocator));
647
648 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
649 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
650 input_data, output_dimensions, output_data,
651 workspace_allocator));
652 } else {
653 SetErrorAndLogNoDnnSupport();
654 }
655 return *this;
656 }
657
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)658 Stream &Stream::ThenPoolForward(
659 const dnn::PoolingDescriptor &pooling_dimensions,
660 const dnn::BatchDescriptor &input_dimensions,
661 const DeviceMemory<Eigen::half> &input_data,
662 const dnn::BatchDescriptor &output_dimensions,
663 DeviceMemory<Eigen::half> *output_data,
664 ScratchAllocator *workspace_allocator) {
665 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
666 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
667 PARAM(workspace_allocator));
668
669 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
670 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
671 input_data, output_dimensions, output_data,
672 workspace_allocator));
673 } else {
674 SetErrorAndLogNoDnnSupport();
675 }
676 return *this;
677 }
678
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)679 Stream &Stream::ThenPoolForward(
680 const dnn::PoolingDescriptor &pooling_dimensions,
681 const dnn::BatchDescriptor &input_dimensions,
682 const DeviceMemory<int8> &input_data,
683 const dnn::BatchDescriptor &output_dimensions,
684 DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
685 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
686 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
687 PARAM(workspace_allocator));
688
689 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
690 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
691 input_data, output_dimensions, output_data,
692 workspace_allocator));
693 } else {
694 SetErrorAndLogNoDnnSupport();
695 }
696 return *this;
697 }
698
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)699 Stream &Stream::ThenPoolBackward(
700 const dnn::PoolingDescriptor &pooling_dimensions,
701 const dnn::BatchDescriptor &input_dimensions,
702 const DeviceMemory<double> &input_data,
703 const dnn::BatchDescriptor &output_dimensions,
704 const DeviceMemory<double> &output_data,
705 const DeviceMemory<double> &input_diff_data,
706 DeviceMemory<double> *output_diff_data,
707 ScratchAllocator *workspace_allocator) {
708 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
709 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
710 PARAM(input_diff_data), PARAM(output_diff_data),
711 PARAM(workspace_allocator));
712
713 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
714 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
715 input_data, output_dimensions, output_data,
716 input_diff_data, output_diff_data,
717 workspace_allocator));
718 } else {
719 SetError();
720 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
721 "without DNN support";
722 }
723 return *this;
724 }
725
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)726 Stream &Stream::ThenPoolBackward(
727 const dnn::PoolingDescriptor &pooling_dimensions,
728 const dnn::BatchDescriptor &input_dimensions,
729 const DeviceMemory<float> &input_data,
730 const dnn::BatchDescriptor &output_dimensions,
731 const DeviceMemory<float> &output_data,
732 const DeviceMemory<float> &input_diff_data,
733 DeviceMemory<float> *output_diff_data,
734 ScratchAllocator *workspace_allocator) {
735 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
736 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
737 PARAM(input_diff_data), PARAM(output_diff_data),
738 PARAM(workspace_allocator));
739
740 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
741 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
742 input_data, output_dimensions, output_data,
743 input_diff_data, output_diff_data,
744 workspace_allocator));
745 } else {
746 SetErrorAndLogNoDnnSupport();
747 }
748 return *this;
749 }
750
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)751 Stream &Stream::ThenPoolBackward(
752 const dnn::PoolingDescriptor &pooling_dimensions,
753 const dnn::BatchDescriptor &input_dimensions,
754 const DeviceMemory<Eigen::half> &input_data,
755 const dnn::BatchDescriptor &output_dimensions,
756 const DeviceMemory<Eigen::half> &output_data,
757 const DeviceMemory<Eigen::half> &input_diff_data,
758 DeviceMemory<Eigen::half> *output_diff_data,
759 ScratchAllocator *workspace_allocator) {
760 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
761 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
762 PARAM(input_diff_data), PARAM(output_diff_data),
763 PARAM(workspace_allocator));
764
765 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
766 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
767 input_data, output_dimensions, output_data,
768 input_diff_data, output_diff_data,
769 workspace_allocator));
770 } else {
771 SetErrorAndLogNoDnnSupport();
772 }
773 return *this;
774 }
775
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)776 Stream &Stream::ThenNormalizeWithDimensions(
777 const dnn::NormalizeDescriptor &normalize_descriptor,
778 const dnn::BatchDescriptor &dimensions,
779 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
780 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
781 PARAM(output_data));
782
783 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
784 CheckError(dnn->DoNormalizeWithDimensions(
785 this, normalize_descriptor, dimensions, input_data, output_data));
786 } else {
787 SetErrorAndLogNoDnnSupport();
788 }
789 return *this;
790 }
791
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)792 Stream &Stream::ThenNormalizeBackwardWithDimensions(
793 const dnn::NormalizeDescriptor &normalize_descriptor,
794 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
795 const DeviceMemory<float> &normalized_data,
796 const DeviceMemory<float> &normalized_variable_gradient,
797 DeviceMemory<float> *raw_variable_gradient,
798 ScratchAllocator *workspace_allocator) {
799 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
800 PARAM(normalized_data), PARAM(normalized_variable_gradient),
801 PARAM(raw_variable_gradient), PARAM(workspace_allocator));
802
803 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
804 CheckError(dnn->DoNormalizeBackwardWithDimensions(
805 this, normalize_descriptor, dimensions, raw_data, normalized_data,
806 normalized_variable_gradient, raw_variable_gradient,
807 workspace_allocator));
808 } else {
809 SetErrorAndLogNoDnnSupport();
810 }
811 return *this;
812 }
813
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)814 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
815 const dnn::BatchDescriptor &dimensions,
816 const DeviceMemory<float> &input_data,
817 DeviceMemory<float> *output_data) {
818 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
819 output_data, /*options=*/0);
820 }
821
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)822 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
823 const dnn::BatchDescriptor &dimensions,
824 const DeviceMemory<float> &input_data,
825 DeviceMemory<float> *output_data,
826 uint64 options) {
827 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
828 PARAM(output_data), PARAM(options));
829
830 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
831 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
832 output_data, options));
833 } else {
834 SetErrorAndLogNoDnnSupport();
835 }
836 return *this;
837 }
838
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)839 Stream &Stream::ThenDepthConcatenate(
840 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
841 port::ArraySlice<const DeviceMemory<float> *> input_data,
842 DeviceMemory<float> *output_data) {
843 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
844
845 for (size_t i = 1; i < input_dimensions.size(); ++i) {
846 if (input_dimensions[i].count() != input_dimensions[0].count() ||
847 input_dimensions[i].height() != input_dimensions[0].height() ||
848 input_dimensions[i].width() != input_dimensions[0].width()) {
849 SetError();
850 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
851 << "input_dimensions[0]: " << input_dimensions[0].ToString()
852 << "input_dimensions[" << i
853 << "]: " << input_dimensions[i].ToString();
854 return *this;
855 }
856 }
857
858 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
859 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
860 output_data));
861 } else {
862 SetErrorAndLogNoDnnSupport();
863 }
864 return *this;
865 }
866
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)867 Stream &Stream::ThenSpaceConcatenate(
868 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
869 port::ArraySlice<const DeviceMemory<float> *> input_data,
870 DeviceMemory<float> *output_data,
871 dnn::SpaceConcatenateMode concat_direction) {
872 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
873
874 // Check that the input dimensions of all the other batches match those of the
875 // first batch.
876 for (size_t i = 1; i < input_dimensions.size(); ++i) {
877 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
878 (input_dimensions[i].count() != input_dimensions[0].count() ||
879 input_dimensions[i].height() != input_dimensions[0].height() ||
880 input_dimensions[i].feature_map_count() !=
881 input_dimensions[0].feature_map_count())) {
882 SetError();
883 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
884 << "input_dimensions[0]: " << input_dimensions[0].ToString()
885 << "input_dimensions[" << i
886 << "]: " << input_dimensions[i].ToString();
887 return *this;
888 }
889
890 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
891 (input_dimensions[i].count() != input_dimensions[0].count() ||
892 input_dimensions[i].width() != input_dimensions[0].width() ||
893 input_dimensions[i].feature_map_count() !=
894 input_dimensions[0].feature_map_count())) {
895 SetError();
896 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
897 << "input_dimensions[0]: " << input_dimensions[0].ToString()
898 << "input_dimensions[" << i
899 << "]: " << input_dimensions[i].ToString();
900 return *this;
901 }
902 }
903 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
904 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
905 output_data, concat_direction));
906 } else {
907 SetErrorAndLogNoDnnSupport();
908 }
909 return *this;
910 }
911
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)912 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
913 const DeviceMemory<float> &input_data,
914 const dnn::BatchDescriptor &output_dimensions,
915 DeviceMemory<float> *output_data) {
916 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
917 PARAM(output_dimensions), PARAM(output_data));
918
919 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
920 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
921 output_dimensions, output_data));
922 } else {
923 SetErrorAndLogNoDnnSupport();
924 }
925 return *this;
926 }
927
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)928 Stream &Stream::ThenDepthToSpace(
929 const dnn::BatchDescriptor &input_dimensions,
930 const DeviceMemory<float> &input_data,
931 const dnn::DepthToSpaceLayout &depth_to_space_layout,
932 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
933 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
934 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
935 PARAM(output_data));
936
937 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
938 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
939 depth_to_space_layout, sqrt_depth_reduction,
940 output_data));
941 } else {
942 SetErrorAndLogNoDnnSupport();
943 }
944 return *this;
945 }
946
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)947 Stream &Stream::ThenSpaceToDepth(
948 const dnn::BatchDescriptor &input_dimensions,
949 const DeviceMemory<float> &input_data,
950 const dnn::DepthToSpaceLayout &space_to_depth_layout,
951 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
952 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
953 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
954 PARAM(output_data));
955
956 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
957 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
958 space_to_depth_layout, sqrt_depth_increase,
959 output_data));
960 } else {
961 SetErrorAndLogNoDnnSupport();
962 }
963 return *this;
964 }
965
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)966 Stream &Stream::ThenElementwiseOperate(
967 dnn::ElementwiseOperation operation,
968 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
969 port::ArraySlice<const DeviceMemory<float> *> input_data,
970 const dnn::BatchDescriptor &output_dimensions,
971 DeviceMemory<float> *output_data) {
972 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
973 PARAM(output_dimensions), PARAM(output_data));
974
975 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
976 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
977 input_data, output_dimensions,
978 output_data));
979 } else {
980 SetErrorAndLogNoDnnSupport();
981 }
982 return *this;
983 }
984
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)985 Stream &Stream::ThenElementwiseOperateScaledQuantized(
986 dnn::ElementwiseOperation operation,
987 port::ArraySlice<int> input_multiplicands, int output_divisor,
988 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
989 port::ArraySlice<const DeviceMemory<float> *> input_data,
990 const dnn::BatchDescriptor &output_dimensions,
991 DeviceMemory<float> *output_data) {
992 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
993 PARAM(input_dimensions), PARAM(input_data),
994 PARAM(output_dimensions), PARAM(output_data));
995
996 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
997 CheckError(dnn->DoElementwiseOperateScaledQuantized(
998 this, operation, input_multiplicands, output_divisor, input_dimensions,
999 input_data, output_dimensions, output_data));
1000 } else {
1001 SetErrorAndLogNoDnnSupport();
1002 }
1003 return *this;
1004 }
1005
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)1006 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1007 const DeviceMemory<float> &input_data,
1008 int64_t left_pad, int64_t right_pad, int64_t top_pad,
1009 int64_t bottom_pad,
1010 DeviceMemory<float> *output_data) {
1011 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1012 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1013 PARAM(output_data));
1014
1015 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1016 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1017 top_pad, bottom_pad, output_data));
1018 } else {
1019 SetErrorAndLogNoDnnSupport();
1020 }
1021 return *this;
1022 }
1023
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)1024 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1025 const DeviceMemory<float> &input_data,
1026 int64_t left_trim, int64_t right_trim,
1027 int64_t top_trim, int64_t bottom_trim,
1028 DeviceMemory<float> *output_data) {
1029 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1030 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1031 PARAM(output_data));
1032
1033 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1034 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1035 right_trim, top_trim, bottom_trim, output_data));
1036 } else {
1037 SetErrorAndLogNoDnnSupport();
1038 }
1039 return *this;
1040 }
1041
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)1042 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1043 const DeviceMemory<float> &input_data,
1044 int64_t replicate_x, int64_t replicate_y,
1045 DeviceMemory<float> *output_data) {
1046 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1047 PARAM(replicate_y), PARAM(output_data));
1048
1049 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1050 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1051 replicate_y, output_data));
1052 } else {
1053 SetErrorAndLogNoDnnSupport();
1054 }
1055 return *this;
1056 }
1057
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1058 Stream &Stream::ThenMemcpyD2HQuantized(
1059 const DeviceMemory<float> &gpu_unquantized_src,
1060 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1061 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1062 PARAM(size));
1063
1064 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1065 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1066 host_dst, size));
1067 } else {
1068 SetErrorAndLogNoDnnSupport();
1069 }
1070 return *this;
1071 }
1072
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1073 Stream &Stream::ThenMemcpyH2DQuantized(
1074 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1075 DeviceMemory<float> *gpu_unquantized_dst) {
1076 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1077 PARAM(gpu_unquantized_dst));
1078
1079 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1080 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1081 gpu_unquantized_dst));
1082 } else {
1083 SetErrorAndLogNoDnnSupport();
1084 }
1085 return *this;
1086 }
1087
GetOrCreateSubStream()1088 Stream *Stream::GetOrCreateSubStream() {
1089 // Do not destroy bad streams when holding mu_ because ~Stream() may
1090 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1091 std::vector<std::unique_ptr<Stream>> bad_streams;
1092
1093 absl::MutexLock lock(&mu_);
1094
1095 // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1096 // we encounter along the way.
1097 for (size_t index = 0; index < sub_streams_.size();) {
1098 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1099 if (pair.second) {
1100 // The sub_stream is reusable.
1101 Stream *sub_stream = pair.first.get();
1102 if (sub_stream->ok()) {
1103 VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1104 << sub_stream->DebugStreamPointers();
1105 pair.second = false;
1106 return sub_stream;
1107 }
1108
1109 // The stream is reusable and not ok. Streams have a monotonic state
1110 // machine; the stream will remain in !ok forever. Swap it with the last
1111 // stream and pop it off.
1112 const int64_t last = sub_streams_.size() - 1;
1113 if (index != last) {
1114 std::swap(pair, sub_streams_[last]);
1115 }
1116 bad_streams.push_back(std::move(sub_streams_.back().first));
1117 sub_streams_.pop_back();
1118 VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1119 << sub_stream->DebugStreamPointers();
1120 } else {
1121 // The sub_stream is not reusable, move on to the next one.
1122 ++index;
1123 }
1124 }
1125
1126 // No streams are reusable; create a new stream.
1127 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1128 false);
1129 Stream *sub_stream = sub_streams_.back().first.get();
1130 sub_stream->Init();
1131 if (!sub_stream->ok()) {
1132 LOG(ERROR) << "sub-stream failed to be initialized";
1133 }
1134 VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1135 << sub_stream->DebugStreamPointers();
1136
1137 return sub_stream;
1138 }
1139
ReturnSubStream(Stream * sub_stream)1140 void Stream::ReturnSubStream(Stream *sub_stream) {
1141 // Do not destroy bad streams when holding mu_ because ~Stream() may
1142 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1143 std::unique_ptr<Stream> bad_stream;
1144
1145 absl::MutexLock lock(&mu_);
1146
1147 // Look for the sub-stream.
1148 for (int64_t index = 0, end = sub_streams_.size(); index < end; ++index) {
1149 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1150 if (pair.first.get() != sub_stream) {
1151 continue;
1152 }
1153
1154 // Found the sub_stream.
1155 if (sub_stream->ok()) {
1156 VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1157 << sub_stream->DebugStreamPointers();
1158 pair.second = true;
1159 } else {
1160 // The returned stream is not ok. Streams have a monotonic state
1161 // machine; the stream will remain in !ok forever. Swap it with the last
1162 // stream and pop it off.
1163 VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1164 << sub_stream->DebugStreamPointers();
1165 const int64_t last = sub_streams_.size() - 1;
1166 if (index != last) {
1167 std::swap(pair, sub_streams_[last]);
1168 }
1169 std::swap(bad_stream, sub_streams_.back().first);
1170 sub_streams_.pop_back();
1171 }
1172 return;
1173 }
1174
1175 LOG(FATAL) << DebugStreamPointers()
1176 << " did not create the returned sub-stream "
1177 << sub_stream->DebugStreamPointers();
1178 }
1179
ThenStartTimer(Timer * t)1180 Stream &Stream::ThenStartTimer(Timer *t) {
1181 VLOG_CALL(PARAM(t));
1182
1183 CheckError(parent_->StartTimer(this, t));
1184 return *this;
1185 }
1186
ThenStopTimer(Timer * t)1187 Stream &Stream::ThenStopTimer(Timer *t) {
1188 VLOG_CALL(PARAM(t));
1189
1190 CheckError(parent_->StopTimer(this, t));
1191 return *this;
1192 }
1193
ThenWaitFor(Stream * other)1194 Stream &Stream::ThenWaitFor(Stream *other) {
1195 VLOG_CALL(PARAM(other));
1196
1197 CHECK(this != other) << "stream cannot wait for itself";
1198 if (ok() && other->ok()) {
1199 CheckError(parent_->CreateStreamDependency(this, other));
1200 } else {
1201 SetError();
1202 LOG(INFO) << DebugStreamPointers() << " did not wait for "
1203 << other->DebugStreamPointers();
1204 }
1205 return *this;
1206 }
1207
ThenWaitFor(Event * event)1208 Stream &Stream::ThenWaitFor(Event *event) {
1209 VLOG_CALL(PARAM(event));
1210
1211 if (ok()) {
1212 port::Status status = parent_->WaitForEvent(this, event);
1213 if (!status.ok()) {
1214 LOG(ERROR) << "Error waiting for event in stream: "
1215 << status.error_message()
1216 << "; not marking stream as bad, as the Event object may be "
1217 << "at fault. Monitor for further errors.";
1218 }
1219 } else {
1220 LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1221 }
1222 return *this;
1223 }
1224
1225 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1226 // functions and logs for errors.
1227 template <typename... Args>
1228 struct ThenBlasImpl {
1229 // blas_func is the DoBlasXXX member function pointer, and args are its
1230 // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1231 Stream &operator()(Stream *stream,
1232 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1233 Args... args) {
1234 return Run(stream, blas_func, /*record_error=*/true, args...);
1235 }
1236
1237 // Like operator(), but only calls stream->CheckError() if record_error is
1238 // true.
1239 Stream &Run(Stream *stream,
1240 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1241 bool record_error, Args... args);
1242 };
1243
1244 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1245 Stream &ThenBlasImpl<Args...>::Run(
1246 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1247 bool record_error, Args... args) {
1248 if (stream->ok()) {
1249 bool ok;
1250 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1251 ok = (blas->*blas_func)(stream, args...);
1252 } else {
1253 LOG(WARNING)
1254 << "attempting to perform BLAS operation using StreamExecutor "
1255 "without BLAS support";
1256 ok = false;
1257 }
1258 if (record_error) {
1259 stream->CheckError(ok);
1260 }
1261 }
1262 return *stream;
1263 }
1264
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1265 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1266 int incx, DeviceMemory<float> *result) {
1267 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1268
1269 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1270 impl;
1271 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1272 result);
1273 }
1274
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1275 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1276 int incx, DeviceMemory<double> *result) {
1277 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1278
1279 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1280 DeviceMemory<double> *>
1281 impl;
1282 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1283 result);
1284 }
1285
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1286 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1287 const DeviceMemory<std::complex<float>> &x,
1288 int incx, DeviceMemory<float> *result) {
1289 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1290
1291 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1292 DeviceMemory<float> *>
1293 impl;
1294 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1295 result);
1296 }
1297
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1298 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1299 const DeviceMemory<std::complex<double>> &x,
1300 int incx, DeviceMemory<double> *result) {
1301 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1302
1303 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1304 DeviceMemory<double> *>
1305 impl;
1306 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1307 result);
1308 }
1309
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1310 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1311 const DeviceMemory<float> &x, int incx,
1312 DeviceMemory<float> *y, int incy) {
1313 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1314 PARAM(incy));
1315
1316 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1317 DeviceMemory<float> *, int>
1318 impl;
1319 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1320 y, incy);
1321 }
1322
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1323 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1324 const DeviceMemory<double> &x, int incx,
1325 DeviceMemory<double> *y, int incy) {
1326 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1327 PARAM(incy));
1328
1329 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1330 DeviceMemory<double> *, int>
1331 impl;
1332 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1333 y, incy);
1334 }
1335
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1336 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1337 const DeviceMemory<std::complex<float>> &x,
1338 int incx, DeviceMemory<std::complex<float>> *y,
1339 int incy) {
1340 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1341 PARAM(incy));
1342
1343 ThenBlasImpl<uint64, std::complex<float>,
1344 const DeviceMemory<std::complex<float>> &, int,
1345 DeviceMemory<std::complex<float>> *, int>
1346 impl;
1347 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1348 y, incy);
1349 }
1350
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1351 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1352 const DeviceMemory<std::complex<double>> &x,
1353 int incx, DeviceMemory<std::complex<double>> *y,
1354 int incy) {
1355 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1356 PARAM(incy));
1357
1358 ThenBlasImpl<uint64, std::complex<double>,
1359 const DeviceMemory<std::complex<double>> &, int,
1360 DeviceMemory<std::complex<double>> *, int>
1361 impl;
1362 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1363 y, incy);
1364 }
1365
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1366 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
1367 int incx, DeviceMemory<float> *y, int incy) {
1368 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1369
1370 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
1371 int>
1372 impl;
1373 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1374 incy);
1375 }
1376
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1377 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
1378 int incx, DeviceMemory<double> *y, int incy) {
1379 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1380
1381 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1382 DeviceMemory<double> *, int>
1383 impl;
1384 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1385 incy);
1386 }
1387
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1388 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1389 const DeviceMemory<std::complex<float>> &x,
1390 int incx, DeviceMemory<std::complex<float>> *y,
1391 int incy) {
1392 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1393
1394 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1395 DeviceMemory<std::complex<float>> *, int>
1396 impl;
1397 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1398 incy);
1399 }
1400
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1401 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1402 const DeviceMemory<std::complex<double>> &x,
1403 int incx, DeviceMemory<std::complex<double>> *y,
1404 int incy) {
1405 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1406
1407 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1408 DeviceMemory<std::complex<double>> *, int>
1409 impl;
1410 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1411 incy);
1412 }
1413
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)1414 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
1415 int incx, const DeviceMemory<float> &y, int incy,
1416 DeviceMemory<float> *result) {
1417 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1418 PARAM(result));
1419
1420 ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
1421 const DeviceMemory<float> &, int, DeviceMemory<float> *>
1422 impl;
1423 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1424 result);
1425 }
1426
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)1427 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
1428 int incx, const DeviceMemory<double> &y, int incy,
1429 DeviceMemory<double> *result) {
1430 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1431 PARAM(result));
1432
1433 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1434 const DeviceMemory<double> &, int, DeviceMemory<double> *>
1435 impl;
1436 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1437 result);
1438 }
1439
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1440 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1441 const DeviceMemory<std::complex<float>> &x,
1442 int incx,
1443 const DeviceMemory<std::complex<float>> &y,
1444 int incy,
1445 DeviceMemory<std::complex<float>> *result) {
1446 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1447 PARAM(result));
1448
1449 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1450 const DeviceMemory<std::complex<float>> &, int,
1451 DeviceMemory<std::complex<float>> *>
1452 impl;
1453 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1454 incy, result);
1455 }
1456
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1457 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1458 const DeviceMemory<std::complex<double>> &x,
1459 int incx,
1460 const DeviceMemory<std::complex<double>> &y,
1461 int incy,
1462 DeviceMemory<std::complex<double>> *result) {
1463 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1464 PARAM(result));
1465
1466 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1467 const DeviceMemory<std::complex<double>> &, int,
1468 DeviceMemory<std::complex<double>> *>
1469 impl;
1470 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1471 incy, result);
1472 }
1473
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1474 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1475 const DeviceMemory<std::complex<float>> &x,
1476 int incx,
1477 const DeviceMemory<std::complex<float>> &y,
1478 int incy,
1479 DeviceMemory<std::complex<float>> *result) {
1480 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1481 PARAM(result));
1482
1483 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1484 const DeviceMemory<std::complex<float>> &, int,
1485 DeviceMemory<std::complex<float>> *>
1486 impl;
1487 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1488 incy, result);
1489 }
1490
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1491 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1492 const DeviceMemory<std::complex<double>> &x,
1493 int incx,
1494 const DeviceMemory<std::complex<double>> &y,
1495 int incy,
1496 DeviceMemory<std::complex<double>> *result) {
1497 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1498 PARAM(result));
1499
1500 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1501 const DeviceMemory<std::complex<double>> &, int,
1502 DeviceMemory<std::complex<double>> *>
1503 impl;
1504 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1505 incy, result);
1506 }
1507
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1508 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
1509 int incx, DeviceMemory<float> *result) {
1510 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1511
1512 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1513 impl;
1514 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1515 result);
1516 }
1517
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1518 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
1519 int incx, DeviceMemory<double> *result) {
1520 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1521
1522 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1523 DeviceMemory<double> *>
1524 impl;
1525 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1526 result);
1527 }
1528
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1529 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1530 const DeviceMemory<std::complex<float>> &x,
1531 int incx, DeviceMemory<float> *result) {
1532 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1533
1534 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1535 DeviceMemory<float> *>
1536 impl;
1537 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1538 result);
1539 }
1540
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1541 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1542 const DeviceMemory<std::complex<double>> &x,
1543 int incx, DeviceMemory<double> *result) {
1544 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1545
1546 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1547 DeviceMemory<double> *>
1548 impl;
1549 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1550 result);
1551 }
1552
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)1553 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
1554 DeviceMemory<float> *y, int incy, float c,
1555 float s) {
1556 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1557 PARAM(c), PARAM(s));
1558
1559 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1560 float, float>
1561 impl;
1562 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1563 c, s);
1564 }
1565
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)1566 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
1567 int incx, DeviceMemory<double> *y, int incy,
1568 double c, double s) {
1569 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1570 PARAM(c), PARAM(s));
1571
1572 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1573 double, double>
1574 impl;
1575 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1576 c, s);
1577 }
1578
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)1579 Stream &Stream::ThenBlasRot(uint64 elem_count,
1580 DeviceMemory<std::complex<float>> *x, int incx,
1581 DeviceMemory<std::complex<float>> *y, int incy,
1582 float c, float s) {
1583 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1584 PARAM(c), PARAM(s));
1585
1586 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1587 DeviceMemory<std::complex<float>> *, int, float, float>
1588 impl;
1589 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1590 c, s);
1591 }
1592
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)1593 Stream &Stream::ThenBlasRot(uint64 elem_count,
1594 DeviceMemory<std::complex<double>> *x, int incx,
1595 DeviceMemory<std::complex<double>> *y, int incy,
1596 double c, double s) {
1597 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1598 PARAM(c), PARAM(s));
1599
1600 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1601 DeviceMemory<std::complex<double>> *, int, double, double>
1602 impl;
1603 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1604 c, s);
1605 }
1606
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)1607 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
1608 DeviceMemory<float> *c, DeviceMemory<float> *s) {
1609 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1610
1611 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1612 DeviceMemory<float> *, DeviceMemory<float> *>
1613 impl;
1614 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1615 }
1616
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)1617 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
1618 DeviceMemory<double> *c, DeviceMemory<double> *s) {
1619 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1620
1621 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1622 DeviceMemory<double> *, DeviceMemory<double> *>
1623 impl;
1624 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1625 }
1626
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)1627 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
1628 DeviceMemory<std::complex<float>> *b,
1629 DeviceMemory<float> *c,
1630 DeviceMemory<std::complex<float>> *s) {
1631 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1632
1633 ThenBlasImpl<DeviceMemory<std::complex<float>> *,
1634 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
1635 DeviceMemory<std::complex<float>> *>
1636 impl;
1637 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1638 }
1639
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)1640 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
1641 DeviceMemory<std::complex<double>> *b,
1642 DeviceMemory<double> *c,
1643 DeviceMemory<std::complex<double>> *s) {
1644 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1645
1646 ThenBlasImpl<DeviceMemory<std::complex<double>> *,
1647 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
1648 DeviceMemory<std::complex<double>> *>
1649 impl;
1650 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1651 }
1652
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)1653 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
1654 int incx, DeviceMemory<float> *y, int incy,
1655 const DeviceMemory<float> ¶m) {
1656 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1657 PARAM(param));
1658
1659 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1660 const DeviceMemory<float> &>
1661 impl;
1662 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1663 incy, param);
1664 }
1665
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)1666 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
1667 int incx, DeviceMemory<double> *y, int incy,
1668 const DeviceMemory<double> ¶m) {
1669 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1670 PARAM(param));
1671
1672 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1673 const DeviceMemory<double> &>
1674 impl;
1675 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1676 incy, param);
1677 }
1678
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)1679 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
1680 DeviceMemory<float> *x1,
1681 const DeviceMemory<float> &y1,
1682 DeviceMemory<float> *param) {
1683 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1684
1685 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1686 DeviceMemory<float> *, const DeviceMemory<float> &,
1687 DeviceMemory<float> *>
1688 impl;
1689 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1690 }
1691
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)1692 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
1693 DeviceMemory<double> *d2,
1694 DeviceMemory<double> *x1,
1695 const DeviceMemory<double> &y1,
1696 DeviceMemory<double> *param) {
1697 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1698
1699 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1700 DeviceMemory<double> *, const DeviceMemory<double> &,
1701 DeviceMemory<double> *>
1702 impl;
1703 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1704 }
1705
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)1706 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1707 DeviceMemory<float> *x, int incx) {
1708 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1709
1710 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
1711 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1712 }
1713
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)1714 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1715 DeviceMemory<double> *x, int incx) {
1716 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1717
1718 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
1719 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1720 }
1721
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1722 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1723 DeviceMemory<std::complex<float>> *x, int incx) {
1724 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1725
1726 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
1727 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1728 }
1729
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1730 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1731 DeviceMemory<std::complex<double>> *x, int incx) {
1732 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1733
1734 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
1735 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1736 }
1737
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1738 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
1739 DeviceMemory<std::complex<float>> *x, int incx) {
1740 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1741
1742 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
1743 int>
1744 impl;
1745 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1746 }
1747
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1748 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
1749 DeviceMemory<std::complex<double>> *x, int incx) {
1750 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1751
1752 ThenBlasImpl<uint64, std::complex<double>,
1753 DeviceMemory<std::complex<double>> *, int>
1754 impl;
1755 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1756 }
1757
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)1758 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
1759 int incx, DeviceMemory<float> *y, int incy) {
1760 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1761
1762 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
1763 impl;
1764 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1765 incy);
1766 }
1767
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)1768 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
1769 int incx, DeviceMemory<double> *y, int incy) {
1770 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1771
1772 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
1773 impl;
1774 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1775 incy);
1776 }
1777
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1778 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1779 DeviceMemory<std::complex<float>> *x, int incx,
1780 DeviceMemory<std::complex<float>> *y, int incy) {
1781 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1782
1783 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1784 DeviceMemory<std::complex<float>> *, int>
1785 impl;
1786 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1787 incy);
1788 }
1789
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1790 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1791 DeviceMemory<std::complex<double>> *x, int incx,
1792 DeviceMemory<std::complex<double>> *y, int incy) {
1793 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1794
1795 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1796 DeviceMemory<std::complex<double>> *, int>
1797 impl;
1798 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1799 incy);
1800 }
1801
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)1802 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
1803 int incx, DeviceMemory<int> *result) {
1804 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1805
1806 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
1807 impl;
1808 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1809 result);
1810 }
1811
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)1812 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
1813 int incx, DeviceMemory<int> *result) {
1814 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1815
1816 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
1817 impl;
1818 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1819 result);
1820 }
1821
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)1822 Stream &Stream::ThenBlasIamax(uint64 elem_count,
1823 const DeviceMemory<std::complex<float>> &x,
1824 int incx, DeviceMemory<int> *result) {
1825 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1826
1827 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1828 DeviceMemory<int> *>
1829 impl;
1830 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1831 result);
1832 }
1833
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)1834 Stream &Stream::ThenBlasIamax(uint64 elem_count,
1835 const DeviceMemory<std::complex<double>> &x,
1836 int incx, DeviceMemory<int> *result) {
1837 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1838
1839 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1840 DeviceMemory<int> *>
1841 impl;
1842 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1843 result);
1844 }
1845
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)1846 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
1847 int incx, DeviceMemory<int> *result) {
1848 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1849
1850 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
1851 impl;
1852 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1853 result);
1854 }
1855
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)1856 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
1857 int incx, DeviceMemory<int> *result) {
1858 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1859
1860 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
1861 impl;
1862 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1863 result);
1864 }
1865
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)1866 Stream &Stream::ThenBlasIamin(uint64 elem_count,
1867 const DeviceMemory<std::complex<float>> &x,
1868 int incx, DeviceMemory<int> *result) {
1869 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1870
1871 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1872 DeviceMemory<int> *>
1873 impl;
1874 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1875 result);
1876 }
1877
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)1878 Stream &Stream::ThenBlasIamin(uint64 elem_count,
1879 const DeviceMemory<std::complex<double>> &x,
1880 int incx, DeviceMemory<int> *result) {
1881 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1882
1883 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1884 DeviceMemory<int> *>
1885 impl;
1886 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1887 result);
1888 }
1889
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1890 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1891 uint64 kl, uint64 ku, float alpha,
1892 const DeviceMemory<float> &a, int lda,
1893 const DeviceMemory<float> &x, int incx, float beta,
1894 DeviceMemory<float> *y, int incy) {
1895 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1896 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1897 PARAM(beta), PARAM(y), PARAM(incy));
1898
1899 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
1900 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1901 int, float, DeviceMemory<float> *, int>
1902 impl;
1903 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1904 a, lda, x, incx, beta, y, incy);
1905 }
1906
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1907 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1908 uint64 kl, uint64 ku, double alpha,
1909 const DeviceMemory<double> &a, int lda,
1910 const DeviceMemory<double> &x, int incx,
1911 double beta, DeviceMemory<double> *y, int incy) {
1912 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1913 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1914 PARAM(beta), PARAM(y), PARAM(incy));
1915
1916 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
1917 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1918 int, double, DeviceMemory<double> *, int>
1919 impl;
1920 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1921 a, lda, x, incx, beta, y, incy);
1922 }
1923
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1924 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1925 uint64 kl, uint64 ku, std::complex<float> alpha,
1926 const DeviceMemory<std::complex<float>> &a,
1927 int lda,
1928 const DeviceMemory<std::complex<float>> &x,
1929 int incx, std::complex<float> beta,
1930 DeviceMemory<std::complex<float>> *y, int incy) {
1931 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1932 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1933 PARAM(beta), PARAM(y), PARAM(incy));
1934
1935 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
1936 std::complex<float>, const DeviceMemory<std::complex<float>> &,
1937 int, const DeviceMemory<std::complex<float>> &, int,
1938 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
1939 impl;
1940 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1941 a, lda, x, incx, beta, y, incy);
1942 }
1943
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1944 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1945 uint64 kl, uint64 ku, std::complex<double> alpha,
1946 const DeviceMemory<std::complex<double>> &a,
1947 int lda,
1948 const DeviceMemory<std::complex<double>> &x,
1949 int incx, std::complex<double> beta,
1950 DeviceMemory<std::complex<double>> *y, int incy) {
1951 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1952 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1953 PARAM(beta), PARAM(y), PARAM(incy));
1954
1955 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
1956 std::complex<double>, const DeviceMemory<std::complex<double>> &,
1957 int, const DeviceMemory<std::complex<double>> &, int,
1958 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
1959 impl;
1960 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1961 a, lda, x, incx, beta, y, incy);
1962 }
1963
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1964 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1965 float alpha, const DeviceMemory<float> &a, int lda,
1966 const DeviceMemory<float> &x, int incx, float beta,
1967 DeviceMemory<float> *y, int incy) {
1968 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1969 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1970 PARAM(incy));
1971
1972 ThenBlasImpl<blas::Transpose, uint64, uint64, float,
1973 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1974 int, float, DeviceMemory<float> *, int>
1975 impl;
1976 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1977 x, incx, beta, y, incy);
1978 }
1979
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1980 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1981 double alpha, const DeviceMemory<double> &a,
1982 int lda, const DeviceMemory<double> &x, int incx,
1983 double beta, DeviceMemory<double> *y, int incy) {
1984 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1985 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1986 PARAM(incy));
1987
1988 ThenBlasImpl<blas::Transpose, uint64, uint64, double,
1989 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1990 int, double, DeviceMemory<double> *, int>
1991 impl;
1992 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1993 x, incx, beta, y, incy);
1994 }
1995
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1996 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1997 std::complex<float> alpha,
1998 const DeviceMemory<std::complex<float>> &a,
1999 int lda,
2000 const DeviceMemory<std::complex<float>> &x,
2001 int incx, std::complex<float> beta,
2002 DeviceMemory<std::complex<float>> *y, int incy) {
2003 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2004 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2005 PARAM(incy));
2006
2007 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2008 const DeviceMemory<std::complex<float>> &, int,
2009 const DeviceMemory<std::complex<float>> &, int,
2010 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2011 impl;
2012 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2013 x, incx, beta, y, incy);
2014 }
2015
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2016 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2017 std::complex<double> alpha,
2018 const DeviceMemory<std::complex<double>> &a,
2019 int lda,
2020 const DeviceMemory<std::complex<double>> &x,
2021 int incx, std::complex<double> beta,
2022 DeviceMemory<std::complex<double>> *y, int incy) {
2023 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2024 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2025 PARAM(incy));
2026
2027 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2028 const DeviceMemory<std::complex<double>> &, int,
2029 const DeviceMemory<std::complex<double>> &, int,
2030 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2031 impl;
2032 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2033 x, incx, beta, y, incy);
2034 }
2035
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2036 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2037 const DeviceMemory<float> &x, int incx,
2038 const DeviceMemory<float> &y, int incy,
2039 DeviceMemory<float> *a, int lda) {
2040 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2041 PARAM(incy), PARAM(a), PARAM(lda));
2042
2043 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2044 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2045 impl;
2046 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2047 incy, a, lda);
2048 }
2049
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2050 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2051 const DeviceMemory<double> &x, int incx,
2052 const DeviceMemory<double> &y, int incy,
2053 DeviceMemory<double> *a, int lda) {
2054 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2055 PARAM(incy), PARAM(a), PARAM(lda));
2056
2057 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2058 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2059 impl;
2060 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2061 incy, a, lda);
2062 }
2063
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2064 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2065 const DeviceMemory<std::complex<float>> &x,
2066 int incx,
2067 const DeviceMemory<std::complex<float>> &y,
2068 int incy, DeviceMemory<std::complex<float>> *a,
2069 int lda) {
2070 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2071 PARAM(incy), PARAM(a), PARAM(lda));
2072
2073 ThenBlasImpl<uint64, uint64, std::complex<float>,
2074 const DeviceMemory<std::complex<float>> &, int,
2075 const DeviceMemory<std::complex<float>> &, int,
2076 DeviceMemory<std::complex<float>> *, int>
2077 impl;
2078 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2079 incy, a, lda);
2080 }
2081
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2082 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2083 const DeviceMemory<std::complex<double>> &x,
2084 int incx,
2085 const DeviceMemory<std::complex<double>> &y,
2086 int incy, DeviceMemory<std::complex<double>> *a,
2087 int lda) {
2088 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2089 PARAM(incy), PARAM(a), PARAM(lda));
2090
2091 ThenBlasImpl<uint64, uint64, std::complex<double>,
2092 const DeviceMemory<std::complex<double>> &, int,
2093 const DeviceMemory<std::complex<double>> &, int,
2094 DeviceMemory<std::complex<double>> *, int>
2095 impl;
2096 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2097 incy, a, lda);
2098 }
2099
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2100 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2101 const DeviceMemory<std::complex<float>> &x,
2102 int incx,
2103 const DeviceMemory<std::complex<float>> &y,
2104 int incy, DeviceMemory<std::complex<float>> *a,
2105 int lda) {
2106 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2107 PARAM(incy), PARAM(a), PARAM(lda));
2108
2109 ThenBlasImpl<uint64, uint64, std::complex<float>,
2110 const DeviceMemory<std::complex<float>> &, int,
2111 const DeviceMemory<std::complex<float>> &, int,
2112 DeviceMemory<std::complex<float>> *, int>
2113 impl;
2114 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2115 incy, a, lda);
2116 }
2117
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2118 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2119 const DeviceMemory<std::complex<double>> &x,
2120 int incx,
2121 const DeviceMemory<std::complex<double>> &y,
2122 int incy, DeviceMemory<std::complex<double>> *a,
2123 int lda) {
2124 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2125 PARAM(incy), PARAM(a), PARAM(lda));
2126
2127 ThenBlasImpl<uint64, uint64, std::complex<double>,
2128 const DeviceMemory<std::complex<double>> &, int,
2129 const DeviceMemory<std::complex<double>> &, int,
2130 DeviceMemory<std::complex<double>> *, int>
2131 impl;
2132 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2133 incy, a, lda);
2134 }
2135
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2136 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2137 std::complex<float> alpha,
2138 const DeviceMemory<std::complex<float>> &a,
2139 int lda,
2140 const DeviceMemory<std::complex<float>> &x,
2141 int incx, std::complex<float> beta,
2142 DeviceMemory<std::complex<float>> *y, int incy) {
2143 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2144 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2145
2146 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2147 const DeviceMemory<std::complex<float>> &, int,
2148 const DeviceMemory<std::complex<float>> &, int,
2149 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2150 impl;
2151 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2152 x, incx, beta, y, incy);
2153 }
2154
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2155 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2156 std::complex<double> alpha,
2157 const DeviceMemory<std::complex<double>> &a,
2158 int lda,
2159 const DeviceMemory<std::complex<double>> &x,
2160 int incx, std::complex<double> beta,
2161 DeviceMemory<std::complex<double>> *y, int incy) {
2162 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2163 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2164
2165 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2166 const DeviceMemory<std::complex<double>> &, int,
2167 const DeviceMemory<std::complex<double>> &, int,
2168 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2169 impl;
2170 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2171 x, incx, beta, y, incy);
2172 }
2173
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2174 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2175 std::complex<float> alpha,
2176 const DeviceMemory<std::complex<float>> &a,
2177 int lda,
2178 const DeviceMemory<std::complex<float>> &x,
2179 int incx, std::complex<float> beta,
2180 DeviceMemory<std::complex<float>> *y, int incy) {
2181 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2182 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2183
2184 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2185 const DeviceMemory<std::complex<float>> &, int,
2186 const DeviceMemory<std::complex<float>> &, int,
2187 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2188 impl;
2189 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2190 incx, beta, y, incy);
2191 }
2192
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2193 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2194 std::complex<double> alpha,
2195 const DeviceMemory<std::complex<double>> &a,
2196 int lda,
2197 const DeviceMemory<std::complex<double>> &x,
2198 int incx, std::complex<double> beta,
2199 DeviceMemory<std::complex<double>> *y, int incy) {
2200 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2201 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2202
2203 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2204 const DeviceMemory<std::complex<double>> &, int,
2205 const DeviceMemory<std::complex<double>> &, int,
2206 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2207 impl;
2208 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2209 incx, beta, y, incy);
2210 }
2211
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2212 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2213 const DeviceMemory<std::complex<float>> &x,
2214 int incx, DeviceMemory<std::complex<float>> *a,
2215 int lda) {
2216 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2217 PARAM(a), PARAM(lda));
2218
2219 ThenBlasImpl<blas::UpperLower, uint64, float,
2220 const DeviceMemory<std::complex<float>> &, int,
2221 DeviceMemory<std::complex<float>> *, int>
2222 impl;
2223 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2224 lda);
2225 }
2226
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2227 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2228 const DeviceMemory<std::complex<double>> &x,
2229 int incx, DeviceMemory<std::complex<double>> *a,
2230 int lda) {
2231 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2232 PARAM(a), PARAM(lda));
2233
2234 ThenBlasImpl<blas::UpperLower, uint64, double,
2235 const DeviceMemory<std::complex<double>> &, int,
2236 DeviceMemory<std::complex<double>> *, int>
2237 impl;
2238 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2239 lda);
2240 }
2241
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2242 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2243 std::complex<float> alpha,
2244 const DeviceMemory<std::complex<float>> &x,
2245 int incx,
2246 const DeviceMemory<std::complex<float>> &y,
2247 int incy, DeviceMemory<std::complex<float>> *a,
2248 int lda) {
2249 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2250 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2251
2252 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2253 const DeviceMemory<std::complex<float>> &, int,
2254 const DeviceMemory<std::complex<float>> &, int,
2255 DeviceMemory<std::complex<float>> *, int>
2256 impl;
2257 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2258 incy, a, lda);
2259 }
2260
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2261 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2262 std::complex<double> alpha,
2263 const DeviceMemory<std::complex<double>> &x,
2264 int incx,
2265 const DeviceMemory<std::complex<double>> &y,
2266 int incy, DeviceMemory<std::complex<double>> *a,
2267 int lda) {
2268 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2269 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2270
2271 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2272 const DeviceMemory<std::complex<double>> &, int,
2273 const DeviceMemory<std::complex<double>> &, int,
2274 DeviceMemory<std::complex<double>> *, int>
2275 impl;
2276 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2277 incy, a, lda);
2278 }
2279
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2280 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2281 std::complex<float> alpha,
2282 const DeviceMemory<std::complex<float>> &ap,
2283 const DeviceMemory<std::complex<float>> &x,
2284 int incx, std::complex<float> beta,
2285 DeviceMemory<std::complex<float>> *y, int incy) {
2286 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2287 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2288
2289 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2290 const DeviceMemory<std::complex<float>> &,
2291 const DeviceMemory<std::complex<float>> &, int,
2292 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2293 impl;
2294 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2295 beta, y, incy);
2296 }
2297
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2298 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2299 std::complex<double> alpha,
2300 const DeviceMemory<std::complex<double>> &ap,
2301 const DeviceMemory<std::complex<double>> &x,
2302 int incx, std::complex<double> beta,
2303 DeviceMemory<std::complex<double>> *y, int incy) {
2304 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2305 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2306
2307 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2308 const DeviceMemory<std::complex<double>> &,
2309 const DeviceMemory<std::complex<double>> &, int,
2310 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2311 impl;
2312 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2313 beta, y, incy);
2314 }
2315
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2316 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2317 const DeviceMemory<std::complex<float>> &x,
2318 int incx, DeviceMemory<std::complex<float>> *ap) {
2319 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2320 PARAM(ap));
2321
2322 ThenBlasImpl<blas::UpperLower, uint64, float,
2323 const DeviceMemory<std::complex<float>> &, int,
2324 DeviceMemory<std::complex<float>> *>
2325 impl;
2326 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2327 }
2328
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2329 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2330 const DeviceMemory<std::complex<double>> &x,
2331 int incx, DeviceMemory<std::complex<double>> *ap) {
2332 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2333 PARAM(ap));
2334
2335 ThenBlasImpl<blas::UpperLower, uint64, double,
2336 const DeviceMemory<std::complex<double>> &, int,
2337 DeviceMemory<std::complex<double>> *>
2338 impl;
2339 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2340 }
2341
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2342 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2343 std::complex<float> alpha,
2344 const DeviceMemory<std::complex<float>> &x,
2345 int incx,
2346 const DeviceMemory<std::complex<float>> &y,
2347 int incy, DeviceMemory<std::complex<float>> *ap) {
2348 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2349 PARAM(y), PARAM(incy), PARAM(ap));
2350
2351 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2352 const DeviceMemory<std::complex<float>> &, int,
2353 const DeviceMemory<std::complex<float>> &, int,
2354 DeviceMemory<std::complex<float>> *>
2355 impl;
2356 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2357 incy, ap);
2358 }
2359
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2360 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2361 std::complex<double> alpha,
2362 const DeviceMemory<std::complex<double>> &x,
2363 int incx,
2364 const DeviceMemory<std::complex<double>> &y,
2365 int incy, DeviceMemory<std::complex<double>> *ap) {
2366 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2367 PARAM(y), PARAM(incy), PARAM(ap));
2368
2369 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2370 const DeviceMemory<std::complex<double>> &, int,
2371 const DeviceMemory<std::complex<double>> &, int,
2372 DeviceMemory<std::complex<double>> *>
2373 impl;
2374 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2375 incy, ap);
2376 }
2377
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2378 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2379 float alpha, const DeviceMemory<float> &a, int lda,
2380 const DeviceMemory<float> &x, int incx, float beta,
2381 DeviceMemory<float> *y, int incy) {
2382 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2383 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2384
2385 ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2386 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2387 int, float, DeviceMemory<float> *, int>
2388 impl;
2389 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2390 x, incx, beta, y, incy);
2391 }
2392
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2393 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2394 double alpha, const DeviceMemory<double> &a,
2395 int lda, const DeviceMemory<double> &x, int incx,
2396 double beta, DeviceMemory<double> *y, int incy) {
2397 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2398 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2399
2400 ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2401 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2402 int, double, DeviceMemory<double> *, int>
2403 impl;
2404 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2405 x, incx, beta, y, incy);
2406 }
2407
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2408 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2409 const DeviceMemory<float> &ap,
2410 const DeviceMemory<float> &x, int incx, float beta,
2411 DeviceMemory<float> *y, int incy) {
2412 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2413 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2414
2415 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2416 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
2417 int>
2418 impl;
2419 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2420 beta, y, incy);
2421 }
2422
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2423 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
2424 const DeviceMemory<double> &ap,
2425 const DeviceMemory<double> &x, int incx,
2426 double beta, DeviceMemory<double> *y, int incy) {
2427 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2428 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2429
2430 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2431 const DeviceMemory<double> &, int, double,
2432 DeviceMemory<double> *, int>
2433 impl;
2434 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2435 beta, y, incy);
2436 }
2437
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)2438 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
2439 const DeviceMemory<float> &x, int incx,
2440 DeviceMemory<float> *ap) {
2441 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2442 PARAM(ap));
2443
2444 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2445 int, DeviceMemory<float> *>
2446 impl;
2447 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2448 }
2449
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)2450 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
2451 const DeviceMemory<double> &x, int incx,
2452 DeviceMemory<double> *ap) {
2453 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2454 PARAM(ap));
2455
2456 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2457 int, DeviceMemory<double> *>
2458 impl;
2459 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2460 }
2461
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)2462 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
2463 const DeviceMemory<float> &x, int incx,
2464 const DeviceMemory<float> &y, int incy,
2465 DeviceMemory<float> *ap) {
2466 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2467 PARAM(y), PARAM(incy), PARAM(ap));
2468
2469 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2470 int, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2471 impl;
2472 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2473 incy, ap);
2474 }
2475
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)2476 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
2477 const DeviceMemory<double> &x, int incx,
2478 const DeviceMemory<double> &y, int incy,
2479 DeviceMemory<double> *ap) {
2480 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2481 PARAM(y), PARAM(incy), PARAM(ap));
2482
2483 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2484 int, const DeviceMemory<double> &, int, DeviceMemory<double> *>
2485 impl;
2486 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2487 incy, ap);
2488 }
2489
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2490 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
2491 const DeviceMemory<float> &a, int lda,
2492 const DeviceMemory<float> &x, int incx, float beta,
2493 DeviceMemory<float> *y, int incy) {
2494 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2495 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2496
2497 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2498 int, const DeviceMemory<float> &, int, float,
2499 DeviceMemory<float> *, int>
2500 impl;
2501 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2502 incx, beta, y, incy);
2503 }
2504
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2505 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
2506 const DeviceMemory<double> &a, int lda,
2507 const DeviceMemory<double> &x, int incx,
2508 double beta, DeviceMemory<double> *y, int incy) {
2509 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2510 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2511
2512 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2513 int, const DeviceMemory<double> &, int, double,
2514 DeviceMemory<double> *, int>
2515 impl;
2516 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2517 incx, beta, y, incy);
2518 }
2519
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)2520 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
2521 const DeviceMemory<float> &x, int incx,
2522 DeviceMemory<float> *a, int lda) {
2523 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2524 PARAM(a), PARAM(lda));
2525
2526 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2527 int, DeviceMemory<float> *, int>
2528 impl;
2529 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2530 lda);
2531 }
2532
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)2533 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
2534 const DeviceMemory<double> &x, int incx,
2535 DeviceMemory<double> *a, int lda) {
2536 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2537 PARAM(a), PARAM(lda));
2538
2539 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2540 int, DeviceMemory<double> *, int>
2541 impl;
2542 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2543 lda);
2544 }
2545
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2546 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
2547 const DeviceMemory<float> &x, int incx,
2548 const DeviceMemory<float> &y, int incy,
2549 DeviceMemory<float> *a, int lda) {
2550 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2551 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2552
2553 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2554 int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2555 int>
2556 impl;
2557 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2558 incy, a, lda);
2559 }
2560
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2561 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
2562 const DeviceMemory<double> &x, int incx,
2563 const DeviceMemory<double> &y, int incy,
2564 DeviceMemory<double> *a, int lda) {
2565 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2566 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2567
2568 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2569 int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
2570 int>
2571 impl;
2572 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2573 incy, a, lda);
2574 }
2575
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2576 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2577 blas::Diagonal diag, uint64 n, uint64 k,
2578 const DeviceMemory<float> &a, int lda,
2579 DeviceMemory<float> *x, int incx) {
2580 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2581 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2582
2583 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2584 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2585 int>
2586 impl;
2587 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2588 lda, x, incx);
2589 }
2590
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2591 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2592 blas::Diagonal diag, uint64 n, uint64 k,
2593 const DeviceMemory<double> &a, int lda,
2594 DeviceMemory<double> *x, int incx) {
2595 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2596 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2597
2598 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2599 uint64, const DeviceMemory<double> &, int,
2600 DeviceMemory<double> *, int>
2601 impl;
2602 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2603 lda, x, incx);
2604 }
2605
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2606 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2607 blas::Diagonal diag, uint64 n, uint64 k,
2608 const DeviceMemory<std::complex<float>> &a,
2609 int lda, DeviceMemory<std::complex<float>> *x,
2610 int incx) {
2611 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2612 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2613
2614 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2615 uint64, const DeviceMemory<std::complex<float>> &, int,
2616 DeviceMemory<std::complex<float>> *, int>
2617 impl;
2618 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2619 lda, x, incx);
2620 }
2621
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2622 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2623 blas::Diagonal diag, uint64 n, uint64 k,
2624 const DeviceMemory<std::complex<double>> &a,
2625 int lda, DeviceMemory<std::complex<double>> *x,
2626 int incx) {
2627 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2628 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2629
2630 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2631 uint64, const DeviceMemory<std::complex<double>> &, int,
2632 DeviceMemory<std::complex<double>> *, int>
2633 impl;
2634 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2635 lda, x, incx);
2636 }
2637
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2638 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2639 blas::Diagonal diag, uint64 n, uint64 k,
2640 const DeviceMemory<float> &a, int lda,
2641 DeviceMemory<float> *x, int incx) {
2642 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2643 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2644
2645 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2646 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2647 int>
2648 impl;
2649 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2650 lda, x, incx);
2651 }
2652
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2653 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2654 blas::Diagonal diag, uint64 n, uint64 k,
2655 const DeviceMemory<double> &a, int lda,
2656 DeviceMemory<double> *x, int incx) {
2657 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2658 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2659
2660 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2661 uint64, const DeviceMemory<double> &, int,
2662 DeviceMemory<double> *, int>
2663 impl;
2664 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2665 lda, x, incx);
2666 }
2667
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2668 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2669 blas::Diagonal diag, uint64 n, uint64 k,
2670 const DeviceMemory<std::complex<float>> &a,
2671 int lda, DeviceMemory<std::complex<float>> *x,
2672 int incx) {
2673 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2674 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2675
2676 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2677 uint64, const DeviceMemory<std::complex<float>> &, int,
2678 DeviceMemory<std::complex<float>> *, int>
2679 impl;
2680 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2681 lda, x, incx);
2682 }
2683
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2684 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2685 blas::Diagonal diag, uint64 n, uint64 k,
2686 const DeviceMemory<std::complex<double>> &a,
2687 int lda, DeviceMemory<std::complex<double>> *x,
2688 int incx) {
2689 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2690 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2691
2692 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2693 uint64, const DeviceMemory<std::complex<double>> &, int,
2694 DeviceMemory<std::complex<double>> *, int>
2695 impl;
2696 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2697 lda, x, incx);
2698 }
2699
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2700 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2701 blas::Diagonal diag, uint64 n,
2702 const DeviceMemory<float> &ap,
2703 DeviceMemory<float> *x, int incx) {
2704 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2705 PARAM(x), PARAM(incx));
2706
2707 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2708 const DeviceMemory<float> &, DeviceMemory<float> *, int>
2709 impl;
2710 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2711 incx);
2712 }
2713
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2714 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2715 blas::Diagonal diag, uint64 n,
2716 const DeviceMemory<double> &ap,
2717 DeviceMemory<double> *x, int incx) {
2718 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2719 PARAM(x), PARAM(incx));
2720
2721 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2722 const DeviceMemory<double> &, DeviceMemory<double> *, int>
2723 impl;
2724 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2725 incx);
2726 }
2727
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2728 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2729 blas::Diagonal diag, uint64 n,
2730 const DeviceMemory<std::complex<float>> &ap,
2731 DeviceMemory<std::complex<float>> *x, int incx) {
2732 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2733 PARAM(x), PARAM(incx));
2734
2735 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2736 const DeviceMemory<std::complex<float>> &,
2737 DeviceMemory<std::complex<float>> *, int>
2738 impl;
2739 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2740 incx);
2741 }
2742
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2743 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2744 blas::Diagonal diag, uint64 n,
2745 const DeviceMemory<std::complex<double>> &ap,
2746 DeviceMemory<std::complex<double>> *x, int incx) {
2747 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2748 PARAM(x), PARAM(incx));
2749
2750 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2751 const DeviceMemory<std::complex<double>> &,
2752 DeviceMemory<std::complex<double>> *, int>
2753 impl;
2754 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2755 incx);
2756 }
2757
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2758 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2759 blas::Diagonal diag, uint64 n,
2760 const DeviceMemory<float> &ap,
2761 DeviceMemory<float> *x, int incx) {
2762 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2763 PARAM(x), PARAM(incx));
2764
2765 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2766 const DeviceMemory<float> &, DeviceMemory<float> *, int>
2767 impl;
2768 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2769 incx);
2770 }
2771
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2772 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2773 blas::Diagonal diag, uint64 n,
2774 const DeviceMemory<double> &ap,
2775 DeviceMemory<double> *x, int incx) {
2776 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2777 PARAM(x), PARAM(incx));
2778
2779 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2780 const DeviceMemory<double> &, DeviceMemory<double> *, int>
2781 impl;
2782 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2783 incx);
2784 }
2785
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2786 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2787 blas::Diagonal diag, uint64 n,
2788 const DeviceMemory<std::complex<float>> &ap,
2789 DeviceMemory<std::complex<float>> *x, int incx) {
2790 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2791 PARAM(x), PARAM(incx));
2792
2793 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2794 const DeviceMemory<std::complex<float>> &,
2795 DeviceMemory<std::complex<float>> *, int>
2796 impl;
2797 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2798 incx);
2799 }
2800
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2801 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2802 blas::Diagonal diag, uint64 n,
2803 const DeviceMemory<std::complex<double>> &ap,
2804 DeviceMemory<std::complex<double>> *x, int incx) {
2805 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2806 PARAM(x), PARAM(incx));
2807
2808 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2809 const DeviceMemory<std::complex<double>> &,
2810 DeviceMemory<std::complex<double>> *, int>
2811 impl;
2812 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2813 incx);
2814 }
2815
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2816 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2817 blas::Diagonal diag, uint64 n,
2818 const DeviceMemory<float> &a, int lda,
2819 DeviceMemory<float> *x, int incx) {
2820 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2821 PARAM(lda), PARAM(x), PARAM(incx));
2822
2823 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2824 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2825 impl;
2826 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2827 lda, x, incx);
2828 }
2829
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2830 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2831 blas::Diagonal diag, uint64 n,
2832 const DeviceMemory<double> &a, int lda,
2833 DeviceMemory<double> *x, int incx) {
2834 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2835 PARAM(lda), PARAM(x), PARAM(incx));
2836
2837 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2838 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2839 impl;
2840 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2841 lda, x, incx);
2842 }
2843
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2844 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2845 blas::Diagonal diag, uint64 n,
2846 const DeviceMemory<std::complex<float>> &a,
2847 int lda, DeviceMemory<std::complex<float>> *x,
2848 int incx) {
2849 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2850 PARAM(lda), PARAM(x), PARAM(incx));
2851
2852 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2853 const DeviceMemory<std::complex<float>> &, int,
2854 DeviceMemory<std::complex<float>> *, int>
2855 impl;
2856 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2857 lda, x, incx);
2858 }
2859
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2860 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2861 blas::Diagonal diag, uint64 n,
2862 const DeviceMemory<std::complex<double>> &a,
2863 int lda, DeviceMemory<std::complex<double>> *x,
2864 int incx) {
2865 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2866 PARAM(lda), PARAM(x), PARAM(incx));
2867
2868 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2869 const DeviceMemory<std::complex<double>> &, int,
2870 DeviceMemory<std::complex<double>> *, int>
2871 impl;
2872 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2873 lda, x, incx);
2874 }
2875
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2876 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2877 blas::Diagonal diag, uint64 n,
2878 const DeviceMemory<float> &a, int lda,
2879 DeviceMemory<float> *x, int incx) {
2880 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2881 PARAM(lda), PARAM(x), PARAM(incx));
2882
2883 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2884 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2885 impl;
2886 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2887 lda, x, incx);
2888 }
2889
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2890 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2891 blas::Diagonal diag, uint64 n,
2892 const DeviceMemory<double> &a, int lda,
2893 DeviceMemory<double> *x, int incx) {
2894 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2895 PARAM(lda), PARAM(x), PARAM(incx));
2896
2897 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2898 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2899 impl;
2900 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2901 lda, x, incx);
2902 }
2903
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2904 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2905 blas::Diagonal diag, uint64 n,
2906 const DeviceMemory<std::complex<float>> &a,
2907 int lda, DeviceMemory<std::complex<float>> *x,
2908 int incx) {
2909 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2910 PARAM(lda), PARAM(x), PARAM(incx));
2911
2912 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2913 const DeviceMemory<std::complex<float>> &, int,
2914 DeviceMemory<std::complex<float>> *, int>
2915 impl;
2916 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2917 lda, x, incx);
2918 }
2919
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2920 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2921 blas::Diagonal diag, uint64 n,
2922 const DeviceMemory<std::complex<double>> &a,
2923 int lda, DeviceMemory<std::complex<double>> *x,
2924 int incx) {
2925 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2926 PARAM(lda), PARAM(x), PARAM(incx));
2927
2928 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2929 const DeviceMemory<std::complex<double>> &, int,
2930 DeviceMemory<std::complex<double>> *, int>
2931 impl;
2932 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2933 lda, x, incx);
2934 }
2935
2936 namespace {
2937 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
2938 // blas::ProfileResult*. This functor doesn't put the stream into an error
2939 // state if the op fails and the profile result is non-null. Instead, the
2940 // error-ness is returned in the profile result itself.
2941 template <typename... Args>
2942 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon9675edea0211::ThenBlasWithProfileImpl2943 Stream &operator()(Stream *stream,
2944 bool (blas::BlasSupport::*blas_func)(
2945 Stream *, Args..., blas::ProfileResult *),
2946 Args... args, blas::ProfileResult *profile_result) {
2947 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
2948 bool record_error = profile_result == nullptr;
2949 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
2950 }
2951 };
2952 } // anonymous namespace
2953
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)2954 Stream &Stream::ThenBlasGemvWithProfiling(
2955 blas::Transpose trans, uint64 m, uint64 n, float alpha,
2956 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
2957 int incx, float beta, DeviceMemory<float> *y, int incy,
2958 blas::ProfileResult *output_profile_result) {
2959 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2960 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2961 PARAM(incy));
2962
2963 ThenBlasWithProfileImpl<
2964 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
2965 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
2966 impl;
2967 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
2968 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
2969 }
2970
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)2971 Stream &Stream::ThenBlasGemvWithProfiling(
2972 blas::Transpose trans, uint64 m, uint64 n, double alpha,
2973 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
2974 int incx, double beta, DeviceMemory<double> *y, int incy,
2975 blas::ProfileResult *output_profile_result) {
2976 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2977 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2978 PARAM(incy));
2979
2980 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
2981 const DeviceMemory<double> &, int,
2982 const DeviceMemory<double> &, int, double,
2983 DeviceMemory<double> *, int>
2984 impl;
2985 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
2986 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
2987 }
2988
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)2989 Stream &Stream::ThenBlasGemvWithProfiling(
2990 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
2991 const DeviceMemory<std::complex<float>> &a, int lda,
2992 const DeviceMemory<std::complex<float>> &x, int incx,
2993 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
2994 blas::ProfileResult *output_profile_result) {
2995 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2996 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2997 PARAM(incy));
2998
2999 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3000 const DeviceMemory<std::complex<float>> &, int,
3001 const DeviceMemory<std::complex<float>> &, int,
3002 std::complex<float>,
3003 DeviceMemory<std::complex<float>> *, int>
3004 impl;
3005 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3006 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3007 }
3008
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3009 Stream &Stream::ThenBlasGemvWithProfiling(
3010 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3011 const DeviceMemory<std::complex<double>> &a, int lda,
3012 const DeviceMemory<std::complex<double>> &x, int incx,
3013 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3014 blas::ProfileResult *output_profile_result) {
3015 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3016 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3017 PARAM(incy));
3018
3019 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3020 const DeviceMemory<std::complex<double>> &, int,
3021 const DeviceMemory<std::complex<double>> &, int,
3022 std::complex<double>,
3023 DeviceMemory<std::complex<double>> *, int>
3024 impl;
3025 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3026 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3027 }
3028
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3029 Stream &Stream::ThenBlasGemmWithProfiling(
3030 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3031 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3032 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3033 DeviceMemory<Eigen::half> *c, int ldc,
3034 blas::ProfileResult *output_profile_result) {
3035 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3036 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3037 PARAM(beta), PARAM(c), PARAM(ldc));
3038
3039 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3040 uint64, float, const DeviceMemory<Eigen::half> &, int,
3041 const DeviceMemory<Eigen::half> &, int, float,
3042 DeviceMemory<Eigen::half> *, int>
3043 impl;
3044 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3045 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3046 output_profile_result);
3047 }
3048
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3049 Stream &Stream::ThenBlasGemmWithProfiling(
3050 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3051 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3052 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3053 int ldc, blas::ProfileResult *output_profile_result) {
3054 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3055 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3056 PARAM(beta), PARAM(c), PARAM(ldc));
3057
3058 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3059 uint64, float, const DeviceMemory<float> &, int,
3060 const DeviceMemory<float> &, int, float,
3061 DeviceMemory<float> *, int>
3062 impl;
3063 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3064 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3065 output_profile_result);
3066 }
3067
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3068 Stream &Stream::ThenBlasGemmWithProfiling(
3069 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3070 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3071 const DeviceMemory<double> &b, int ldb, double beta,
3072 DeviceMemory<double> *c, int ldc,
3073 blas::ProfileResult *output_profile_result) {
3074 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3075 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3076 PARAM(beta), PARAM(c), PARAM(ldc));
3077
3078 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3079 uint64, double, const DeviceMemory<double> &, int,
3080 const DeviceMemory<double> &, int, double,
3081 DeviceMemory<double> *, int>
3082 impl;
3083 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3084 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3085 output_profile_result);
3086 }
3087
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3088 Stream &Stream::ThenBlasGemmWithProfiling(
3089 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3090 uint64 k, std::complex<float> alpha,
3091 const DeviceMemory<std::complex<float>> &a, int lda,
3092 const DeviceMemory<std::complex<float>> &b, int ldb,
3093 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3094 blas::ProfileResult *output_profile_result) {
3095 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3096 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3097 PARAM(beta), PARAM(c), PARAM(ldc));
3098
3099 ThenBlasWithProfileImpl<
3100 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3101 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3102 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3103 DeviceMemory<std::complex<float>> *, int>
3104 impl;
3105 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3106 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3107 output_profile_result);
3108 }
3109
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3110 Stream &Stream::ThenBlasGemmWithProfiling(
3111 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3112 uint64 k, std::complex<double> alpha,
3113 const DeviceMemory<std::complex<double>> &a, int lda,
3114 const DeviceMemory<std::complex<double>> &b, int ldb,
3115 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3116 blas::ProfileResult *output_profile_result) {
3117 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3118 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3119 PARAM(beta), PARAM(c), PARAM(ldc));
3120
3121 ThenBlasWithProfileImpl<
3122 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3123 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3124 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3125 DeviceMemory<std::complex<double>> *, int>
3126 impl;
3127 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3128 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3129 output_profile_result);
3130 }
3131
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3132 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3133 uint64 n, std::complex<float> alpha,
3134 const DeviceMemory<std::complex<float>> &a,
3135 int lda,
3136 const DeviceMemory<std::complex<float>> &b,
3137 int ldb, std::complex<float> beta,
3138 DeviceMemory<std::complex<float>> *c, int ldc) {
3139 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3140 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3141 PARAM(ldc));
3142
3143 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3144 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3145 int, const DeviceMemory<std::complex<float>> &, int,
3146 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3147 impl;
3148 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3149 lda, b, ldb, beta, c, ldc);
3150 }
3151
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3152 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3153 uint64 n, std::complex<double> alpha,
3154 const DeviceMemory<std::complex<double>> &a,
3155 int lda,
3156 const DeviceMemory<std::complex<double>> &b,
3157 int ldb, std::complex<double> beta,
3158 DeviceMemory<std::complex<double>> *c, int ldc) {
3159 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3160 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3161 PARAM(ldc));
3162
3163 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3164 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3165 int, const DeviceMemory<std::complex<double>> &, int,
3166 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3167 impl;
3168 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3169 lda, b, ldb, beta, c, ldc);
3170 }
3171
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3172 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3173 uint64 n, uint64 k, float alpha,
3174 const DeviceMemory<std::complex<float>> &a,
3175 int lda, float beta,
3176 DeviceMemory<std::complex<float>> *c, int ldc) {
3177 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3178 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3179
3180 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3181 const DeviceMemory<std::complex<float>> &, int, float,
3182 DeviceMemory<std::complex<float>> *, int>
3183 impl;
3184 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3185 lda, beta, c, ldc);
3186 }
3187
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3188 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3189 uint64 n, uint64 k, double alpha,
3190 const DeviceMemory<std::complex<double>> &a,
3191 int lda, double beta,
3192 DeviceMemory<std::complex<double>> *c, int ldc) {
3193 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3194 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3195
3196 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3197 const DeviceMemory<std::complex<double>> &, int, double,
3198 DeviceMemory<std::complex<double>> *, int>
3199 impl;
3200 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3201 lda, beta, c, ldc);
3202 }
3203
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3204 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3205 uint64 n, uint64 k, std::complex<float> alpha,
3206 const DeviceMemory<std::complex<float>> &a,
3207 int lda,
3208 const DeviceMemory<std::complex<float>> &b,
3209 int ldb, float beta,
3210 DeviceMemory<std::complex<float>> *c, int ldc) {
3211 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3212 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3213 PARAM(ldc));
3214
3215 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3216 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3217 int, const DeviceMemory<std::complex<float>> &, int, float,
3218 DeviceMemory<std::complex<float>> *, int>
3219 impl;
3220 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3221 a, lda, b, ldb, beta, c, ldc);
3222 }
3223
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3224 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3225 uint64 n, uint64 k, std::complex<double> alpha,
3226 const DeviceMemory<std::complex<double>> &a,
3227 int lda,
3228 const DeviceMemory<std::complex<double>> &b,
3229 int ldb, double beta,
3230 DeviceMemory<std::complex<double>> *c, int ldc) {
3231 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3232 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3233 PARAM(ldc));
3234
3235 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3236 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3237 int, const DeviceMemory<std::complex<double>> &, int, double,
3238 DeviceMemory<std::complex<double>> *, int>
3239 impl;
3240 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3241 a, lda, b, ldb, beta, c, ldc);
3242 }
3243
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3244 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3245 uint64 n, float alpha,
3246 const DeviceMemory<float> &a, int lda,
3247 const DeviceMemory<float> &b, int ldb, float beta,
3248 DeviceMemory<float> *c, int ldc) {
3249 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3250 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3251 PARAM(ldc));
3252
3253 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
3254 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3255 int, float, DeviceMemory<float> *, int>
3256 impl;
3257 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3258 lda, b, ldb, beta, c, ldc);
3259 }
3260
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3261 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3262 uint64 n, double alpha,
3263 const DeviceMemory<double> &a, int lda,
3264 const DeviceMemory<double> &b, int ldb,
3265 double beta, DeviceMemory<double> *c, int ldc) {
3266 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3267 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3268 PARAM(ldc));
3269
3270 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
3271 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3272 int, double, DeviceMemory<double> *, int>
3273 impl;
3274 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3275 lda, b, ldb, beta, c, ldc);
3276 }
3277
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3278 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3279 uint64 n, std::complex<float> alpha,
3280 const DeviceMemory<std::complex<float>> &a,
3281 int lda,
3282 const DeviceMemory<std::complex<float>> &b,
3283 int ldb, std::complex<float> beta,
3284 DeviceMemory<std::complex<float>> *c, int ldc) {
3285 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3286 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3287 PARAM(ldc));
3288
3289 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3290 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3291 int, const DeviceMemory<std::complex<float>> &, int,
3292 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3293 impl;
3294 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3295 lda, b, ldb, beta, c, ldc);
3296 }
3297
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3298 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3299 uint64 n, std::complex<double> alpha,
3300 const DeviceMemory<std::complex<double>> &a,
3301 int lda,
3302 const DeviceMemory<std::complex<double>> &b,
3303 int ldb, std::complex<double> beta,
3304 DeviceMemory<std::complex<double>> *c, int ldc) {
3305 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3306 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3307 PARAM(ldc));
3308
3309 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3310 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3311 int, const DeviceMemory<std::complex<double>> &, int,
3312 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3313 impl;
3314 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3315 lda, b, ldb, beta, c, ldc);
3316 }
3317
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)3318 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3319 uint64 n, uint64 k, float alpha,
3320 const DeviceMemory<float> &a, int lda, float beta,
3321 DeviceMemory<float> *c, int ldc) {
3322 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3323 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3324
3325 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3326 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3327 int>
3328 impl;
3329 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3330 lda, beta, c, ldc);
3331 }
3332
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)3333 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3334 uint64 n, uint64 k, double alpha,
3335 const DeviceMemory<double> &a, int lda,
3336 double beta, DeviceMemory<double> *c, int ldc) {
3337 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3338 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3339
3340 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3341 const DeviceMemory<double> &, int, double,
3342 DeviceMemory<double> *, int>
3343 impl;
3344 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3345 lda, beta, c, ldc);
3346 }
3347
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3348 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3349 uint64 n, uint64 k, std::complex<float> alpha,
3350 const DeviceMemory<std::complex<float>> &a,
3351 int lda, std::complex<float> beta,
3352 DeviceMemory<std::complex<float>> *c, int ldc) {
3353 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3354 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3355
3356 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3357 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3358 int, std::complex<float>, DeviceMemory<std::complex<float>> *,
3359 int>
3360 impl;
3361 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3362 lda, beta, c, ldc);
3363 }
3364
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3365 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3366 uint64 n, uint64 k, std::complex<double> alpha,
3367 const DeviceMemory<std::complex<double>> &a,
3368 int lda, std::complex<double> beta,
3369 DeviceMemory<std::complex<double>> *c, int ldc) {
3370 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3371 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3372
3373 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3374 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3375 int, std::complex<double>, DeviceMemory<std::complex<double>> *,
3376 int>
3377 impl;
3378 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3379 lda, beta, c, ldc);
3380 }
3381
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3382 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3383 uint64 n, uint64 k, float alpha,
3384 const DeviceMemory<float> &a, int lda,
3385 const DeviceMemory<float> &b, int ldb, float beta,
3386 DeviceMemory<float> *c, int ldc) {
3387 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3388 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3389 PARAM(ldc));
3390
3391 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3392 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3393 int, float, DeviceMemory<float> *, int>
3394 impl;
3395 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3396 a, lda, b, ldb, beta, c, ldc);
3397 }
3398
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3399 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3400 uint64 n, uint64 k, double alpha,
3401 const DeviceMemory<double> &a, int lda,
3402 const DeviceMemory<double> &b, int ldb,
3403 double beta, DeviceMemory<double> *c, int ldc) {
3404 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3405 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3406 PARAM(ldc));
3407
3408 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3409 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3410 int, double, DeviceMemory<double> *, int>
3411 impl;
3412 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3413 a, lda, b, ldb, beta, c, ldc);
3414 }
3415
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3416 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3417 uint64 n, uint64 k, std::complex<float> alpha,
3418 const DeviceMemory<std::complex<float>> &a,
3419 int lda,
3420 const DeviceMemory<std::complex<float>> &b,
3421 int ldb, std::complex<float> beta,
3422 DeviceMemory<std::complex<float>> *c, int ldc) {
3423 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3424 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3425 PARAM(ldc));
3426
3427 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3428 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3429 int, const DeviceMemory<std::complex<float>> &, int,
3430 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3431 impl;
3432 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3433 a, lda, b, ldb, beta, c, ldc);
3434 }
3435
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3436 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3437 uint64 n, uint64 k, std::complex<double> alpha,
3438 const DeviceMemory<std::complex<double>> &a,
3439 int lda,
3440 const DeviceMemory<std::complex<double>> &b,
3441 int ldb, std::complex<double> beta,
3442 DeviceMemory<std::complex<double>> *c, int ldc) {
3443 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3444 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3445 PARAM(ldc));
3446
3447 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3448 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3449 int, const DeviceMemory<std::complex<double>> &, int,
3450 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3451 impl;
3452 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3453 a, lda, b, ldb, beta, c, ldc);
3454 }
3455
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3456 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3457 blas::Transpose transa, blas::Diagonal diag,
3458 uint64 m, uint64 n, float alpha,
3459 const DeviceMemory<float> &a, int lda,
3460 DeviceMemory<float> *b, int ldb) {
3461 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3462 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3463
3464 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3465 uint64, uint64, float, const DeviceMemory<float> &, int,
3466 DeviceMemory<float> *, int>
3467 impl;
3468 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3469 n, alpha, a, lda, b, ldb);
3470 }
3471
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3472 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3473 blas::Transpose transa, blas::Diagonal diag,
3474 uint64 m, uint64 n, double alpha,
3475 const DeviceMemory<double> &a, int lda,
3476 DeviceMemory<double> *b, int ldb) {
3477 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3478 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3479
3480 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3481 uint64, uint64, double, const DeviceMemory<double> &, int,
3482 DeviceMemory<double> *, int>
3483 impl;
3484 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3485 n, alpha, a, lda, b, ldb);
3486 }
3487
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3488 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3489 blas::Transpose transa, blas::Diagonal diag,
3490 uint64 m, uint64 n, std::complex<float> alpha,
3491 const DeviceMemory<std::complex<float>> &a,
3492 int lda, DeviceMemory<std::complex<float>> *b,
3493 int ldb) {
3494 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3495 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3496
3497 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3498 uint64, uint64, std::complex<float>,
3499 const DeviceMemory<std::complex<float>> &, int,
3500 DeviceMemory<std::complex<float>> *, int>
3501 impl;
3502 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3503 n, alpha, a, lda, b, ldb);
3504 }
3505
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3506 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3507 blas::Transpose transa, blas::Diagonal diag,
3508 uint64 m, uint64 n, std::complex<double> alpha,
3509 const DeviceMemory<std::complex<double>> &a,
3510 int lda, DeviceMemory<std::complex<double>> *b,
3511 int ldb) {
3512 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3513 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3514
3515 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3516 uint64, uint64, std::complex<double>,
3517 const DeviceMemory<std::complex<double>> &, int,
3518 DeviceMemory<std::complex<double>> *, int>
3519 impl;
3520 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3521 n, alpha, a, lda, b, ldb);
3522 }
3523
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3524 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3525 blas::Transpose transa, blas::Diagonal diag,
3526 uint64 m, uint64 n, float alpha,
3527 const DeviceMemory<float> &a, int lda,
3528 DeviceMemory<float> *b, int ldb) {
3529 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3530 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3531
3532 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3533 uint64, uint64, float, const DeviceMemory<float> &, int,
3534 DeviceMemory<float> *, int>
3535 impl;
3536 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3537 n, alpha, a, lda, b, ldb);
3538 }
3539
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3540 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3541 blas::Transpose transa, blas::Diagonal diag,
3542 uint64 m, uint64 n, double alpha,
3543 const DeviceMemory<double> &a, int lda,
3544 DeviceMemory<double> *b, int ldb) {
3545 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3546 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3547
3548 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3549 uint64, uint64, double, const DeviceMemory<double> &, int,
3550 DeviceMemory<double> *, int>
3551 impl;
3552 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3553 n, alpha, a, lda, b, ldb);
3554 }
3555
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3556 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3557 blas::Transpose transa, blas::Diagonal diag,
3558 uint64 m, uint64 n, std::complex<float> alpha,
3559 const DeviceMemory<std::complex<float>> &a,
3560 int lda, DeviceMemory<std::complex<float>> *b,
3561 int ldb) {
3562 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3563 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3564
3565 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3566 uint64, uint64, std::complex<float>,
3567 const DeviceMemory<std::complex<float>> &, int,
3568 DeviceMemory<std::complex<float>> *, int>
3569 impl;
3570 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3571 n, alpha, a, lda, b, ldb);
3572 }
3573
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3574 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3575 blas::Transpose transa, blas::Diagonal diag,
3576 uint64 m, uint64 n, std::complex<double> alpha,
3577 const DeviceMemory<std::complex<double>> &a,
3578 int lda, DeviceMemory<std::complex<double>> *b,
3579 int ldb) {
3580 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3581 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3582
3583 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3584 uint64, uint64, std::complex<double>,
3585 const DeviceMemory<std::complex<double>> &, int,
3586 DeviceMemory<std::complex<double>> *, int>
3587 impl;
3588 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3589 n, alpha, a, lda, b, ldb);
3590 }
3591
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)3592 Stream &Stream::ThenBlasGemmBatched(
3593 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3594 uint64 k, float alpha,
3595 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
3596 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
3597 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
3598 int batch_count) {
3599 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3600 b, ldb, beta, c, ldc, batch_count,
3601 /*scratch_allocator=*/nullptr);
3602 }
3603
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3604 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3605 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3606 uint64 k, float alpha,
3607 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
3608 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
3609 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
3610 int batch_count, ScratchAllocator *scratch_allocator) {
3611 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3612 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3613 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3614
3615 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3616 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
3617 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
3618 float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
3619 int, int, ScratchAllocator *>
3620 impl;
3621 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3622 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3623 scratch_allocator);
3624 }
3625
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)3626 Stream &Stream::ThenBlasGemmBatched(
3627 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3628 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
3629 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
3630 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
3631 int batch_count) {
3632 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3633 b, ldb, beta, c, ldc, batch_count,
3634 /*scratch_allocator=*/nullptr);
3635 }
3636
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3637 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3638 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3639 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
3640 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
3641 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
3642 int batch_count, ScratchAllocator *scratch_allocator) {
3643 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3644 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3645 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3646
3647 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3648 const port::ArraySlice<DeviceMemory<float> *> &, int,
3649 const port::ArraySlice<DeviceMemory<float> *> &, int, float,
3650 const port::ArraySlice<DeviceMemory<float> *> &, int, int,
3651 ScratchAllocator *>
3652 impl;
3653 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3654 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3655 scratch_allocator);
3656 }
3657
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)3658 Stream &Stream::ThenBlasGemmBatched(
3659 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3660 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
3661 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
3662 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
3663 int batch_count) {
3664 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3665 b, ldb, beta, c, ldc, batch_count,
3666 /*scratch_allocator=*/nullptr);
3667 }
3668
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3669 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3670 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3671 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
3672 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
3673 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
3674 int batch_count, ScratchAllocator *scratch_allocator) {
3675 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3676 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3677 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3678
3679 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3680 const port::ArraySlice<DeviceMemory<double> *> &, int,
3681 const port::ArraySlice<DeviceMemory<double> *> &, int, double,
3682 const port::ArraySlice<DeviceMemory<double> *> &, int, int,
3683 ScratchAllocator *>
3684 impl;
3685 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3686 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3687 scratch_allocator);
3688 }
3689
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)3690 Stream &Stream::ThenBlasGemmBatched(
3691 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3692 uint64 k, std::complex<float> alpha,
3693 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
3694 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
3695 std::complex<float> beta,
3696 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
3697 int batch_count) {
3698 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3699 b, ldb, beta, c, ldc, batch_count,
3700 /*scratch_allocator=*/nullptr);
3701 }
3702
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3703 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3704 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3705 uint64 k, std::complex<float> alpha,
3706 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
3707 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
3708 std::complex<float> beta,
3709 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
3710 int batch_count, ScratchAllocator *scratch_allocator) {
3711 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3712 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3713 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3714
3715 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3716 std::complex<float>,
3717 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3718 int,
3719 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3720 int, std::complex<float>,
3721 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3722 int, int, ScratchAllocator *>
3723 impl;
3724 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3725 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3726 scratch_allocator);
3727 }
3728
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)3729 Stream &Stream::ThenBlasGemmBatched(
3730 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3731 uint64 k, std::complex<double> alpha,
3732 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
3733 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
3734 std::complex<double> beta,
3735 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
3736 int batch_count) {
3737 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3738 b, ldb, beta, c, ldc, batch_count,
3739 /*scratch_allocator=*/nullptr);
3740 }
3741
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3742 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3743 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3744 uint64 k, std::complex<double> alpha,
3745 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
3746 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
3747 std::complex<double> beta,
3748 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
3749 int batch_count, ScratchAllocator *scratch_allocator) {
3750 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3751 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3752 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3753
3754 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3755 std::complex<double>,
3756 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3757 int,
3758 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3759 int, std::complex<double>,
3760 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3761 int, int, ScratchAllocator *>
3762 impl;
3763 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3764 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3765 scratch_allocator);
3766 }
3767
3768 template <typename ABType, typename CType>
ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<CType> & alpha,const DeviceMemory<ABType> & a,const DeviceMemory<ABType> & b,const HostOrDeviceScalar<CType> & beta,DeviceMemory<CType> * c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,const DeviceMemory<CType> & bias,blas::ProfileResult * output_profile_result)3769 Stream &Stream::ThenBlasLtMatmulImpl(
3770 const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
3771 const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
3772 const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
3773 ScratchAllocator *scratch_allocator,
3774 const blas::IBlasLtMatmulAlgorithm *algorithm,
3775 const DeviceMemory<CType> &bias,
3776 blas::ProfileResult *output_profile_result) {
3777 VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
3778 PARAM(c), PARAM(algorithm), PARAM(bias));
3779
3780 ThenBlasWithProfileImpl<
3781 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
3782 const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
3783 const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
3784 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3785 const DeviceMemory<CType> &>
3786 impl;
3787 return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
3788 c, scratch_allocator, algorithm, bias, output_profile_result);
3789 }
3790
3791 // Explicit template instantiations for each supported type combination.
3792 template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
3793 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
3794 const DeviceMemory<int8> &, const DeviceMemory<int8> &,
3795 const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
3796 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3797 const DeviceMemory<int32> &, blas::ProfileResult *);
3798
3799 template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
3800 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
3801 const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
3802 const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3803 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3804 const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
3805
3806 template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
3807 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
3808 const DeviceMemory<float> &, const DeviceMemory<float> &,
3809 const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
3810 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3811 const DeviceMemory<float> &, blas::ProfileResult *);
3812
3813 template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
3814 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
3815 const DeviceMemory<double> &, const DeviceMemory<double> &,
3816 const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
3817 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3818 const DeviceMemory<double> &, blas::ProfileResult *);
3819
3820 template Stream &
3821 Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
3822 const blas::IBlasLtMatmulPlan *,
3823 const HostOrDeviceScalar<std::complex<float>> &,
3824 const DeviceMemory<std::complex<float>> &,
3825 const DeviceMemory<std::complex<float>> &,
3826 const HostOrDeviceScalar<std::complex<float>> &,
3827 DeviceMemory<std::complex<float>> *, ScratchAllocator *,
3828 const blas::IBlasLtMatmulAlgorithm *,
3829 const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
3830
3831 template Stream &
3832 Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
3833 const blas::IBlasLtMatmulPlan *,
3834 const HostOrDeviceScalar<std::complex<double>> &,
3835 const DeviceMemory<std::complex<double>> &,
3836 const DeviceMemory<std::complex<double>> &,
3837 const HostOrDeviceScalar<std::complex<double>> &,
3838 DeviceMemory<std::complex<double>> *, ScratchAllocator *,
3839 const blas::IBlasLtMatmulAlgorithm *,
3840 const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
3841
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)3842 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
3843 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
3844
3845 if (rng::RngSupport *rng = parent_->AsRng()) {
3846 CheckError(rng->SetSeed(this, seed, seed_bytes));
3847 } else {
3848 SetError();
3849 LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
3850 }
3851 return *this;
3852 }
3853
ThenPopulateRandUniform(DeviceMemory<float> * values)3854 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
3855 VLOG_CALL(PARAM(values));
3856
3857 if (rng::RngSupport *rng = parent_->AsRng()) {
3858 CheckError(rng->DoPopulateRandUniform(this, values));
3859 } else {
3860 SetError();
3861 LOG(INFO) << DebugStreamPointers()
3862 << " attempting to perform RNG operation using StreamExecutor"
3863 " without RNG support.";
3864 }
3865 return *this;
3866 }
3867
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)3868 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
3869 DeviceMemory<float> *values) {
3870 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
3871
3872 if (rng::RngSupport *rng = parent_->AsRng()) {
3873 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
3874 } else {
3875 SetError();
3876 LOG(INFO) << DebugStreamPointers()
3877 << " attempting to perform RNG operation using StreamExecutor"
3878 " without RNG support.";
3879 }
3880 return *this;
3881 }
3882
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)3883 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
3884 DeviceMemory<double> *values) {
3885 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
3886
3887 if (rng::RngSupport *rng = parent_->AsRng()) {
3888 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
3889 } else {
3890 SetError();
3891 LOG(INFO) << DebugStreamPointers()
3892 << " attempting to perform RNG operation using StreamExecutor"
3893 " without RNG support.";
3894 }
3895 return *this;
3896 }
3897
ThenPopulateRandUniform(DeviceMemory<double> * values)3898 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
3899 VLOG_CALL(PARAM(values));
3900
3901 if (rng::RngSupport *rng = parent_->AsRng()) {
3902 CheckError(rng->DoPopulateRandUniform(this, values));
3903 } else {
3904 SetError();
3905 LOG(INFO) << DebugStreamPointers()
3906 << " attempting to perform RNG operation using StreamExecutor"
3907 " without RNG support.";
3908 }
3909 return *this;
3910 }
3911
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)3912 Stream &Stream::ThenPopulateRandUniform(
3913 DeviceMemory<std::complex<float>> *values) {
3914 VLOG_CALL(PARAM(values));
3915
3916 if (rng::RngSupport *rng = parent_->AsRng()) {
3917 CheckError(rng->DoPopulateRandUniform(this, values));
3918 } else {
3919 SetError();
3920 LOG(INFO) << DebugStreamPointers()
3921 << " attempting to perform RNG operation using StreamExecutor"
3922 " without RNG support.";
3923 }
3924 return *this;
3925 }
3926
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)3927 Stream &Stream::ThenPopulateRandUniform(
3928 DeviceMemory<std::complex<double>> *values) {
3929 VLOG_CALL(PARAM(values));
3930
3931 if (rng::RngSupport *rng = parent_->AsRng()) {
3932 CheckError(rng->DoPopulateRandUniform(this, values));
3933 } else {
3934 SetError();
3935 LOG(INFO) << DebugStreamPointers()
3936 << " attempting to perform RNG operation using StreamExecutor"
3937 " without RNG support.";
3938 }
3939 return *this;
3940 }
3941
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)3942 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
3943 uint64 size) {
3944 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
3945
3946 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
3947 return *this;
3948 }
3949
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)3950 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
3951 uint64 size) {
3952 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
3953
3954 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
3955 return *this;
3956 }
3957
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)3958 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
3959 const DeviceMemoryBase &gpu_src, uint64 size) {
3960 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
3961
3962 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
3963 return *this;
3964 }
3965
ThenMemZero(DeviceMemoryBase * location,uint64 size)3966 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
3967 VLOG_CALL(PARAM(location), PARAM(size));
3968
3969 CheckStatus(parent_->MemZero(this, location, size));
3970 return *this;
3971 }
3972
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)3973 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
3974 uint64 size) {
3975 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
3976
3977 CheckStatus(parent_->Memset32(this, location, pattern, size));
3978 return *this;
3979 }
3980
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)3981 Stream &Stream::ThenRnnForward(
3982 const dnn::RnnDescriptor &rnn_desc,
3983 const dnn::RnnSequenceTensorDescriptor &input_desc,
3984 const DeviceMemory<Eigen::half> &input_data,
3985 const DeviceMemory<int> &seq_lengths_data,
3986 const dnn::RnnStateTensorDescriptor &input_h_desc,
3987 const DeviceMemory<Eigen::half> &input_h_data,
3988 const dnn::RnnStateTensorDescriptor &input_c_desc,
3989 const DeviceMemory<Eigen::half> &input_c_data,
3990 const DeviceMemory<Eigen::half> ¶ms,
3991 const dnn::RnnSequenceTensorDescriptor &output_desc,
3992 DeviceMemory<Eigen::half> *output_data,
3993 const dnn::RnnStateTensorDescriptor &output_h_desc,
3994 DeviceMemory<Eigen::half> *output_h_data,
3995 const dnn::RnnStateTensorDescriptor &output_c_desc,
3996 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
3997 ScratchAllocator *reserve_space_allocator,
3998 ScratchAllocator *workspace_allocator,
3999 dnn::ProfileResult *output_profile_result) {
4000 // TODO(zhengxq): add VLOG PARAM calls.
4001 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4002 auto status = dnn->DoRnnForward(
4003 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4004 input_h_data, input_c_desc, input_c_data, params, output_desc,
4005 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4006 is_training, reserve_space_allocator, workspace_allocator,
4007 output_profile_result);
4008 if (!status && !output_profile_result) {
4009 SetError();
4010 }
4011 } else {
4012 SetErrorAndLogNoDnnSupport();
4013 }
4014 return *this;
4015 }
4016
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4017 Stream &Stream::ThenRnnForward(
4018 const dnn::RnnDescriptor &rnn_desc,
4019 const dnn::RnnSequenceTensorDescriptor &input_desc,
4020 const DeviceMemory<float> &input_data,
4021 const DeviceMemory<int> &seq_lengths_data,
4022 const dnn::RnnStateTensorDescriptor &input_h_desc,
4023 const DeviceMemory<float> &input_h_data,
4024 const dnn::RnnStateTensorDescriptor &input_c_desc,
4025 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4026 const dnn::RnnSequenceTensorDescriptor &output_desc,
4027 DeviceMemory<float> *output_data,
4028 const dnn::RnnStateTensorDescriptor &output_h_desc,
4029 DeviceMemory<float> *output_h_data,
4030 const dnn::RnnStateTensorDescriptor &output_c_desc,
4031 DeviceMemory<float> *output_c_data, bool is_training,
4032 ScratchAllocator *reserve_space_allocator,
4033 ScratchAllocator *workspace_allocator,
4034 dnn::ProfileResult *output_profile_result) {
4035 // TODO(zhengxq): add VLOG PARAM calls.
4036 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4037 auto status = dnn->DoRnnForward(
4038 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4039 input_h_data, input_c_desc, input_c_data, params, output_desc,
4040 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4041 is_training, reserve_space_allocator, workspace_allocator,
4042 output_profile_result);
4043 if (!status && !output_profile_result) {
4044 SetError();
4045 }
4046 } else {
4047 SetErrorAndLogNoDnnSupport();
4048 }
4049 return *this;
4050 }
4051
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4052 Stream &Stream::ThenRnnForward(
4053 const dnn::RnnDescriptor &rnn_desc,
4054 const dnn::RnnSequenceTensorDescriptor &input_desc,
4055 const DeviceMemory<double> &input_data,
4056 const DeviceMemory<int> &seq_lengths_data,
4057 const dnn::RnnStateTensorDescriptor &input_h_desc,
4058 const DeviceMemory<double> &input_h_data,
4059 const dnn::RnnStateTensorDescriptor &input_c_desc,
4060 const DeviceMemory<double> &input_c_data,
4061 const DeviceMemory<double> ¶ms,
4062 const dnn::RnnSequenceTensorDescriptor &output_desc,
4063 DeviceMemory<double> *output_data,
4064 const dnn::RnnStateTensorDescriptor &output_h_desc,
4065 DeviceMemory<double> *output_h_data,
4066 const dnn::RnnStateTensorDescriptor &output_c_desc,
4067 DeviceMemory<double> *output_c_data, bool is_training,
4068 ScratchAllocator *reserve_space_allocator,
4069 ScratchAllocator *workspace_allocator,
4070 dnn::ProfileResult *output_profile_result) {
4071 // TODO(zhengxq): add VLOG PARAM calls.
4072 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4073 auto status = dnn->DoRnnForward(
4074 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4075 input_h_data, input_c_desc, input_c_data, params, output_desc,
4076 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4077 is_training, reserve_space_allocator, workspace_allocator,
4078 output_profile_result);
4079 if (!status && !output_profile_result) {
4080 SetError();
4081 }
4082 } else {
4083 SetErrorAndLogNoDnnSupport();
4084 }
4085 return *this;
4086 }
4087
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4088 Stream &Stream::ThenRnnBackward(
4089 const dnn::RnnDescriptor &rnn_desc,
4090 const dnn::RnnSequenceTensorDescriptor &input_desc,
4091 const DeviceMemory<Eigen::half> &input_data,
4092 const DeviceMemory<int> &seq_lengths_data,
4093 const dnn::RnnStateTensorDescriptor &input_h_desc,
4094 const DeviceMemory<Eigen::half> &input_h_data,
4095 const dnn::RnnStateTensorDescriptor &input_c_desc,
4096 const DeviceMemory<Eigen::half> &input_c_data,
4097 const DeviceMemory<Eigen::half> ¶ms,
4098 const dnn::RnnSequenceTensorDescriptor &output_desc,
4099 const DeviceMemory<Eigen::half> &output_data,
4100 const dnn::RnnStateTensorDescriptor &output_h_desc,
4101 const DeviceMemory<Eigen::half> &output_h_data,
4102 const dnn::RnnStateTensorDescriptor &output_c_desc,
4103 const DeviceMemory<Eigen::half> &output_c_data,
4104 const DeviceMemory<Eigen::half> &output_backprop_data,
4105 const DeviceMemory<Eigen::half> &output_h_backprop_data,
4106 const DeviceMemory<Eigen::half> &output_c_backprop_data,
4107 DeviceMemory<Eigen::half> *input_backprop_data,
4108 DeviceMemory<Eigen::half> *input_h_backprop_data,
4109 DeviceMemory<Eigen::half> *input_c_backprop_data,
4110 DeviceMemory<Eigen::half> *params_backprop_data,
4111 DeviceMemory<uint8> *reserve_space_data,
4112 ScratchAllocator *workspace_allocator,
4113 dnn::ProfileResult *output_profile_result) {
4114 // TODO(zhengxq): add VLOG PARAM calls.
4115 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4116 auto status = dnn->DoRnnBackward(
4117 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4118 input_h_data, input_c_desc, input_c_data, params, output_desc,
4119 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4120 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4121 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4122 params_backprop_data, reserve_space_data, workspace_allocator,
4123 output_profile_result);
4124 if (!status && !output_profile_result) {
4125 SetError();
4126 }
4127 } else {
4128 SetError();
4129 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4130 }
4131 return *this;
4132 }
4133
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4134 Stream &Stream::ThenRnnBackward(
4135 const dnn::RnnDescriptor &rnn_desc,
4136 const dnn::RnnSequenceTensorDescriptor &input_desc,
4137 const DeviceMemory<float> &input_data,
4138 const DeviceMemory<int> &seq_lengths_data,
4139 const dnn::RnnStateTensorDescriptor &input_h_desc,
4140 const DeviceMemory<float> &input_h_data,
4141 const dnn::RnnStateTensorDescriptor &input_c_desc,
4142 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4143 const dnn::RnnSequenceTensorDescriptor &output_desc,
4144 const DeviceMemory<float> &output_data,
4145 const dnn::RnnStateTensorDescriptor &output_h_desc,
4146 const DeviceMemory<float> &output_h_data,
4147 const dnn::RnnStateTensorDescriptor &output_c_desc,
4148 const DeviceMemory<float> &output_c_data,
4149 const DeviceMemory<float> &output_backprop_data,
4150 const DeviceMemory<float> &output_h_backprop_data,
4151 const DeviceMemory<float> &output_c_backprop_data,
4152 DeviceMemory<float> *input_backprop_data,
4153 DeviceMemory<float> *input_h_backprop_data,
4154 DeviceMemory<float> *input_c_backprop_data,
4155 DeviceMemory<float> *params_backprop_data,
4156 DeviceMemory<uint8> *reserve_space_data,
4157 ScratchAllocator *workspace_allocator,
4158 dnn::ProfileResult *output_profile_result) {
4159 // TODO(zhengxq): add VLOG PARAM calls.
4160 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4161 auto status = dnn->DoRnnBackward(
4162 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4163 input_h_data, input_c_desc, input_c_data, params, output_desc,
4164 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4165 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4166 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4167 params_backprop_data, reserve_space_data, workspace_allocator,
4168 output_profile_result);
4169 if (!status && !output_profile_result) {
4170 SetError();
4171 }
4172 } else {
4173 SetError();
4174 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4175 }
4176 return *this;
4177 }
4178
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4179 Stream &Stream::ThenRnnBackward(
4180 const dnn::RnnDescriptor &rnn_desc,
4181 const dnn::RnnSequenceTensorDescriptor &input_desc,
4182 const DeviceMemory<double> &input_data,
4183 const DeviceMemory<int> &seq_lengths_data,
4184 const dnn::RnnStateTensorDescriptor &input_h_desc,
4185 const DeviceMemory<double> &input_h_data,
4186 const dnn::RnnStateTensorDescriptor &input_c_desc,
4187 const DeviceMemory<double> &input_c_data,
4188 const DeviceMemory<double> ¶ms,
4189 const dnn::RnnSequenceTensorDescriptor &output_desc,
4190 const DeviceMemory<double> &output_data,
4191 const dnn::RnnStateTensorDescriptor &output_h_desc,
4192 const DeviceMemory<double> &output_h_data,
4193 const dnn::RnnStateTensorDescriptor &output_c_desc,
4194 const DeviceMemory<double> &output_c_data,
4195 const DeviceMemory<double> &output_backprop_data,
4196 const DeviceMemory<double> &output_h_backprop_data,
4197 const DeviceMemory<double> &output_c_backprop_data,
4198 DeviceMemory<double> *input_backprop_data,
4199 DeviceMemory<double> *input_h_backprop_data,
4200 DeviceMemory<double> *input_c_backprop_data,
4201 DeviceMemory<double> *params_backprop_data,
4202 DeviceMemory<uint8> *reserve_space_data,
4203 ScratchAllocator *workspace_allocator,
4204 dnn::ProfileResult *output_profile_result) {
4205 // TODO(zhengxq): add VLOG PARAM calls.
4206 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4207 auto status = dnn->DoRnnBackward(
4208 this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4209 input_h_data, input_c_desc, input_c_data, params, output_desc,
4210 output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4211 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4212 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4213 params_backprop_data, reserve_space_data, workspace_allocator,
4214 output_profile_result);
4215 if (!status && !output_profile_result) {
4216 SetError();
4217 }
4218 } else {
4219 SetError();
4220 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4221 }
4222 return *this;
4223 }
4224
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)4225 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
4226 const DeviceMemory<float> &probs_data,
4227 absl::Span<const int> labels_data,
4228 absl::Span<const int> labels_lengths_data,
4229 absl::Span<const int> input_lengths_data,
4230 DeviceMemory<float> *costs_data,
4231 const dnn::RnnStateTensorDescriptor &grads_desc,
4232 DeviceMemory<float> *grads_data,
4233 ScratchAllocator *workspace_allocator) {
4234 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4235 DeviceMemory<uint8> scratch_memory;
4236 int ctc_loss_algo_id;
4237 auto status =
4238 dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
4239 labels_data, labels_lengths_data,
4240 input_lengths_data, workspace_allocator,
4241 &scratch_memory, &ctc_loss_algo_id)
4242 .ok();
4243 if (status) {
4244 status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
4245 labels_lengths_data, input_lengths_data,
4246 costs_data, grads_desc, grads_data,
4247 &scratch_memory, ctc_loss_algo_id);
4248 }
4249 if (!status) {
4250 SetError();
4251 }
4252 } else {
4253 SetErrorAndLogNoDnnSupport();
4254 }
4255 return *this;
4256 }
4257
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4258 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4259 dnn::DataType input_type,
4260 const DeviceMemoryBase &input_data,
4261 const dnn::BatchDescriptor &output_desc,
4262 dnn::DataType output_type, float scale,
4263 DeviceMemoryBase *output_data) {
4264 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4265 PARAM(output_desc), PARAM(output_type), PARAM(scale),
4266 PARAM(output_data));
4267 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4268 CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
4269 output_desc, output_type, scale,
4270 output_data));
4271 } else {
4272 SetErrorAndLogNoDnnSupport();
4273 }
4274 return *this;
4275 }
4276
ThenDoHostCallback(std::function<void ()> callback)4277 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4278 VLOG_CALL(PARAM(callback));
4279
4280 if (!ok()) {
4281 LOG(INFO) << DebugStreamPointers()
4282 << " was in error state before adding host callback";
4283 }
4284 CheckError(parent_->HostCallback(this, std::move(callback)));
4285 return *this;
4286 }
4287
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)4288 Stream &Stream::ThenDoHostCallbackWithStatus(
4289 std::function<port::Status()> callback) {
4290 VLOG_CALL(PARAM(callback));
4291
4292 if (!ok()) {
4293 LOG(INFO) << DebugStreamPointers()
4294 << " was in error state before adding host callback";
4295 }
4296 CheckError(parent_->HostCallback(this, std::move(callback)));
4297 return *this;
4298 }
4299
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)4300 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
4301 std::function<void()> callback) {
4302 VLOG_CALL(PARAM(callback));
4303
4304 if (!ok()) {
4305 LOG(INFO) << DebugStreamPointers()
4306 << " was in error state before adding callback to be run after "
4307 "next block-host-until-done.";
4308 }
4309 absl::MutexLock lock(&mu_);
4310 after_block_host_until_done_callbacks_.push_back(std::move(callback));
4311 return *this;
4312 }
4313
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4314 Stream &Stream::ThenFft(fft::Plan *plan,
4315 const DeviceMemory<std::complex<float>> &input,
4316 DeviceMemory<std::complex<float>> *output) {
4317 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4318
4319 if (fft::FftSupport *fft = parent_->AsFft()) {
4320 CheckError(fft->DoFft(this, plan, input, output));
4321 } else {
4322 SetError();
4323 LOG(INFO) << DebugStreamPointers()
4324 << " attempting to perform FFT operation using StreamExecutor"
4325 " without FFT support";
4326 }
4327 return *this;
4328 }
4329
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4330 Stream &Stream::ThenFft(fft::Plan *plan,
4331 const DeviceMemory<std::complex<double>> &input,
4332 DeviceMemory<std::complex<double>> *output) {
4333 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4334
4335 if (fft::FftSupport *fft = parent_->AsFft()) {
4336 CheckError(fft->DoFft(this, plan, input, output));
4337 } else {
4338 SetError();
4339 LOG(INFO) << DebugStreamPointers()
4340 << " attempting to perform FFT operation using StreamExecutor"
4341 " without FFT support";
4342 }
4343 return *this;
4344 }
4345
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4346 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4347 DeviceMemory<std::complex<float>> *output) {
4348 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4349
4350 if (fft::FftSupport *fft = parent_->AsFft()) {
4351 CheckError(fft->DoFft(this, plan, input, output));
4352 } else {
4353 SetError();
4354 LOG(INFO) << DebugStreamPointers()
4355 << " attempting to perform FFT operation using StreamExecutor"
4356 " without FFT support";
4357 }
4358 return *this;
4359 }
4360
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4361 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4362 DeviceMemory<std::complex<double>> *output) {
4363 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4364
4365 if (fft::FftSupport *fft = parent_->AsFft()) {
4366 CheckError(fft->DoFft(this, plan, input, output));
4367 } else {
4368 SetError();
4369 LOG(INFO) << DebugStreamPointers()
4370 << " attempting to perform FFT operation using StreamExecutor"
4371 " without FFT support";
4372 }
4373 return *this;
4374 }
4375
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)4376 Stream &Stream::ThenFft(fft::Plan *plan,
4377 const DeviceMemory<std::complex<float>> &input,
4378 DeviceMemory<float> *output) {
4379 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4380
4381 if (fft::FftSupport *fft = parent_->AsFft()) {
4382 CheckError(fft->DoFft(this, plan, input, output));
4383 } else {
4384 SetError();
4385 LOG(INFO) << DebugStreamPointers()
4386 << " attempting to perform FFT operation using StreamExecutor"
4387 " without FFT support";
4388 }
4389 return *this;
4390 }
4391
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)4392 Stream &Stream::ThenFft(fft::Plan *plan,
4393 const DeviceMemory<std::complex<double>> &input,
4394 DeviceMemory<double> *output) {
4395 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4396
4397 if (fft::FftSupport *fft = parent_->AsFft()) {
4398 CheckError(fft->DoFft(this, plan, input, output));
4399 } else {
4400 SetError();
4401 LOG(INFO) << DebugStreamPointers()
4402 << " attempting to perform FFT operation using StreamExecutor"
4403 " without FFT support";
4404 }
4405 return *this;
4406 }
4407
4408 // It looks confusing, but all this is doing is inserting a callback at the
4409 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)4410 Stream &Stream::ThenEnqueueOnBackgroundThread(
4411 std::function<void(StreamExecutor *)> task) {
4412 VLOG_CALL(PARAM(task));
4413
4414 StreamExecutor *stream_executor = this->parent_;
4415 std::function<void()> bound_task = std::bind(task, stream_executor);
4416
4417 return ThenDoHostCallback([stream_executor, bound_task]() {
4418 stream_executor->EnqueueOnBackgroundThread(bound_task);
4419 });
4420 }
4421
BlockHostUntilDone()4422 port::Status Stream::BlockHostUntilDone() {
4423 VLOG_CALL();
4424
4425 if (!ok()) {
4426 port::Status status = port::Status(
4427 port::error::INTERNAL,
4428 "stream did not block host until done; was already in an error state");
4429 LOG(INFO) << DebugStreamPointers() << " " << status;
4430 return status;
4431 }
4432
4433 temporary_memory_manager_.DeallocateFinalizedTemporaries();
4434
4435 port::Status error = parent_->BlockHostUntilDone(this);
4436 CheckError(error.ok());
4437
4438 RunAfterBlockHostUntilDoneCallbacks();
4439 return error;
4440 }
4441
RunAfterBlockHostUntilDoneCallbacks()4442 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
4443 std::vector<std::function<void()>> callbacks;
4444 {
4445 absl::MutexLock lock(&mu_);
4446 std::swap(callbacks, after_block_host_until_done_callbacks_);
4447 }
4448 for (const auto &fn : callbacks) {
4449 fn();
4450 }
4451 }
4452
DebugStreamPointers() const4453 std::string Stream::DebugStreamPointers() const {
4454 // Relies on the ToVlogString(const void*) overload above.
4455 return absl::StrCat("[stream=", ToVlogString(this),
4456 ",impl=", ToVlogString(implementation_.get()), "]");
4457 }
4458
CheckStatus(port::Status status)4459 void Stream::CheckStatus(port::Status status) {
4460 if (status.ok()) {
4461 return;
4462 }
4463 LOG(ERROR) << status;
4464 absl::MutexLock lock(&mu_);
4465 status_ = status;
4466 }
4467
4468 } // namespace stream_executor
4469