1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/stream.h"
17
18 #include "tensorflow/stream_executor/platform/port.h"
19
20 #include "absl/strings/str_cat.h"
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/host_or_device_scalar.h"
24 #include "tensorflow/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/stream_executor/platform.h"
26 #include "tensorflow/stream_executor/platform/logging.h"
27 #include "tensorflow/stream_executor/rng.h"
28 #include "tensorflow/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
30
31 namespace stream_executor {
32
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39 return descriptor.ToShortString();
40 }
41
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43 return descriptor.ToShortString();
44 }
45
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47 return descriptor.ToShortString();
48 }
49
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51 return descriptor.ToShortString();
52 }
53
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55 return descriptor.ToShortString();
56 }
57
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59 return dnn::ActivationModeString(mode);
60 }
61
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63 return algo_config.ToString();
64 }
65
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67 return dnn::ElementwiseOperationString(op);
68 }
69
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71 return dnn::QuantizedActivationModeString(mode);
72 }
73
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83 return blas::ComputationTypeString(ty);
84 }
85
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87 if (ptr == nullptr) {
88 return "null";
89 }
90
91 // StrCat does not convert pointers to text.
92 std::ostringstream out;
93 out << ptr;
94 return out.str();
95 }
96
97 template <class T>
ToVlogString(const std::complex<T> & c)98 string ToVlogString(const std::complex<T> &c) {
99 // StrCat does not convert std::complex to text.
100 std::ostringstream out;
101 out << c;
102 return out.str();
103 }
104
105 template <class T>
ToVlogString(const std::function<T> & f)106 string ToVlogString(const std::function<T> &f) {
107 return f == nullptr ? "null" : "<non-null function>";
108 }
109
ToVlogString(const DeviceMemoryBase & memory)110 string ToVlogString(const DeviceMemoryBase &memory) {
111 return ToVlogString(memory.opaque());
112 }
113
ToVlogString(const DeviceMemoryBase * memory)114 string ToVlogString(const DeviceMemoryBase *memory) {
115 return memory == nullptr ? "null" : ToVlogString(*memory);
116 }
117
ToVlogString(const Eigen::half & h)118 string ToVlogString(const Eigen::half &h) {
119 return absl::StrCat(static_cast<float>(h));
120 }
121
ToVlogString(int i)122 string ToVlogString(int i) { return absl::StrCat(i); }
123
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return absl::StrCat(i); }
125
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return absl::StrCat(i); }
127
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return absl::StrCat(i); }
129
ToVlogString(float f)130 string ToVlogString(float f) { return absl::StrCat(f); }
131
ToVlogString(double d)132 string ToVlogString(double d) { return absl::StrCat(d); }
133
134 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)135 string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
136 if (memory_or_constant.is_pointer()) {
137 return ToVlogString(memory_or_constant.pointer());
138 }
139 return ToVlogString(memory_or_constant.value());
140 }
141
142 template <class T>
ToVlogString(port::ArraySlice<T> elements)143 string ToVlogString(port::ArraySlice<T> elements) {
144 string str = absl::StrCat(
145 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
146 elements.size(), "]{");
147 const char *separator = "";
148 size_t max_to_show = std::numeric_limits<size_t>::max();
149 if (!VLOG_IS_ON(2)) {
150 max_to_show = 5;
151 } else if (!VLOG_IS_ON(3)) {
152 max_to_show = 20;
153 } else if (!VLOG_IS_ON(11)) {
154 max_to_show = 1000;
155 }
156 for (size_t i = 0; i < elements.size(); ++i) {
157 if (i == max_to_show) {
158 str += ", ...";
159 break;
160 }
161 absl::StrAppend(&str, separator, ToVlogString(elements[i]));
162 separator = ", ";
163 }
164 str += "}";
165 return str;
166 }
167
168 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)169 string ToVlogString(port::MutableArraySlice<T> elements) {
170 return ToVlogString(port::ArraySlice<T>(elements));
171 }
172
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)173 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
174 switch (depth_to_space_layout) {
175 case dnn::DepthToSpaceLayout::DepthHeightWidth:
176 return "DepthToSpaceLayout::DepthHeightWidth";
177 }
178 return "unknown DepthToSpaceLayout";
179 }
180
ToVlogString(dnn::DataType data_type)181 string ToVlogString(dnn::DataType data_type) {
182 switch (data_type) {
183 case dnn::DataType::kFloat:
184 return "dnn::DataType::kFloat";
185 case dnn::DataType::kDouble:
186 return "dnn::DataType::kDouble";
187 case dnn::DataType::kHalf:
188 return "dnn::DataType::kHalf";
189 case dnn::DataType::kInt8:
190 return "dnn::DataType::kInt8";
191 case dnn::DataType::kInt32:
192 return "dnn::DataType::kInt32";
193 default:
194 return "unknown DataType";
195 }
196 }
197
198 // Used together with PARAM to VLOG calls made to the stream. Intended
199 // to be used like this:
200 //
201 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
202 //
203 // where a and b are the parameters to MyFunction.
204 //
205 // See VLOG_CALL for a short-hand for this. This way of doing it saves
206 // a tremendous amount of boilerplate code given how many functions
207 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)208 string CallStr(const char *function_name, Stream *stream,
209 std::vector<std::pair<const char *, string>> params) {
210 // Do not call this function unless VLOG is on since just
211 // constructing all the strings in params is expensive.
212 CHECK(VLOG_IS_ON(1));
213
214 string str = absl::StrCat(stream->DebugStreamPointers(),
215 " Called Stream::", function_name, "(");
216 const char *separator = "";
217 for (const auto ¶m : params) {
218 absl::StrAppend(&str, separator, param.first, "=", param.second);
219 separator = ", ";
220 }
221 absl::StrAppend(&str, ")");
222 if (VLOG_IS_ON(10)) {
223 absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
224 }
225 return str;
226 }
227
228 // Use this macro to avoid having to type every parameter twice to log
229 // it with VLOG and CallStr.
230 #define PARAM(parameter) \
231 { #parameter, ToVlogString(parameter) }
232
233 // Use this macro to avoid having to type out the name of each
234 // function and to save some boilerplate. Intended to be used like this:
235 //
236 // VLOG_CALL(PARAM(a), PARAM(b))
237 //
238 // This saves a tremendous amount of boilerplate compared to the alternative:
239 //
240 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
241 // << ", b=" << ToVlogString(b);
242 //
243 // Note here that most of the parameter names are not short and that
244 // most of the functions take many more than 2 parameters.
245 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
246
247 } // namespace
248
Stream(StreamExecutor * parent)249 Stream::Stream(StreamExecutor *parent)
250 : parent_(parent),
251 implementation_(parent->implementation()->GetStreamImplementation()),
252 allocated_(false),
253 ok_(false),
254 temporary_memory_manager_(this) {
255 VLOG_CALL(PARAM(parent));
256 }
257
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)258 Stream::Stream(StreamExecutor *parent,
259 internal::StreamInterface *implementation)
260 : parent_(parent),
261 implementation_(implementation),
262 allocated_(false),
263 ok_(false),
264 temporary_memory_manager_(this) {
265 VLOG_CALL(PARAM(parent), PARAM(implementation));
266 }
267
~Stream()268 Stream::~Stream() {
269 VLOG_CALL();
270
271 // Ensure the stream is completed.
272 auto status = BlockHostUntilDone();
273 if (!status.ok()) {
274 LOG(WARNING) << "Error blocking host until done in stream destructor: "
275 << status;
276 }
277 temporary_memory_manager_.ForceDeallocateAll();
278 RunAfterBlockHostUntilDoneCallbacks();
279
280 if (allocated_) {
281 parent_->DeallocateStream(this);
282 }
283 }
284
RefreshStatus()285 port::Status Stream::RefreshStatus() {
286 port::Status status = parent_->GetStatus(this);
287 CheckStatus(status);
288 return status;
289 }
290
Init()291 Stream &Stream::Init() {
292 VLOG_CALL();
293
294 absl::MutexLock lock(&mu_);
295 CHECK_EQ(false, allocated_)
296 << "stream appears to already have been initialized";
297 CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
298
299 if (parent_->AllocateStream(this)) {
300 // Successful initialization!
301 allocated_ = true;
302 ok_ = true;
303 } else {
304 LOG(ERROR) << "failed to allocate stream during initialization";
305 }
306
307 return *this;
308 }
309
InitTimer(Timer * timer)310 Stream &Stream::InitTimer(Timer *timer) {
311 VLOG_CALL(PARAM(timer));
312
313 if (ok()) {
314 CheckError(parent_->AllocateTimer(timer));
315 } else {
316 LOG(INFO) << "did not allocate timer: " << timer;
317 }
318 return *this;
319 }
320
InitWithTimer(Timer * timer)321 Stream &Stream::InitWithTimer(Timer *timer) {
322 VLOG_CALL(PARAM(timer));
323
324 return Init().InitTimer(timer);
325 }
326
ThenRecordEvent(Event * event)327 Stream &Stream::ThenRecordEvent(Event *event) {
328 VLOG_CALL(PARAM(event));
329
330 port::Status status = parent_->RecordEvent(this, event);
331 if (!status.ok()) {
332 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
333 << "; not marking stream as bad, as the Event object may be "
334 << "at fault. Monitor for further errors.";
335 }
336
337 return *this;
338 }
339
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)340 Stream &Stream::ThenBatchNormalizationForward(
341 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
342 const DeviceMemory<float> &offset,
343 const DeviceMemory<float> &estimated_mean,
344 const DeviceMemory<float> &estimated_variance,
345 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
346 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
347 const double exponential_average_factor,
348 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
349 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
350 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
351 bool is_training,
352 std::function<const DeviceMemory<float> &()> var_to_inv_var,
353 std::function<void()> inv_var_to_var,
354 ScratchAllocator *reserve_space_allocator,
355 ScratchAllocator *workspace_allocator) {
356 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
357 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
358 if (ok()) {
359 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
360 CheckError(dnn->DoBatchNormalizationForward(
361 this, x, scale, offset, estimated_mean, estimated_variance,
362 side_input, x_desc, scale_offset_desc, epsilon,
363 exponential_average_factor, activation_mode, y, batch_mean, batch_var,
364 saved_mean, saved_inv_var, is_training, reserve_space_allocator,
365 workspace_allocator, std::move(var_to_inv_var),
366 std::move(inv_var_to_var)));
367 } else {
368 SetErrorAndLogNoDnnSupport();
369 }
370 }
371 return *this;
372 }
373
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)374 Stream &Stream::ThenBatchNormalizationBackward(
375 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
376 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
377 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
378 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
379 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
380 DeviceMemory<float> *offset_backprop,
381 DeviceMemory<uint8> *reserve_space_data,
382 ScratchAllocator *workspace_allocator) {
383 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
384 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
385 PARAM(scale_backprop), PARAM(offset_backprop));
386 if (ok()) {
387 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
388 CheckError(dnn->DoBatchNormalizationBackward(
389 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
390 epsilon, x_backprop, scale_backprop, offset_backprop,
391 reserve_space_data, workspace_allocator));
392 } else {
393 SetErrorAndLogNoDnnSupport();
394 }
395 }
396 return *this;
397 }
398
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)399 Stream &Stream::ThenBatchNormalizationForward(
400 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
401 const DeviceMemory<float> &offset,
402 const DeviceMemory<float> &estimated_mean,
403 const DeviceMemory<float> &estimated_variance,
404 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
405 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
406 const double exponential_average_factor,
407 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
408 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
409 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
410 bool is_training,
411 std::function<const DeviceMemory<float> &()> var_to_inv_var,
412 std::function<void()> inv_var_to_var,
413 ScratchAllocator *reserve_space_allocator,
414 ScratchAllocator *workspace_allocator) {
415 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
416 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
417 if (ok()) {
418 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
419 CheckError(dnn->DoBatchNormalizationForward(
420 this, x, scale, offset, estimated_mean, estimated_variance,
421 side_input, x_desc, scale_offset_desc, epsilon,
422 exponential_average_factor, activation_mode, y, batch_mean, batch_var,
423 saved_mean, saved_inv_var, is_training, reserve_space_allocator,
424 workspace_allocator, std::move(var_to_inv_var),
425 std::move(inv_var_to_var)));
426 } else {
427 SetErrorAndLogNoDnnSupport();
428 }
429 }
430 return *this;
431 }
432
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)433 Stream &Stream::ThenBatchNormalizationBackward(
434 const DeviceMemory<Eigen::half> &y_backprop,
435 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
436 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
437 const dnn::BatchDescriptor &x_desc,
438 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
439 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
440 DeviceMemory<float> *offset_backprop,
441 DeviceMemory<uint8> *reserve_space_data,
442 ScratchAllocator *workspace_allocator) {
443 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
444 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
445 PARAM(scale_backprop), PARAM(offset_backprop));
446 if (ok()) {
447 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
448 CheckError(dnn->DoBatchNormalizationBackward(
449 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
450 epsilon, x_backprop, scale_backprop, offset_backprop,
451 reserve_space_data, workspace_allocator));
452
453 } else {
454 SetErrorAndLogNoDnnSupport();
455 }
456 }
457 return *this;
458 }
459
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)460 Stream &Stream::ThenFusedConvolveWithAlgorithm(
461 const dnn::BatchDescriptor &conv_input_descriptor,
462 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
463 const dnn::FilterDescriptor &filter_descriptor,
464 const DeviceMemory<double> &filter_data,
465 const dnn::ConvolutionDescriptor &convolution_descriptor,
466 const DeviceMemory<double> &side_input_data, double side_input_scale,
467 const dnn::BatchDescriptor &bias_descriptor,
468 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
469 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
470 ScratchAllocator *scratch_allocator,
471 const dnn::AlgorithmConfig &algorithm_config,
472 dnn::ProfileResult *output_profile_result) {
473 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
474 PARAM(conv_input_scale), PARAM(filter_descriptor),
475 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
476 PARAM(side_input_data), PARAM(side_input_scale),
477 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
478 PARAM(algorithm_config));
479
480 if (ok()) {
481 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
482 auto status = dnn->DoFusedConvolve(
483 this, conv_input_descriptor, conv_input_data, conv_input_scale,
484 filter_descriptor, filter_data, convolution_descriptor,
485 side_input_data, side_input_scale, bias_descriptor, biases,
486 activation_mode, output_descriptor, output, scratch_allocator,
487 algorithm_config, output_profile_result);
488 if (!status && !output_profile_result) {
489 SetError();
490 }
491 } else {
492 SetErrorAndLogNoDnnSupport();
493 }
494 }
495 return *this;
496 }
497
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)498 Stream &Stream::ThenFusedConvolveWithAlgorithm(
499 const dnn::BatchDescriptor &conv_input_descriptor,
500 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
501 const dnn::FilterDescriptor &filter_descriptor,
502 const DeviceMemory<float> &filter_data,
503 const dnn::ConvolutionDescriptor &convolution_descriptor,
504 const DeviceMemory<float> &side_input_data, float side_input_scale,
505 const dnn::BatchDescriptor &bias_descriptor,
506 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
507 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
508 ScratchAllocator *scratch_allocator,
509 const dnn::AlgorithmConfig &algorithm_config,
510 dnn::ProfileResult *output_profile_result) {
511 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
512 PARAM(conv_input_scale), PARAM(filter_descriptor),
513 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
514 PARAM(side_input_data), PARAM(side_input_scale),
515 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
516 PARAM(algorithm_config));
517
518 if (ok()) {
519 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
520 auto status = dnn->DoFusedConvolve(
521 this, conv_input_descriptor, conv_input_data, conv_input_scale,
522 filter_descriptor, filter_data, convolution_descriptor,
523 side_input_data, side_input_scale, bias_descriptor, biases,
524 activation_mode, output_descriptor, output, scratch_allocator,
525 algorithm_config, output_profile_result);
526 if (!status && !output_profile_result) {
527 SetError();
528 }
529 } else {
530 SetErrorAndLogNoDnnSupport();
531 }
532 }
533 return *this;
534 }
535
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)536 Stream &Stream::ThenFusedConvolveWithAlgorithm(
537 const dnn::BatchDescriptor &conv_input_descriptor,
538 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
539 const dnn::FilterDescriptor &filter_descriptor,
540 const DeviceMemory<Eigen::half> &filter_data,
541 const dnn::ConvolutionDescriptor &convolution_descriptor,
542 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
543 const dnn::BatchDescriptor &bias_descriptor,
544 const DeviceMemory<Eigen::half> &biases,
545 dnn::ActivationMode activation_mode,
546 const dnn::BatchDescriptor &output_descriptor,
547 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
548 const dnn::AlgorithmConfig &algorithm_config,
549 dnn::ProfileResult *output_profile_result) {
550 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
551 PARAM(conv_input_scale), PARAM(filter_descriptor),
552 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
553 PARAM(side_input_data), PARAM(side_input_scale),
554 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
555 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
556
557 if (ok()) {
558 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
559 auto status = dnn->DoFusedConvolve(
560 this, conv_input_descriptor, conv_input_data, conv_input_scale,
561 filter_descriptor, filter_data, convolution_descriptor,
562 side_input_data, side_input_scale, bias_descriptor, biases,
563 activation_mode, output_descriptor, output, scratch_allocator,
564 algorithm_config, output_profile_result);
565 if (!status && !output_profile_result) {
566 SetError();
567 }
568 } else {
569 SetErrorAndLogNoDnnSupport();
570 }
571 }
572 return *this;
573 }
574
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)575 Stream &Stream::ThenFusedConvolveWithAlgorithm(
576 const dnn::BatchDescriptor &conv_input_descriptor,
577 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
578 const dnn::FilterDescriptor &filter_descriptor,
579 const DeviceMemory<int8> &filter_data,
580 const dnn::ConvolutionDescriptor &convolution_descriptor,
581 const DeviceMemory<int8> &side_input_data, float side_input_scale,
582 const dnn::BatchDescriptor &bias_descriptor,
583 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
584 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
585 ScratchAllocator *scratch_allocator,
586 const dnn::AlgorithmConfig &algorithm_config,
587 dnn::ProfileResult *output_profile_result) {
588 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
589 PARAM(conv_input_scale), PARAM(filter_descriptor),
590 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
591 PARAM(side_input_data), PARAM(side_input_scale),
592 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
593 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
594
595 if (ok()) {
596 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
597 auto status = dnn->DoFusedConvolve(
598 this, conv_input_descriptor, conv_input_data, conv_input_scale,
599 filter_descriptor, filter_data, convolution_descriptor,
600 side_input_data, side_input_scale, bias_descriptor, biases,
601 activation_mode, output_descriptor, output, scratch_allocator,
602 algorithm_config, output_profile_result);
603 if (!status && !output_profile_result) {
604 SetError();
605 }
606 } else {
607 SetErrorAndLogNoDnnSupport();
608 }
609 }
610 return *this;
611 }
612
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)613 Stream &Stream::ThenFusedConvolveWithAlgorithm(
614 const dnn::BatchDescriptor &conv_input_descriptor,
615 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
616 const dnn::FilterDescriptor &filter_descriptor,
617 const DeviceMemory<int8> &filter_data,
618 const dnn::ConvolutionDescriptor &convolution_descriptor,
619 const DeviceMemory<float> &side_input_data, float side_input_scale,
620 const dnn::BatchDescriptor &bias_descriptor,
621 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
622 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
623 ScratchAllocator *scratch_allocator,
624 const dnn::AlgorithmConfig &algorithm_config,
625 dnn::ProfileResult *output_profile_result) {
626 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
627 PARAM(conv_input_scale), PARAM(filter_descriptor),
628 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
629 PARAM(side_input_data), PARAM(side_input_scale),
630 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
631 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
632
633 if (ok()) {
634 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
635 auto status = dnn->DoFusedConvolve(
636 this, conv_input_descriptor, conv_input_data, conv_input_scale,
637 filter_descriptor, filter_data, convolution_descriptor,
638 side_input_data, side_input_scale, bias_descriptor, biases,
639 activation_mode, output_descriptor, output, scratch_allocator,
640 algorithm_config, output_profile_result);
641 if (!status && !output_profile_result) {
642 SetError();
643 }
644 } else {
645 SetErrorAndLogNoDnnSupport();
646 }
647 }
648 return *this;
649 }
650
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)651 Stream &Stream::ThenConvolveWithAlgorithm(
652 const dnn::BatchDescriptor &input_descriptor,
653 const DeviceMemory<double> &input_data,
654 const dnn::FilterDescriptor &filter_descriptor,
655 const DeviceMemory<double> &filter_data,
656 const dnn::ConvolutionDescriptor &convolution_descriptor,
657 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
658 ScratchAllocator *scratch_allocator,
659 const dnn::AlgorithmConfig &algorithm_config,
660 dnn::ProfileResult *output_profile_result) {
661 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
662 PARAM(filter_descriptor), PARAM(filter_data),
663 PARAM(convolution_descriptor), PARAM(output_descriptor),
664 PARAM(output), PARAM(algorithm_config));
665
666 if (ok()) {
667 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
668 DeviceMemory<uint8> scratch_memory;
669 dnn::AlgorithmDesc algorithm_desc;
670 auto status =
671 dnn->PrepareForConvolution(
672 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
673 input_data, filter_descriptor, filter_data, output_descriptor,
674 *output, convolution_descriptor, algorithm_config,
675 scratch_allocator, &algorithm_desc, &scratch_memory)
676 .ok();
677 if (status) {
678 status = dnn->DoConvolve(
679 this, input_descriptor, input_data, filter_descriptor, filter_data,
680 convolution_descriptor, output_descriptor, output, algorithm_desc,
681 &scratch_memory, output_profile_result);
682 }
683 if (!status && !output_profile_result) {
684 SetError();
685 }
686 } else {
687 SetErrorAndLogNoDnnSupport();
688 }
689 }
690 return *this;
691 }
692
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)693 Stream &Stream::ThenConvolveWithAlgorithm(
694 const dnn::BatchDescriptor &input_descriptor,
695 const DeviceMemory<float> &input_data,
696 const dnn::FilterDescriptor &filter_descriptor,
697 const DeviceMemory<float> &filter_data,
698 const dnn::ConvolutionDescriptor &convolution_descriptor,
699 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
700 ScratchAllocator *scratch_allocator,
701 const dnn::AlgorithmConfig &algorithm_config,
702 dnn::ProfileResult *output_profile_result) {
703 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
704 PARAM(filter_descriptor), PARAM(filter_data),
705 PARAM(convolution_descriptor), PARAM(output_descriptor),
706 PARAM(output), PARAM(algorithm_config));
707
708 if (ok()) {
709 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
710 DeviceMemory<uint8> scratch_memory;
711 dnn::AlgorithmDesc algorithm_desc;
712 auto status =
713 dnn->PrepareForConvolution(
714 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
715 input_data, filter_descriptor, filter_data, output_descriptor,
716 *output, convolution_descriptor, algorithm_config,
717 scratch_allocator, &algorithm_desc, &scratch_memory)
718 .ok();
719 if (status) {
720 status = dnn->DoConvolve(
721 this, input_descriptor, input_data, filter_descriptor, filter_data,
722 convolution_descriptor, output_descriptor, output, algorithm_desc,
723 &scratch_memory, output_profile_result);
724 }
725 if (!status && !output_profile_result) {
726 SetError();
727 }
728 } else {
729 SetErrorAndLogNoDnnSupport();
730 }
731 }
732 return *this;
733 }
734
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)735 Stream &Stream::ThenConvolveWithAlgorithm(
736 const dnn::BatchDescriptor &input_descriptor,
737 const DeviceMemory<Eigen::half> &input_data,
738 const dnn::FilterDescriptor &filter_descriptor,
739 const DeviceMemory<Eigen::half> &filter_data,
740 const dnn::ConvolutionDescriptor &convolution_descriptor,
741 const dnn::BatchDescriptor &output_descriptor,
742 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
743 const dnn::AlgorithmConfig &algorithm_config,
744 dnn::ProfileResult *output_profile_result) {
745 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
746 PARAM(filter_descriptor), PARAM(filter_data),
747 PARAM(convolution_descriptor), PARAM(output_descriptor),
748 PARAM(output), PARAM(algorithm_config));
749
750 if (ok()) {
751 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
752 DeviceMemory<uint8> scratch_memory;
753 dnn::AlgorithmDesc algorithm_desc;
754 auto status =
755 dnn->PrepareForConvolution(
756 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
757 input_data, filter_descriptor, filter_data, output_descriptor,
758 *output, convolution_descriptor, algorithm_config,
759 scratch_allocator, &algorithm_desc, &scratch_memory)
760 .ok();
761 if (status) {
762 status = dnn->DoConvolve(
763 this, input_descriptor, input_data, filter_descriptor, filter_data,
764 convolution_descriptor, output_descriptor, output, algorithm_desc,
765 &scratch_memory, output_profile_result);
766 }
767 if (!status && !output_profile_result) {
768 SetError();
769 }
770 } else {
771 SetErrorAndLogNoDnnSupport();
772 }
773 }
774 return *this;
775 }
776
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<int8> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)777 Stream &Stream::ThenConvolveWithAlgorithm(
778 const dnn::BatchDescriptor &input_descriptor,
779 const DeviceMemory<int8> &input_data,
780 const dnn::FilterDescriptor &filter_descriptor,
781 const DeviceMemory<int8> &filter_data,
782 const dnn::ConvolutionDescriptor &convolution_descriptor,
783 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
784 ScratchAllocator *scratch_allocator,
785 const dnn::AlgorithmConfig &algorithm_config,
786 dnn::ProfileResult *output_profile_result) {
787 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
788 PARAM(filter_descriptor), PARAM(filter_data),
789 PARAM(convolution_descriptor), PARAM(output_descriptor),
790 PARAM(output), PARAM(algorithm_config));
791
792 if (ok()) {
793 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794 DeviceMemory<uint8> scratch_memory;
795 dnn::AlgorithmDesc algorithm_desc;
796 auto status =
797 dnn->PrepareForConvolution(
798 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
799 input_data, filter_descriptor, filter_data, output_descriptor,
800 *output, convolution_descriptor, algorithm_config,
801 scratch_allocator, &algorithm_desc, &scratch_memory)
802 .ok();
803 if (status) {
804 status = dnn->DoConvolve(
805 this, input_descriptor, input_data, filter_descriptor, filter_data,
806 convolution_descriptor, output_descriptor, output, algorithm_desc,
807 &scratch_memory, output_profile_result);
808 }
809 if (!status && !output_profile_result) {
810 SetError();
811 }
812 } else {
813 SetErrorAndLogNoDnnSupport();
814 }
815 }
816 return *this;
817 }
818
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<int8> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)819 Stream &Stream::ThenConvolveWithAlgorithm(
820 const dnn::BatchDescriptor &input_descriptor,
821 const DeviceMemory<int8> &input_data,
822 const dnn::FilterDescriptor &filter_descriptor,
823 const DeviceMemory<int8> &filter_data,
824 const dnn::ConvolutionDescriptor &convolution_descriptor,
825 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
826 ScratchAllocator *scratch_allocator,
827 const dnn::AlgorithmConfig &algorithm_config,
828 dnn::ProfileResult *output_profile_result) {
829 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
830 PARAM(filter_descriptor), PARAM(filter_data),
831 PARAM(convolution_descriptor), PARAM(output_descriptor),
832 PARAM(output), PARAM(algorithm_config));
833
834 if (ok()) {
835 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
836 DeviceMemory<uint8> scratch_memory;
837 dnn::AlgorithmDesc algorithm_desc;
838 auto status =
839 dnn->PrepareForConvolution(
840 dnn::ConvolutionKind::FORWARD, this, input_descriptor,
841 input_data, filter_descriptor, filter_data, output_descriptor,
842 *output, convolution_descriptor, algorithm_config,
843 scratch_allocator, &algorithm_desc, &scratch_memory)
844 .ok();
845 if (status) {
846 status = dnn->DoConvolve(
847 this, input_descriptor, input_data, filter_descriptor, filter_data,
848 convolution_descriptor, output_descriptor, output, algorithm_desc,
849 &scratch_memory, output_profile_result);
850 }
851 if (!status && !output_profile_result) {
852 SetError();
853 }
854 } else {
855 SetErrorAndLogNoDnnSupport();
856 }
857 }
858 return *this;
859 }
860
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)861 Stream &Stream::ThenConvolve(
862 const dnn::BatchDescriptor &input_descriptor,
863 const DeviceMemory<float> &input_data,
864 const dnn::FilterDescriptor &filter_descriptor,
865 const DeviceMemory<float> &filter_data,
866 const dnn::ConvolutionDescriptor &convolution_descriptor,
867 const dnn::BatchDescriptor &output_descriptor,
868 DeviceMemory<float> *output) {
869 return ThenConvolveWithAlgorithm(
870 input_descriptor, input_data, filter_descriptor, filter_data,
871 convolution_descriptor, output_descriptor, output,
872 /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
873 /*output_profile_result=*/nullptr);
874 }
875
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)876 Stream &Stream::ThenConvolveQuantized(
877 const dnn::BatchDescriptor &input_descriptor,
878 const DeviceMemory<float> &input_data,
879 const dnn::FilterDescriptor &filter_descriptor,
880 const DeviceMemory<int8> &filter_coefficients,
881 const DeviceMemory<float> &coefficient_scales,
882 const dnn::ConvolutionDescriptor &convolution_descriptor,
883 const dnn::BatchDescriptor &output_descriptor,
884 DeviceMemory<float> *output) {
885 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
886 PARAM(filter_descriptor), PARAM(filter_coefficients),
887 PARAM(coefficient_scales), PARAM(convolution_descriptor),
888 PARAM(output_descriptor), PARAM(output));
889
890 if (ok()) {
891 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
892 CheckError(dnn->DoConvolveQuantized(
893 this, input_descriptor, input_data, filter_descriptor,
894 filter_coefficients, coefficient_scales, convolution_descriptor,
895 output_descriptor, output));
896 } else {
897 SetError();
898 LOG(WARNING)
899 << "attempting to perform DNN operation using StreamExecutor "
900 "without DNN support";
901 }
902 }
903 return *this;
904 }
905
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)906 Stream &Stream::ThenConvolveQuantized(
907 const dnn::BatchDescriptor &input_descriptor,
908 const DeviceMemory<float> &input_data,
909 const dnn::FilterDescriptor &filter_descriptor,
910 const DeviceMemory<int16> &filter_coefficients,
911 const DeviceMemory<float> &coefficient_scales,
912 const dnn::ConvolutionDescriptor &convolution_descriptor,
913 const dnn::BatchDescriptor &output_descriptor,
914 DeviceMemory<float> *output) {
915 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
916 PARAM(filter_descriptor), PARAM(filter_coefficients),
917 PARAM(coefficient_scales), PARAM(convolution_descriptor),
918 PARAM(output_descriptor), PARAM(output));
919
920 if (ok()) {
921 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
922 CheckError(dnn->DoConvolveQuantized(
923 this, input_descriptor, input_data, filter_descriptor,
924 filter_coefficients, coefficient_scales, convolution_descriptor,
925 output_descriptor, output));
926 } else {
927 SetError();
928 LOG(WARNING)
929 << "attempting to perform DNN operation using StreamExecutor "
930 "without DNN support";
931 }
932 }
933 return *this;
934 }
935
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)936 Stream &Stream::ThenSeparableConvolve(
937 const dnn::BatchDescriptor &batch_descriptor,
938 const DeviceMemory<float> &input_data,
939 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
940 const DeviceMemory<float> &first_weights,
941 const DeviceMemory<float> &second_weights,
942 const dnn::ConvolutionDescriptor &convolution_descriptor,
943 const dnn::BatchDescriptor &output_descriptor,
944 DeviceMemory<float> *output) {
945 VLOG_CALL(
946 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
947 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
948 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
949
950 if (ok()) {
951 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
952 CheckError(dnn->DoSeparableConvolve(
953 this, batch_descriptor, input_data, filter_descriptor,
954 depth_multiplier, first_weights, second_weights,
955 convolution_descriptor, output_descriptor, output));
956 } else {
957 SetErrorAndLogNoDnnSupport();
958 }
959 }
960 return *this;
961 }
962
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<double> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)963 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
964 const dnn::FilterDescriptor &filter_descriptor,
965 const DeviceMemory<double> &filter_data,
966 const dnn::BatchDescriptor &output_descriptor,
967 DeviceMemory<double> backward_output_data,
968 const dnn::ConvolutionDescriptor &convolution_descriptor,
969 const dnn::BatchDescriptor &input_descriptor,
970 DeviceMemory<double> *backward_input_data,
971 ScratchAllocator *scratch_allocator,
972 const dnn::AlgorithmConfig &algorithm_config,
973 dnn::ProfileResult *output_profile_result) {
974 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
975 PARAM(output_descriptor), PARAM(backward_output_data),
976 PARAM(convolution_descriptor), PARAM(input_descriptor),
977 PARAM(backward_input_data));
978
979 if (ok()) {
980 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
981 DeviceMemory<uint8> scratch_memory;
982 dnn::AlgorithmDesc algorithm_desc;
983 auto status =
984 dnn->PrepareForConvolution(
985 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
986 *backward_input_data, filter_descriptor, filter_data,
987 output_descriptor, backward_output_data,
988 convolution_descriptor, algorithm_config, scratch_allocator,
989 &algorithm_desc, &scratch_memory)
990 .ok();
991 if (status) {
992 status = dnn->DoConvolveBackwardData(
993 this, filter_descriptor, filter_data, output_descriptor,
994 backward_output_data, convolution_descriptor, input_descriptor,
995 backward_input_data, algorithm_desc, &scratch_memory,
996 output_profile_result);
997 }
998 if (!status && !output_profile_result) {
999 SetError();
1000 }
1001 } else {
1002 SetErrorAndLogNoDnnSupport();
1003 }
1004 }
1005 return *this;
1006 }
1007
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1008 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
1009 const dnn::FilterDescriptor &filter_descriptor,
1010 const DeviceMemory<float> &filter_data,
1011 const dnn::BatchDescriptor &output_descriptor,
1012 DeviceMemory<float> backward_output_data,
1013 const dnn::ConvolutionDescriptor &convolution_descriptor,
1014 const dnn::BatchDescriptor &input_descriptor,
1015 DeviceMemory<float> *backward_input_data,
1016 ScratchAllocator *scratch_allocator,
1017 const dnn::AlgorithmConfig &algorithm_config,
1018 dnn::ProfileResult *output_profile_result) {
1019 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
1020 PARAM(output_descriptor), PARAM(backward_output_data),
1021 PARAM(convolution_descriptor), PARAM(input_descriptor),
1022 PARAM(backward_input_data));
1023
1024 if (ok()) {
1025 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1026 DeviceMemory<uint8> scratch_memory;
1027 dnn::AlgorithmDesc algorithm_desc;
1028 auto status =
1029 dnn->PrepareForConvolution(
1030 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
1031 *backward_input_data, filter_descriptor, filter_data,
1032 output_descriptor, backward_output_data,
1033 convolution_descriptor, algorithm_config, scratch_allocator,
1034 &algorithm_desc, &scratch_memory)
1035 .ok();
1036 if (status) {
1037 status = dnn->DoConvolveBackwardData(
1038 this, filter_descriptor, filter_data, output_descriptor,
1039 backward_output_data, convolution_descriptor, input_descriptor,
1040 backward_input_data, algorithm_desc, &scratch_memory,
1041 output_profile_result);
1042 }
1043 if (!status && !output_profile_result) {
1044 SetError();
1045 }
1046 } else {
1047 SetErrorAndLogNoDnnSupport();
1048 }
1049 }
1050 return *this;
1051 }
1052
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1053 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
1054 const dnn::FilterDescriptor &filter_descriptor,
1055 const DeviceMemory<Eigen::half> &filter_data,
1056 const dnn::BatchDescriptor &output_descriptor,
1057 DeviceMemory<Eigen::half> backward_output_data,
1058 const dnn::ConvolutionDescriptor &convolution_descriptor,
1059 const dnn::BatchDescriptor &input_descriptor,
1060 DeviceMemory<Eigen::half> *backward_input_data,
1061 ScratchAllocator *scratch_allocator,
1062 const dnn::AlgorithmConfig &algorithm_config,
1063 dnn::ProfileResult *output_profile_result) {
1064 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
1065 PARAM(output_descriptor), PARAM(backward_output_data),
1066 PARAM(convolution_descriptor), PARAM(input_descriptor),
1067 PARAM(backward_input_data), PARAM(algorithm_config));
1068
1069 if (ok()) {
1070 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1071 DeviceMemory<uint8> scratch_memory;
1072 dnn::AlgorithmDesc algorithm_desc;
1073 auto status =
1074 dnn->PrepareForConvolution(
1075 dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
1076 *backward_input_data, filter_descriptor, filter_data,
1077 output_descriptor, backward_output_data,
1078 convolution_descriptor, algorithm_config, scratch_allocator,
1079 &algorithm_desc, &scratch_memory)
1080 .ok();
1081 if (status) {
1082 status = dnn->DoConvolveBackwardData(
1083 this, filter_descriptor, filter_data, output_descriptor,
1084 backward_output_data, convolution_descriptor, input_descriptor,
1085 backward_input_data, algorithm_desc, &scratch_memory,
1086 output_profile_result);
1087 }
1088 if (!status && !output_profile_result) {
1089 SetError();
1090 }
1091 } else {
1092 SetErrorAndLogNoDnnSupport();
1093 }
1094 }
1095 return *this;
1096 }
1097
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<double> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1098 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1099 const dnn::BatchDescriptor &input_descriptor,
1100 const DeviceMemory<double> &input_data,
1101 const dnn::BatchDescriptor &output_descriptor,
1102 DeviceMemory<double> backward_output_data,
1103 const dnn::ConvolutionDescriptor &convolution_descriptor,
1104 const dnn::FilterDescriptor &filter_descriptor,
1105 DeviceMemory<double> *backward_filter_data,
1106 ScratchAllocator *scratch_allocator,
1107 const dnn::AlgorithmConfig &algorithm_config,
1108 dnn::ProfileResult *output_profile_result) {
1109 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1110 PARAM(output_descriptor), PARAM(backward_output_data),
1111 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1112 PARAM(backward_filter_data));
1113
1114 if (ok()) {
1115 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1116 DeviceMemory<uint8> scratch_memory;
1117 dnn::AlgorithmDesc algorithm_desc;
1118 auto status =
1119 dnn->PrepareForConvolution(
1120 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1121 input_data, filter_descriptor, *backward_filter_data,
1122 output_descriptor, backward_output_data,
1123 convolution_descriptor, algorithm_config, scratch_allocator,
1124 &algorithm_desc, &scratch_memory)
1125 .ok();
1126 if (status) {
1127 status = dnn->DoConvolveBackwardFilter(
1128 this, input_descriptor, input_data, output_descriptor,
1129 backward_output_data, convolution_descriptor, filter_descriptor,
1130 backward_filter_data, algorithm_desc, &scratch_memory,
1131 output_profile_result);
1132 }
1133 if (!status && !output_profile_result) {
1134 SetError();
1135 }
1136 } else {
1137 SetErrorAndLogNoDnnSupport();
1138 }
1139 }
1140 return *this;
1141 }
1142
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1143 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1144 const dnn::BatchDescriptor &input_descriptor,
1145 const DeviceMemory<float> &input_data,
1146 const dnn::BatchDescriptor &output_descriptor,
1147 DeviceMemory<float> backward_output_data,
1148 const dnn::ConvolutionDescriptor &convolution_descriptor,
1149 const dnn::FilterDescriptor &filter_descriptor,
1150 DeviceMemory<float> *backward_filter_data,
1151 ScratchAllocator *scratch_allocator,
1152 const dnn::AlgorithmConfig &algorithm_config,
1153 dnn::ProfileResult *output_profile_result) {
1154 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1155 PARAM(output_descriptor), PARAM(backward_output_data),
1156 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1157 PARAM(backward_filter_data));
1158
1159 if (ok()) {
1160 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1161 DeviceMemory<uint8> scratch_memory;
1162 dnn::AlgorithmDesc algorithm_desc;
1163 auto status =
1164 dnn->PrepareForConvolution(
1165 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1166 input_data, filter_descriptor, *backward_filter_data,
1167 output_descriptor, backward_output_data,
1168 convolution_descriptor, algorithm_config, scratch_allocator,
1169 &algorithm_desc, &scratch_memory)
1170 .ok();
1171 if (status) {
1172 status = dnn->DoConvolveBackwardFilter(
1173 this, input_descriptor, input_data, output_descriptor,
1174 backward_output_data, convolution_descriptor, filter_descriptor,
1175 backward_filter_data, algorithm_desc, &scratch_memory,
1176 output_profile_result);
1177 }
1178 if (!status && !output_profile_result) {
1179 SetError();
1180 }
1181 } else {
1182 SetErrorAndLogNoDnnSupport();
1183 }
1184 }
1185 return *this;
1186 }
1187
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1188 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1189 const dnn::BatchDescriptor &input_descriptor,
1190 const DeviceMemory<Eigen::half> &input_data,
1191 const dnn::BatchDescriptor &output_descriptor,
1192 DeviceMemory<Eigen::half> backward_output_data,
1193 const dnn::ConvolutionDescriptor &convolution_descriptor,
1194 const dnn::FilterDescriptor &filter_descriptor,
1195 DeviceMemory<Eigen::half> *backward_filter_data,
1196 ScratchAllocator *scratch_allocator,
1197 const dnn::AlgorithmConfig &algorithm_config,
1198 dnn::ProfileResult *output_profile_result) {
1199 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1200 PARAM(output_descriptor), PARAM(backward_output_data),
1201 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1202 PARAM(backward_filter_data));
1203
1204 if (ok()) {
1205 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1206 DeviceMemory<uint8> scratch_memory;
1207 dnn::AlgorithmDesc algorithm_desc;
1208 auto status =
1209 dnn->PrepareForConvolution(
1210 dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1211 input_data, filter_descriptor, *backward_filter_data,
1212 output_descriptor, backward_output_data,
1213 convolution_descriptor, algorithm_config, scratch_allocator,
1214 &algorithm_desc, &scratch_memory)
1215 .ok();
1216 if (status) {
1217 status = dnn->DoConvolveBackwardFilter(
1218 this, input_descriptor, input_data, output_descriptor,
1219 backward_output_data, convolution_descriptor, filter_descriptor,
1220 backward_filter_data, algorithm_desc, &scratch_memory,
1221 output_profile_result);
1222 }
1223 if (!status && !output_profile_result) {
1224 SetError();
1225 }
1226 } else {
1227 SetErrorAndLogNoDnnSupport();
1228 }
1229 }
1230 return *this;
1231 }
1232
1233 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1234 Stream &Stream::ThenConvolveBackwardBiasImpl(
1235 const dnn::BatchDescriptor &input_descriptor,
1236 const DeviceMemory<T> &input_data,
1237 const dnn::BatchDescriptor &bias_descriptor,
1238 DeviceMemory<T> *backward_bias_data) {
1239 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1240 PARAM(backward_bias_data));
1241
1242 if (ok()) {
1243 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1244 CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1245 bias_descriptor,
1246 backward_bias_data));
1247 } else {
1248 SetErrorAndLogNoDnnSupport();
1249 }
1250 }
1251 return *this;
1252 }
1253
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1254 Stream &Stream::ThenConvolveBackwardBias(
1255 const dnn::BatchDescriptor &input_descriptor,
1256 const DeviceMemory<double> &input_data,
1257 const dnn::BatchDescriptor &bias_descriptor,
1258 DeviceMemory<double> *backward_bias_data) {
1259 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1260 bias_descriptor, backward_bias_data);
1261 }
1262
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1263 Stream &Stream::ThenConvolveBackwardBias(
1264 const dnn::BatchDescriptor &input_descriptor,
1265 const DeviceMemory<float> &input_data,
1266 const dnn::BatchDescriptor &bias_descriptor,
1267 DeviceMemory<float> *backward_bias_data) {
1268 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1269 bias_descriptor, backward_bias_data);
1270 }
1271
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1272 Stream &Stream::ThenConvolveBackwardBias(
1273 const dnn::BatchDescriptor &input_descriptor,
1274 const DeviceMemory<Eigen::half> &input_data,
1275 const dnn::BatchDescriptor &bias_descriptor,
1276 DeviceMemory<Eigen::half> *backward_bias_data) {
1277 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1278 bias_descriptor, backward_bias_data);
1279 }
1280
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1281 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1282 const DeviceMemory<float> &weights,
1283 const dnn::BatchDescriptor &input_dimensions,
1284 const dnn::BatchDescriptor &output_dimensions,
1285 DeviceMemory<float> *output_data) {
1286 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1287 PARAM(output_dimensions), PARAM(output_data));
1288
1289 if (ok()) {
1290 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1291 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1292 output_dimensions, output_data));
1293 } else {
1294 SetErrorAndLogNoDnnSupport();
1295 }
1296 }
1297 return *this;
1298 }
1299
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1300 Stream &Stream::ThenMatMulQuantized(
1301 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1302 const DeviceMemory<float> &weight_scales,
1303 const dnn::BatchDescriptor &input_dimensions,
1304 const dnn::BatchDescriptor &output_dimensions,
1305 DeviceMemory<float> *output_data) {
1306 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1307 PARAM(input_dimensions), PARAM(output_dimensions),
1308 PARAM(output_data));
1309
1310 if (ok()) {
1311 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1312 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1313 weight_scales, input_dimensions,
1314 output_dimensions, output_data));
1315 } else {
1316 SetErrorAndLogNoDnnSupport();
1317 }
1318 }
1319 return *this;
1320 }
1321
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1322 Stream &Stream::ThenMatMulQuantized(
1323 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1324 const DeviceMemory<float> &weight_scales,
1325 const dnn::BatchDescriptor &input_dimensions,
1326 const dnn::BatchDescriptor &output_dimensions,
1327 DeviceMemory<float> *output_data) {
1328 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1329 PARAM(input_dimensions), PARAM(output_dimensions),
1330 PARAM(output_data));
1331
1332 if (ok()) {
1333 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1334 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1335 weight_scales, input_dimensions,
1336 output_dimensions, output_data));
1337 } else {
1338 SetErrorAndLogNoDnnSupport();
1339 }
1340 }
1341 return *this;
1342 }
1343
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1344 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1345 const DeviceMemory<float> &biases,
1346 const dnn::BatchDescriptor &dimensions,
1347 DeviceMemory<float> *output_data) {
1348 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1349 PARAM(output_data));
1350
1351 if (ok()) {
1352 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1353 CheckError(
1354 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1355 } else {
1356 SetErrorAndLogNoDnnSupport();
1357 }
1358 }
1359 return *this;
1360 }
1361
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1362 Stream &Stream::ThenPoolForward(
1363 const dnn::PoolingDescriptor &pooling_dimensions,
1364 const dnn::BatchDescriptor &input_dimensions,
1365 const DeviceMemory<double> &input_data,
1366 const dnn::BatchDescriptor &output_dimensions,
1367 DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
1368 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1369 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1370 PARAM(workspace_allocator));
1371
1372 if (ok()) {
1373 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1374 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1375 input_data, output_dimensions, output_data,
1376 workspace_allocator));
1377 } else {
1378 SetError();
1379 LOG(WARNING)
1380 << "attempting to perform DNN operation using StreamExecutor "
1381 "without DNN support";
1382 }
1383 }
1384 return *this;
1385 }
1386
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)1387 Stream &Stream::ThenPoolForward(
1388 const dnn::PoolingDescriptor &pooling_dimensions,
1389 const dnn::BatchDescriptor &input_dimensions,
1390 const DeviceMemory<float> &input_data,
1391 const dnn::BatchDescriptor &output_dimensions,
1392 DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
1393 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1394 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1395 PARAM(workspace_allocator));
1396
1397 if (ok()) {
1398 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1399 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1400 input_data, output_dimensions, output_data,
1401 workspace_allocator));
1402 } else {
1403 SetErrorAndLogNoDnnSupport();
1404 }
1405 }
1406 return *this;
1407 }
1408
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1409 Stream &Stream::ThenPoolForward(
1410 const dnn::PoolingDescriptor &pooling_dimensions,
1411 const dnn::BatchDescriptor &input_dimensions,
1412 const DeviceMemory<Eigen::half> &input_data,
1413 const dnn::BatchDescriptor &output_dimensions,
1414 DeviceMemory<Eigen::half> *output_data,
1415 ScratchAllocator *workspace_allocator) {
1416 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1417 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1418 PARAM(workspace_allocator));
1419
1420 if (ok()) {
1421 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1422 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1423 input_data, output_dimensions, output_data,
1424 workspace_allocator));
1425 } else {
1426 SetErrorAndLogNoDnnSupport();
1427 }
1428 }
1429 return *this;
1430 }
1431
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1432 Stream &Stream::ThenPoolForward(
1433 const dnn::PoolingDescriptor &pooling_dimensions,
1434 const dnn::BatchDescriptor &input_dimensions,
1435 const DeviceMemory<int8> &input_data,
1436 const dnn::BatchDescriptor &output_dimensions,
1437 DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
1438 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1439 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1440 PARAM(workspace_allocator));
1441
1442 if (ok()) {
1443 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1444 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1445 input_data, output_dimensions, output_data,
1446 workspace_allocator));
1447 } else {
1448 SetErrorAndLogNoDnnSupport();
1449 }
1450 }
1451 return *this;
1452 }
1453
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1454 Stream &Stream::ThenPoolBackward(
1455 const dnn::PoolingDescriptor &pooling_dimensions,
1456 const dnn::BatchDescriptor &input_dimensions,
1457 const DeviceMemory<double> &input_data,
1458 const dnn::BatchDescriptor &output_dimensions,
1459 const DeviceMemory<double> &output_data,
1460 const DeviceMemory<double> &input_diff_data,
1461 DeviceMemory<double> *output_diff_data,
1462 ScratchAllocator *workspace_allocator) {
1463 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1464 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1465 PARAM(input_diff_data), PARAM(output_diff_data),
1466 PARAM(workspace_allocator));
1467
1468 if (ok()) {
1469 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1470 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1471 input_data, output_dimensions, output_data,
1472 input_diff_data, output_diff_data,
1473 workspace_allocator));
1474 } else {
1475 SetError();
1476 LOG(WARNING)
1477 << "attempting to perform DNN operation using StreamExecutor "
1478 "without DNN support";
1479 }
1480 }
1481 return *this;
1482 }
1483
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1484 Stream &Stream::ThenPoolBackward(
1485 const dnn::PoolingDescriptor &pooling_dimensions,
1486 const dnn::BatchDescriptor &input_dimensions,
1487 const DeviceMemory<float> &input_data,
1488 const dnn::BatchDescriptor &output_dimensions,
1489 const DeviceMemory<float> &output_data,
1490 const DeviceMemory<float> &input_diff_data,
1491 DeviceMemory<float> *output_diff_data,
1492 ScratchAllocator *workspace_allocator) {
1493 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1494 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1495 PARAM(input_diff_data), PARAM(output_diff_data),
1496 PARAM(workspace_allocator));
1497
1498 if (ok()) {
1499 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1500 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1501 input_data, output_dimensions, output_data,
1502 input_diff_data, output_diff_data,
1503 workspace_allocator));
1504 } else {
1505 SetErrorAndLogNoDnnSupport();
1506 }
1507 }
1508 return *this;
1509 }
1510
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1511 Stream &Stream::ThenPoolBackward(
1512 const dnn::PoolingDescriptor &pooling_dimensions,
1513 const dnn::BatchDescriptor &input_dimensions,
1514 const DeviceMemory<Eigen::half> &input_data,
1515 const dnn::BatchDescriptor &output_dimensions,
1516 const DeviceMemory<Eigen::half> &output_data,
1517 const DeviceMemory<Eigen::half> &input_diff_data,
1518 DeviceMemory<Eigen::half> *output_diff_data,
1519 ScratchAllocator *workspace_allocator) {
1520 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1521 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1522 PARAM(input_diff_data), PARAM(output_diff_data),
1523 PARAM(workspace_allocator));
1524
1525 if (ok()) {
1526 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1527 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1528 input_data, output_dimensions, output_data,
1529 input_diff_data, output_diff_data,
1530 workspace_allocator));
1531 } else {
1532 SetErrorAndLogNoDnnSupport();
1533 }
1534 }
1535 return *this;
1536 }
1537
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1538 Stream &Stream::ThenNormalizeWithDimensions(
1539 const dnn::NormalizeDescriptor &normalize_descriptor,
1540 const dnn::BatchDescriptor &dimensions,
1541 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1542 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1543 PARAM(output_data));
1544
1545 if (ok()) {
1546 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1547 CheckError(dnn->DoNormalizeWithDimensions(
1548 this, normalize_descriptor, dimensions, input_data, output_data));
1549 } else {
1550 SetErrorAndLogNoDnnSupport();
1551 }
1552 }
1553 return *this;
1554 }
1555
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1556 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1557 const dnn::NormalizeDescriptor &normalize_descriptor,
1558 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1559 const DeviceMemory<float> &normalized_data,
1560 const DeviceMemory<float> &normalized_variable_gradient,
1561 DeviceMemory<float> *raw_variable_gradient,
1562 ScratchAllocator *workspace_allocator) {
1563 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1564 PARAM(normalized_data), PARAM(normalized_variable_gradient),
1565 PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1566
1567 if (ok()) {
1568 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1569 CheckError(dnn->DoNormalizeBackwardWithDimensions(
1570 this, normalize_descriptor, dimensions, raw_data, normalized_data,
1571 normalized_variable_gradient, raw_variable_gradient,
1572 workspace_allocator));
1573 } else {
1574 SetErrorAndLogNoDnnSupport();
1575 }
1576 }
1577 return *this;
1578 }
1579
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1580 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1581 const dnn::BatchDescriptor &dimensions,
1582 const DeviceMemory<float> &input_data,
1583 DeviceMemory<float> *output_data) {
1584 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1585 output_data, /*options=*/0);
1586 }
1587
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1588 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1589 const dnn::BatchDescriptor &dimensions,
1590 const DeviceMemory<float> &input_data,
1591 DeviceMemory<float> *output_data,
1592 uint64 options) {
1593 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1594 PARAM(output_data), PARAM(options));
1595
1596 if (ok()) {
1597 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1598 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1599 output_data, options));
1600 } else {
1601 SetErrorAndLogNoDnnSupport();
1602 }
1603 }
1604 return *this;
1605 }
1606
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1607 Stream &Stream::ThenDepthConcatenate(
1608 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1609 port::ArraySlice<const DeviceMemory<float> *> input_data,
1610 DeviceMemory<float> *output_data) {
1611 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1612
1613 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1614 if (input_dimensions[i].count() != input_dimensions[0].count() ||
1615 input_dimensions[i].height() != input_dimensions[0].height() ||
1616 input_dimensions[i].width() != input_dimensions[0].width()) {
1617 SetError();
1618 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1619 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1620 << "input_dimensions[" << i
1621 << "]: " << input_dimensions[i].ToString();
1622 return *this;
1623 }
1624 }
1625
1626 if (ok()) {
1627 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1628 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1629 output_data));
1630 } else {
1631 SetErrorAndLogNoDnnSupport();
1632 }
1633 }
1634 return *this;
1635 }
1636
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1637 Stream &Stream::ThenSpaceConcatenate(
1638 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1639 port::ArraySlice<const DeviceMemory<float> *> input_data,
1640 DeviceMemory<float> *output_data,
1641 dnn::SpaceConcatenateMode concat_direction) {
1642 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1643
1644 // Check that the input dimensions of all the other batches match those of the
1645 // first batch.
1646 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1647 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1648 (input_dimensions[i].count() != input_dimensions[0].count() ||
1649 input_dimensions[i].height() != input_dimensions[0].height() ||
1650 input_dimensions[i].feature_map_count() !=
1651 input_dimensions[0].feature_map_count())) {
1652 SetError();
1653 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1654 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1655 << "input_dimensions[" << i
1656 << "]: " << input_dimensions[i].ToString();
1657 return *this;
1658 }
1659
1660 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1661 (input_dimensions[i].count() != input_dimensions[0].count() ||
1662 input_dimensions[i].width() != input_dimensions[0].width() ||
1663 input_dimensions[i].feature_map_count() !=
1664 input_dimensions[0].feature_map_count())) {
1665 SetError();
1666 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1667 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1668 << "input_dimensions[" << i
1669 << "]: " << input_dimensions[i].ToString();
1670 return *this;
1671 }
1672 }
1673 if (ok()) {
1674 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1675 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1676 output_data, concat_direction));
1677 } else {
1678 SetErrorAndLogNoDnnSupport();
1679 }
1680 }
1681 return *this;
1682 }
1683
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1684 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1685 const DeviceMemory<float> &input_data,
1686 const dnn::BatchDescriptor &output_dimensions,
1687 DeviceMemory<float> *output_data) {
1688 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1689 PARAM(output_dimensions), PARAM(output_data));
1690
1691 if (ok()) {
1692 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1693 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1694 output_dimensions, output_data));
1695 } else {
1696 SetErrorAndLogNoDnnSupport();
1697 }
1698 }
1699 return *this;
1700 }
1701
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1702 Stream &Stream::ThenDepthToSpace(
1703 const dnn::BatchDescriptor &input_dimensions,
1704 const DeviceMemory<float> &input_data,
1705 const dnn::DepthToSpaceLayout &depth_to_space_layout,
1706 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1707 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1708 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1709 PARAM(output_data));
1710
1711 if (ok()) {
1712 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1713 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1714 depth_to_space_layout,
1715 sqrt_depth_reduction, output_data));
1716 } else {
1717 SetErrorAndLogNoDnnSupport();
1718 }
1719 }
1720 return *this;
1721 }
1722
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1723 Stream &Stream::ThenSpaceToDepth(
1724 const dnn::BatchDescriptor &input_dimensions,
1725 const DeviceMemory<float> &input_data,
1726 const dnn::DepthToSpaceLayout &space_to_depth_layout,
1727 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1728 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1729 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1730 PARAM(output_data));
1731
1732 if (ok()) {
1733 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1734 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1735 space_to_depth_layout, sqrt_depth_increase,
1736 output_data));
1737 } else {
1738 SetErrorAndLogNoDnnSupport();
1739 }
1740 }
1741 return *this;
1742 }
1743
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1744 Stream &Stream::ThenElementwiseOperate(
1745 dnn::ElementwiseOperation operation,
1746 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1747 port::ArraySlice<const DeviceMemory<float> *> input_data,
1748 const dnn::BatchDescriptor &output_dimensions,
1749 DeviceMemory<float> *output_data) {
1750 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1751 PARAM(output_dimensions), PARAM(output_data));
1752
1753 if (ok()) {
1754 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1755 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1756 input_data, output_dimensions,
1757 output_data));
1758 } else {
1759 SetErrorAndLogNoDnnSupport();
1760 }
1761 }
1762 return *this;
1763 }
1764
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1765 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1766 dnn::ElementwiseOperation operation,
1767 port::ArraySlice<int> input_multiplicands, int output_divisor,
1768 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1769 port::ArraySlice<const DeviceMemory<float> *> input_data,
1770 const dnn::BatchDescriptor &output_dimensions,
1771 DeviceMemory<float> *output_data) {
1772 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1773 PARAM(input_dimensions), PARAM(input_data),
1774 PARAM(output_dimensions), PARAM(output_data));
1775
1776 if (ok()) {
1777 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1778 CheckError(dnn->DoElementwiseOperateScaledQuantized(
1779 this, operation, input_multiplicands, output_divisor,
1780 input_dimensions, input_data, output_dimensions, output_data));
1781 } else {
1782 SetErrorAndLogNoDnnSupport();
1783 }
1784 }
1785 return *this;
1786 }
1787
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1788 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1789 const DeviceMemory<float> &input_data, int64 left_pad,
1790 int64 right_pad, int64 top_pad, int64 bottom_pad,
1791 DeviceMemory<float> *output_data) {
1792 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1793 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1794 PARAM(output_data));
1795
1796 if (ok()) {
1797 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1798 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1799 top_pad, bottom_pad, output_data));
1800 } else {
1801 SetErrorAndLogNoDnnSupport();
1802 }
1803 }
1804 return *this;
1805 }
1806
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1807 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1808 const DeviceMemory<float> &input_data,
1809 int64 left_trim, int64 right_trim, int64 top_trim,
1810 int64 bottom_trim,
1811 DeviceMemory<float> *output_data) {
1812 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1813 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1814 PARAM(output_data));
1815
1816 if (ok()) {
1817 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1818 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1819 right_trim, top_trim, bottom_trim,
1820 output_data));
1821 } else {
1822 SetErrorAndLogNoDnnSupport();
1823 }
1824 }
1825 return *this;
1826 }
1827
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1828 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1829 const DeviceMemory<float> &input_data,
1830 int64 replicate_x, int64 replicate_y,
1831 DeviceMemory<float> *output_data) {
1832 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1833 PARAM(replicate_y), PARAM(output_data));
1834
1835 if (ok()) {
1836 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1837 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1838 replicate_y, output_data));
1839 } else {
1840 SetErrorAndLogNoDnnSupport();
1841 }
1842 }
1843 return *this;
1844 }
1845
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1846 Stream &Stream::ThenMemcpyD2HQuantized(
1847 const DeviceMemory<float> &gpu_unquantized_src,
1848 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1849 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1850 PARAM(size));
1851
1852 if (ok()) {
1853 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1854 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1855 host_dst, size));
1856 } else {
1857 SetErrorAndLogNoDnnSupport();
1858 }
1859 }
1860 return *this;
1861 }
1862
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1863 Stream &Stream::ThenMemcpyH2DQuantized(
1864 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1865 DeviceMemory<float> *gpu_unquantized_dst) {
1866 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1867 PARAM(gpu_unquantized_dst));
1868
1869 if (ok()) {
1870 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1871 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1872 gpu_unquantized_dst));
1873 } else {
1874 SetErrorAndLogNoDnnSupport();
1875 }
1876 }
1877 return *this;
1878 }
1879
GetOrCreateSubStream()1880 Stream *Stream::GetOrCreateSubStream() {
1881 absl::MutexLock lock(&mu_);
1882
1883 // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1884 // we encounter along the way.
1885 for (int64 index = 0; index < sub_streams_.size();) {
1886 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1887 if (pair.second) {
1888 // The sub_stream is reusable.
1889 Stream *sub_stream = pair.first.get();
1890 if (sub_stream->ok()) {
1891 VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1892 << sub_stream->DebugStreamPointers();
1893 pair.second = false;
1894 return sub_stream;
1895 }
1896
1897 // The stream is reusable and not ok. Streams have a monotonic state
1898 // machine; the stream will remain in !ok forever. Swap it with the last
1899 // stream and pop it off.
1900 const int64 last = sub_streams_.size() - 1;
1901 if (index != last) {
1902 std::swap(pair, sub_streams_[last]);
1903 }
1904 sub_streams_.pop_back();
1905 VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1906 << sub_stream->DebugStreamPointers();
1907 } else {
1908 // The sub_stream is not reusable, move on to the next one.
1909 ++index;
1910 }
1911 }
1912
1913 // No streams are reusable; create a new stream.
1914 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1915 false);
1916 Stream *sub_stream = sub_streams_.back().first.get();
1917 sub_stream->Init();
1918 if (!sub_stream->ok_) {
1919 LOG(ERROR) << "sub-stream failed to be initialized";
1920 }
1921 VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1922 << sub_stream->DebugStreamPointers();
1923
1924 return sub_stream;
1925 }
1926
ReturnSubStream(Stream * sub_stream)1927 void Stream::ReturnSubStream(Stream *sub_stream) {
1928 absl::MutexLock lock(&mu_);
1929
1930 // Look for the sub-stream.
1931 for (int64 index = 0; index < sub_streams_.size(); ++index) {
1932 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1933 if (pair.first.get() != sub_stream) {
1934 continue;
1935 }
1936
1937 // Found the sub_stream.
1938 if (sub_stream->ok()) {
1939 VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1940 << sub_stream->DebugStreamPointers();
1941 pair.second = true;
1942 } else {
1943 // The returned stream is not ok. Streams have a monotonic state
1944 // machine; the stream will remain in !ok forever. Swap it with the last
1945 // stream and pop it off.
1946 VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1947 << sub_stream->DebugStreamPointers();
1948 const int64 last = sub_streams_.size() - 1;
1949 if (index != last) {
1950 std::swap(pair, sub_streams_[last]);
1951 }
1952 sub_streams_.pop_back();
1953 }
1954 return;
1955 }
1956
1957 LOG(FATAL) << DebugStreamPointers()
1958 << " did not create the returned sub-stream "
1959 << sub_stream->DebugStreamPointers();
1960 }
1961
ThenStartTimer(Timer * t)1962 Stream &Stream::ThenStartTimer(Timer *t) {
1963 VLOG_CALL(PARAM(t));
1964
1965 if (ok()) {
1966 CheckError(parent_->StartTimer(this, t));
1967 } else {
1968 LOG(INFO) << DebugStreamPointers()
1969 << " did not enqueue 'start timer': " << t;
1970 }
1971 return *this;
1972 }
1973
ThenStopTimer(Timer * t)1974 Stream &Stream::ThenStopTimer(Timer *t) {
1975 VLOG_CALL(PARAM(t));
1976
1977 if (ok()) {
1978 CheckError(parent_->StopTimer(this, t));
1979 } else {
1980 LOG(INFO) << DebugStreamPointers()
1981 << " did not enqueue 'stop timer': " << t;
1982 }
1983 return *this;
1984 }
1985
ThenWaitFor(Stream * other)1986 Stream &Stream::ThenWaitFor(Stream *other) {
1987 VLOG_CALL(PARAM(other));
1988
1989 CHECK(this != other) << "stream cannot wait for itself";
1990 if (ok() && other->ok()) {
1991 CheckError(parent_->CreateStreamDependency(this, other));
1992 } else {
1993 SetError();
1994 LOG(INFO) << DebugStreamPointers() << " did not wait for "
1995 << other->DebugStreamPointers();
1996 }
1997 return *this;
1998 }
1999
ThenWaitFor(Event * event)2000 Stream &Stream::ThenWaitFor(Event *event) {
2001 VLOG_CALL(PARAM(event));
2002
2003 if (ok()) {
2004 port::Status status = parent_->WaitForEvent(this, event);
2005 if (!status.ok()) {
2006 LOG(ERROR) << "Error waiting for event in stream: "
2007 << status.error_message()
2008 << "; not marking stream as bad, as the Event object may be "
2009 << "at fault. Monitor for further errors.";
2010 }
2011 } else {
2012 LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
2013 }
2014 return *this;
2015 }
2016
2017 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
2018 // functions and logs for errors.
2019 template <typename... Args>
2020 struct ThenBlasImpl {
2021 // blas_func is the DoBlasXXX member function pointer, and args are its
2022 // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl2023 Stream &operator()(Stream *stream,
2024 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2025 Args... args) {
2026 return Run(stream, blas_func, /*record_error=*/true, args...);
2027 }
2028
2029 // Like operator(), but only calls stream->CheckError() if record_error is
2030 // true.
2031 Stream &Run(Stream *stream,
2032 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2033 bool record_error, Args... args);
2034 };
2035
2036 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)2037 Stream &ThenBlasImpl<Args...>::Run(
2038 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2039 bool record_error, Args... args) {
2040 if (stream->ok()) {
2041 bool ok;
2042 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
2043 ok = (blas->*blas_func)(stream, args...);
2044 } else {
2045 LOG(WARNING)
2046 << "attempting to perform BLAS operation using StreamExecutor "
2047 "without BLAS support";
2048 ok = false;
2049 }
2050 if (record_error) {
2051 stream->CheckError(ok);
2052 }
2053 }
2054 return *stream;
2055 }
2056
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2057 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
2058 int incx, DeviceMemory<float> *result) {
2059 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2060
2061 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2062 impl;
2063 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2064 result);
2065 }
2066
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2067 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
2068 int incx, DeviceMemory<double> *result) {
2069 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2070
2071 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2072 DeviceMemory<double> *> impl;
2073 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2074 result);
2075 }
2076
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2077 Stream &Stream::ThenBlasAsum(uint64 elem_count,
2078 const DeviceMemory<std::complex<float>> &x,
2079 int incx, DeviceMemory<float> *result) {
2080 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2081
2082 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2083 DeviceMemory<float> *> impl;
2084 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2085 result);
2086 }
2087
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2088 Stream &Stream::ThenBlasAsum(uint64 elem_count,
2089 const DeviceMemory<std::complex<double>> &x,
2090 int incx, DeviceMemory<double> *result) {
2091 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2092
2093 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2094 DeviceMemory<double> *> impl;
2095 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2096 result);
2097 }
2098
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2099 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
2100 const DeviceMemory<float> &x, int incx,
2101 DeviceMemory<float> *y, int incy) {
2102 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2103 PARAM(incy));
2104
2105 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
2106 DeviceMemory<float> *, int> impl;
2107 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2108 y, incy);
2109 }
2110
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2111 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
2112 const DeviceMemory<double> &x, int incx,
2113 DeviceMemory<double> *y, int incy) {
2114 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2115 PARAM(incy));
2116
2117 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
2118 DeviceMemory<double> *, int> impl;
2119 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2120 y, incy);
2121 }
2122
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2123 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
2124 const DeviceMemory<std::complex<float>> &x,
2125 int incx, DeviceMemory<std::complex<float>> *y,
2126 int incy) {
2127 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2128 PARAM(incy));
2129
2130 ThenBlasImpl<uint64, std::complex<float>,
2131 const DeviceMemory<std::complex<float>> &, int,
2132 DeviceMemory<std::complex<float>> *, int> impl;
2133 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2134 y, incy);
2135 }
2136
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2137 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
2138 const DeviceMemory<std::complex<double>> &x,
2139 int incx, DeviceMemory<std::complex<double>> *y,
2140 int incy) {
2141 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2142 PARAM(incy));
2143
2144 ThenBlasImpl<uint64, std::complex<double>,
2145 const DeviceMemory<std::complex<double>> &, int,
2146 DeviceMemory<std::complex<double>> *, int> impl;
2147 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2148 y, incy);
2149 }
2150
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2151 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2152 int incx, DeviceMemory<float> *y, int incy) {
2153 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2154
2155 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2156 int> impl;
2157 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2158 incy);
2159 }
2160
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2161 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2162 int incx, DeviceMemory<double> *y, int incy) {
2163 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2164
2165 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2166 DeviceMemory<double> *, int> impl;
2167 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2168 incy);
2169 }
2170
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2171 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2172 const DeviceMemory<std::complex<float>> &x,
2173 int incx, DeviceMemory<std::complex<float>> *y,
2174 int incy) {
2175 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2176
2177 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2178 DeviceMemory<std::complex<float>> *, int> impl;
2179 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2180 incy);
2181 }
2182
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2183 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2184 const DeviceMemory<std::complex<double>> &x,
2185 int incx, DeviceMemory<std::complex<double>> *y,
2186 int incy) {
2187 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2188
2189 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2190 DeviceMemory<std::complex<double>> *, int> impl;
2191 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2192 incy);
2193 }
2194
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2195 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2196 int incx, const DeviceMemory<float> &y, int incy,
2197 DeviceMemory<float> *result) {
2198 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2199 PARAM(result));
2200
2201 ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2202 const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2203 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2204 result);
2205 }
2206
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2207 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2208 int incx, const DeviceMemory<double> &y, int incy,
2209 DeviceMemory<double> *result) {
2210 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2211 PARAM(result));
2212
2213 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2214 const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2215 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2216 result);
2217 }
2218
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2219 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2220 const DeviceMemory<std::complex<float>> &x,
2221 int incx,
2222 const DeviceMemory<std::complex<float>> &y,
2223 int incy,
2224 DeviceMemory<std::complex<float>> *result) {
2225 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2226 PARAM(result));
2227
2228 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2229 const DeviceMemory<std::complex<float>> &, int,
2230 DeviceMemory<std::complex<float>> *> impl;
2231 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2232 incy, result);
2233 }
2234
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2235 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2236 const DeviceMemory<std::complex<double>> &x,
2237 int incx,
2238 const DeviceMemory<std::complex<double>> &y,
2239 int incy,
2240 DeviceMemory<std::complex<double>> *result) {
2241 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2242 PARAM(result));
2243
2244 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2245 const DeviceMemory<std::complex<double>> &, int,
2246 DeviceMemory<std::complex<double>> *> impl;
2247 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2248 incy, result);
2249 }
2250
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2251 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2252 const DeviceMemory<std::complex<float>> &x,
2253 int incx,
2254 const DeviceMemory<std::complex<float>> &y,
2255 int incy,
2256 DeviceMemory<std::complex<float>> *result) {
2257 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2258 PARAM(result));
2259
2260 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2261 const DeviceMemory<std::complex<float>> &, int,
2262 DeviceMemory<std::complex<float>> *> impl;
2263 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2264 incy, result);
2265 }
2266
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2267 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2268 const DeviceMemory<std::complex<double>> &x,
2269 int incx,
2270 const DeviceMemory<std::complex<double>> &y,
2271 int incy,
2272 DeviceMemory<std::complex<double>> *result) {
2273 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2274 PARAM(result));
2275
2276 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2277 const DeviceMemory<std::complex<double>> &, int,
2278 DeviceMemory<std::complex<double>> *> impl;
2279 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2280 incy, result);
2281 }
2282
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2283 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2284 int incx, DeviceMemory<float> *result) {
2285 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2286
2287 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2288 impl;
2289 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2290 result);
2291 }
2292
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2293 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2294 int incx, DeviceMemory<double> *result) {
2295 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2296
2297 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2298 DeviceMemory<double> *> impl;
2299 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2300 result);
2301 }
2302
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2303 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2304 const DeviceMemory<std::complex<float>> &x,
2305 int incx, DeviceMemory<float> *result) {
2306 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2307
2308 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2309 DeviceMemory<float> *> impl;
2310 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2311 result);
2312 }
2313
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2314 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2315 const DeviceMemory<std::complex<double>> &x,
2316 int incx, DeviceMemory<double> *result) {
2317 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2318
2319 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2320 DeviceMemory<double> *> impl;
2321 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2322 result);
2323 }
2324
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2325 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2326 DeviceMemory<float> *y, int incy, float c,
2327 float s) {
2328 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2329 PARAM(c), PARAM(s));
2330
2331 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2332 float, float> impl;
2333 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2334 c, s);
2335 }
2336
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2337 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2338 int incx, DeviceMemory<double> *y, int incy,
2339 double c, double s) {
2340 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2341 PARAM(c), PARAM(s));
2342
2343 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2344 double, double> impl;
2345 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2346 c, s);
2347 }
2348
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2349 Stream &Stream::ThenBlasRot(uint64 elem_count,
2350 DeviceMemory<std::complex<float>> *x, int incx,
2351 DeviceMemory<std::complex<float>> *y, int incy,
2352 float c, float s) {
2353 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2354 PARAM(c), PARAM(s));
2355
2356 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2357 DeviceMemory<std::complex<float>> *, int, float, float> impl;
2358 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2359 c, s);
2360 }
2361
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2362 Stream &Stream::ThenBlasRot(uint64 elem_count,
2363 DeviceMemory<std::complex<double>> *x, int incx,
2364 DeviceMemory<std::complex<double>> *y, int incy,
2365 double c, double s) {
2366 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2367 PARAM(c), PARAM(s));
2368
2369 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2370 DeviceMemory<std::complex<double>> *, int, double, double> impl;
2371 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2372 c, s);
2373 }
2374
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2375 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2376 DeviceMemory<float> *c, DeviceMemory<float> *s) {
2377 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2378
2379 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2380 DeviceMemory<float> *, DeviceMemory<float> *> impl;
2381 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2382 }
2383
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2384 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2385 DeviceMemory<double> *c, DeviceMemory<double> *s) {
2386 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2387
2388 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2389 DeviceMemory<double> *, DeviceMemory<double> *> impl;
2390 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2391 }
2392
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2393 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2394 DeviceMemory<std::complex<float>> *b,
2395 DeviceMemory<float> *c,
2396 DeviceMemory<std::complex<float>> *s) {
2397 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2398
2399 ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2400 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2401 DeviceMemory<std::complex<float>> *> impl;
2402 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2403 }
2404
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2405 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2406 DeviceMemory<std::complex<double>> *b,
2407 DeviceMemory<double> *c,
2408 DeviceMemory<std::complex<double>> *s) {
2409 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2410
2411 ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2412 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2413 DeviceMemory<std::complex<double>> *> impl;
2414 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2415 }
2416
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2417 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2418 int incx, DeviceMemory<float> *y, int incy,
2419 const DeviceMemory<float> ¶m) {
2420 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2421 PARAM(param));
2422
2423 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2424 const DeviceMemory<float> &> impl;
2425 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2426 incy, param);
2427 }
2428
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2429 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2430 int incx, DeviceMemory<double> *y, int incy,
2431 const DeviceMemory<double> ¶m) {
2432 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2433 PARAM(param));
2434
2435 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2436 const DeviceMemory<double> &> impl;
2437 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2438 incy, param);
2439 }
2440
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2441 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2442 DeviceMemory<float> *x1,
2443 const DeviceMemory<float> &y1,
2444 DeviceMemory<float> *param) {
2445 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2446
2447 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2448 DeviceMemory<float> *, const DeviceMemory<float> &,
2449 DeviceMemory<float> *> impl;
2450 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2451 }
2452
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2453 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2454 DeviceMemory<double> *d2,
2455 DeviceMemory<double> *x1,
2456 const DeviceMemory<double> &y1,
2457 DeviceMemory<double> *param) {
2458 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2459
2460 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2461 DeviceMemory<double> *, const DeviceMemory<double> &,
2462 DeviceMemory<double> *> impl;
2463 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2464 }
2465
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2466 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2467 DeviceMemory<float> *x, int incx) {
2468 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2469
2470 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2471 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2472 }
2473
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2474 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2475 DeviceMemory<double> *x, int incx) {
2476 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2477
2478 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2479 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2480 }
2481
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2482 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2483 DeviceMemory<std::complex<float>> *x, int incx) {
2484 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2485
2486 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2487 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2488 }
2489
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2490 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2491 DeviceMemory<std::complex<double>> *x, int incx) {
2492 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2493
2494 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2495 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2496 }
2497
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2498 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2499 DeviceMemory<std::complex<float>> *x, int incx) {
2500 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2501
2502 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2503 int> impl;
2504 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2505 }
2506
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2507 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2508 DeviceMemory<std::complex<double>> *x, int incx) {
2509 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2510
2511 ThenBlasImpl<uint64, std::complex<double>,
2512 DeviceMemory<std::complex<double>> *, int> impl;
2513 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2514 }
2515
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2516 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2517 int incx, DeviceMemory<float> *y, int incy) {
2518 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2519
2520 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2521 impl;
2522 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2523 incy);
2524 }
2525
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2526 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2527 int incx, DeviceMemory<double> *y, int incy) {
2528 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2529
2530 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2531 impl;
2532 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2533 incy);
2534 }
2535
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2536 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2537 DeviceMemory<std::complex<float>> *x, int incx,
2538 DeviceMemory<std::complex<float>> *y, int incy) {
2539 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2540
2541 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2542 DeviceMemory<std::complex<float>> *, int> impl;
2543 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2544 incy);
2545 }
2546
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2547 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2548 DeviceMemory<std::complex<double>> *x, int incx,
2549 DeviceMemory<std::complex<double>> *y, int incy) {
2550 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2551
2552 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2553 DeviceMemory<std::complex<double>> *, int> impl;
2554 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2555 incy);
2556 }
2557
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2558 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2559 int incx, DeviceMemory<int> *result) {
2560 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2561
2562 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2563 impl;
2564 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2565 result);
2566 }
2567
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2568 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2569 int incx, DeviceMemory<int> *result) {
2570 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2571
2572 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2573 impl;
2574 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2575 result);
2576 }
2577
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2578 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2579 const DeviceMemory<std::complex<float>> &x,
2580 int incx, DeviceMemory<int> *result) {
2581 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2582
2583 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2584 DeviceMemory<int> *> impl;
2585 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2586 result);
2587 }
2588
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2589 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2590 const DeviceMemory<std::complex<double>> &x,
2591 int incx, DeviceMemory<int> *result) {
2592 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2593
2594 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2595 DeviceMemory<int> *> impl;
2596 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2597 result);
2598 }
2599
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2600 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2601 int incx, DeviceMemory<int> *result) {
2602 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2603
2604 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2605 impl;
2606 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2607 result);
2608 }
2609
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2610 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2611 int incx, DeviceMemory<int> *result) {
2612 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2613
2614 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2615 impl;
2616 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2617 result);
2618 }
2619
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2620 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2621 const DeviceMemory<std::complex<float>> &x,
2622 int incx, DeviceMemory<int> *result) {
2623 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2624
2625 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2626 DeviceMemory<int> *> impl;
2627 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2628 result);
2629 }
2630
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2631 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2632 const DeviceMemory<std::complex<double>> &x,
2633 int incx, DeviceMemory<int> *result) {
2634 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2635
2636 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2637 DeviceMemory<int> *> impl;
2638 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2639 result);
2640 }
2641
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2642 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2643 uint64 kl, uint64 ku, float alpha,
2644 const DeviceMemory<float> &a, int lda,
2645 const DeviceMemory<float> &x, int incx, float beta,
2646 DeviceMemory<float> *y, int incy) {
2647 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2648 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2649 PARAM(beta), PARAM(y), PARAM(incy));
2650
2651 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2652 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2653 int, float, DeviceMemory<float> *, int> impl;
2654 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2655 a, lda, x, incx, beta, y, incy);
2656 }
2657
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2658 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2659 uint64 kl, uint64 ku, double alpha,
2660 const DeviceMemory<double> &a, int lda,
2661 const DeviceMemory<double> &x, int incx,
2662 double beta, DeviceMemory<double> *y, int incy) {
2663 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2664 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2665 PARAM(beta), PARAM(y), PARAM(incy));
2666
2667 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2668 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2669 int, double, DeviceMemory<double> *, int> impl;
2670 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2671 a, lda, x, incx, beta, y, incy);
2672 }
2673
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2674 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2675 uint64 kl, uint64 ku, std::complex<float> alpha,
2676 const DeviceMemory<std::complex<float>> &a,
2677 int lda,
2678 const DeviceMemory<std::complex<float>> &x,
2679 int incx, std::complex<float> beta,
2680 DeviceMemory<std::complex<float>> *y, int incy) {
2681 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2682 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2683 PARAM(beta), PARAM(y), PARAM(incy));
2684
2685 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2686 std::complex<float>, const DeviceMemory<std::complex<float>> &,
2687 int, const DeviceMemory<std::complex<float>> &, int,
2688 std::complex<float>, DeviceMemory<std::complex<float>> *,
2689 int> impl;
2690 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2691 a, lda, x, incx, beta, y, incy);
2692 }
2693
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2694 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2695 uint64 kl, uint64 ku, std::complex<double> alpha,
2696 const DeviceMemory<std::complex<double>> &a,
2697 int lda,
2698 const DeviceMemory<std::complex<double>> &x,
2699 int incx, std::complex<double> beta,
2700 DeviceMemory<std::complex<double>> *y, int incy) {
2701 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2702 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2703 PARAM(beta), PARAM(y), PARAM(incy));
2704
2705 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2706 std::complex<double>, const DeviceMemory<std::complex<double>> &,
2707 int, const DeviceMemory<std::complex<double>> &, int,
2708 std::complex<double>, DeviceMemory<std::complex<double>> *,
2709 int> impl;
2710 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2711 a, lda, x, incx, beta, y, incy);
2712 }
2713
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2714 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2715 float alpha, const DeviceMemory<float> &a, int lda,
2716 const DeviceMemory<float> &x, int incx, float beta,
2717 DeviceMemory<float> *y, int incy) {
2718 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2719 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2720 PARAM(incy));
2721
2722 ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2723 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2724 int, float, DeviceMemory<float> *, int> impl;
2725 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2726 x, incx, beta, y, incy);
2727 }
2728
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2729 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2730 double alpha, const DeviceMemory<double> &a,
2731 int lda, const DeviceMemory<double> &x, int incx,
2732 double beta, DeviceMemory<double> *y, int incy) {
2733 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2734 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2735 PARAM(incy));
2736
2737 ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2738 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2739 int, double, DeviceMemory<double> *, int> impl;
2740 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2741 x, incx, beta, y, incy);
2742 }
2743
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2744 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2745 std::complex<float> alpha,
2746 const DeviceMemory<std::complex<float>> &a,
2747 int lda,
2748 const DeviceMemory<std::complex<float>> &x,
2749 int incx, std::complex<float> beta,
2750 DeviceMemory<std::complex<float>> *y, int incy) {
2751 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2752 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2753 PARAM(incy));
2754
2755 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2756 const DeviceMemory<std::complex<float>> &, int,
2757 const DeviceMemory<std::complex<float>> &, int,
2758 std::complex<float>, DeviceMemory<std::complex<float>> *,
2759 int> impl;
2760 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2761 x, incx, beta, y, incy);
2762 }
2763
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2764 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2765 std::complex<double> alpha,
2766 const DeviceMemory<std::complex<double>> &a,
2767 int lda,
2768 const DeviceMemory<std::complex<double>> &x,
2769 int incx, std::complex<double> beta,
2770 DeviceMemory<std::complex<double>> *y, int incy) {
2771 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2772 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2773 PARAM(incy));
2774
2775 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2776 const DeviceMemory<std::complex<double>> &, int,
2777 const DeviceMemory<std::complex<double>> &, int,
2778 std::complex<double>, DeviceMemory<std::complex<double>> *,
2779 int> impl;
2780 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2781 x, incx, beta, y, incy);
2782 }
2783
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2784 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2785 const DeviceMemory<float> &x, int incx,
2786 const DeviceMemory<float> &y, int incy,
2787 DeviceMemory<float> *a, int lda) {
2788 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2789 PARAM(incy), PARAM(a), PARAM(lda));
2790
2791 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2792 const DeviceMemory<float> &, int, DeviceMemory<float> *,
2793 int> impl;
2794 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2795 incy, a, lda);
2796 }
2797
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2798 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2799 const DeviceMemory<double> &x, int incx,
2800 const DeviceMemory<double> &y, int incy,
2801 DeviceMemory<double> *a, int lda) {
2802 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2803 PARAM(incy), PARAM(a), PARAM(lda));
2804
2805 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2806 const DeviceMemory<double> &, int, DeviceMemory<double> *,
2807 int> impl;
2808 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2809 incy, a, lda);
2810 }
2811
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2812 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2813 const DeviceMemory<std::complex<float>> &x,
2814 int incx,
2815 const DeviceMemory<std::complex<float>> &y,
2816 int incy, DeviceMemory<std::complex<float>> *a,
2817 int lda) {
2818 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2819 PARAM(incy), PARAM(a), PARAM(lda));
2820
2821 ThenBlasImpl<uint64, uint64, std::complex<float>,
2822 const DeviceMemory<std::complex<float>> &, int,
2823 const DeviceMemory<std::complex<float>> &, int,
2824 DeviceMemory<std::complex<float>> *, int> impl;
2825 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2826 incy, a, lda);
2827 }
2828
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2829 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2830 const DeviceMemory<std::complex<double>> &x,
2831 int incx,
2832 const DeviceMemory<std::complex<double>> &y,
2833 int incy, DeviceMemory<std::complex<double>> *a,
2834 int lda) {
2835 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2836 PARAM(incy), PARAM(a), PARAM(lda));
2837
2838 ThenBlasImpl<uint64, uint64, std::complex<double>,
2839 const DeviceMemory<std::complex<double>> &, int,
2840 const DeviceMemory<std::complex<double>> &, int,
2841 DeviceMemory<std::complex<double>> *, int> impl;
2842 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2843 incy, a, lda);
2844 }
2845
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2846 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2847 const DeviceMemory<std::complex<float>> &x,
2848 int incx,
2849 const DeviceMemory<std::complex<float>> &y,
2850 int incy, DeviceMemory<std::complex<float>> *a,
2851 int lda) {
2852 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2853 PARAM(incy), PARAM(a), PARAM(lda));
2854
2855 ThenBlasImpl<uint64, uint64, std::complex<float>,
2856 const DeviceMemory<std::complex<float>> &, int,
2857 const DeviceMemory<std::complex<float>> &, int,
2858 DeviceMemory<std::complex<float>> *, int> impl;
2859 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2860 incy, a, lda);
2861 }
2862
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2863 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2864 const DeviceMemory<std::complex<double>> &x,
2865 int incx,
2866 const DeviceMemory<std::complex<double>> &y,
2867 int incy, DeviceMemory<std::complex<double>> *a,
2868 int lda) {
2869 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2870 PARAM(incy), PARAM(a), PARAM(lda));
2871
2872 ThenBlasImpl<uint64, uint64, std::complex<double>,
2873 const DeviceMemory<std::complex<double>> &, int,
2874 const DeviceMemory<std::complex<double>> &, int,
2875 DeviceMemory<std::complex<double>> *, int> impl;
2876 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2877 incy, a, lda);
2878 }
2879
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2880 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2881 std::complex<float> alpha,
2882 const DeviceMemory<std::complex<float>> &a,
2883 int lda,
2884 const DeviceMemory<std::complex<float>> &x,
2885 int incx, std::complex<float> beta,
2886 DeviceMemory<std::complex<float>> *y, int incy) {
2887 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2888 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2889
2890 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2891 const DeviceMemory<std::complex<float>> &, int,
2892 const DeviceMemory<std::complex<float>> &, int,
2893 std::complex<float>, DeviceMemory<std::complex<float>> *,
2894 int> impl;
2895 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2896 x, incx, beta, y, incy);
2897 }
2898
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2899 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2900 std::complex<double> alpha,
2901 const DeviceMemory<std::complex<double>> &a,
2902 int lda,
2903 const DeviceMemory<std::complex<double>> &x,
2904 int incx, std::complex<double> beta,
2905 DeviceMemory<std::complex<double>> *y, int incy) {
2906 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2907 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2908
2909 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2910 const DeviceMemory<std::complex<double>> &, int,
2911 const DeviceMemory<std::complex<double>> &, int,
2912 std::complex<double>, DeviceMemory<std::complex<double>> *,
2913 int> impl;
2914 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2915 x, incx, beta, y, incy);
2916 }
2917
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2918 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2919 std::complex<float> alpha,
2920 const DeviceMemory<std::complex<float>> &a,
2921 int lda,
2922 const DeviceMemory<std::complex<float>> &x,
2923 int incx, std::complex<float> beta,
2924 DeviceMemory<std::complex<float>> *y, int incy) {
2925 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2926 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2927
2928 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2929 const DeviceMemory<std::complex<float>> &, int,
2930 const DeviceMemory<std::complex<float>> &, int,
2931 std::complex<float>, DeviceMemory<std::complex<float>> *,
2932 int> impl;
2933 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2934 incx, beta, y, incy);
2935 }
2936
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2937 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2938 std::complex<double> alpha,
2939 const DeviceMemory<std::complex<double>> &a,
2940 int lda,
2941 const DeviceMemory<std::complex<double>> &x,
2942 int incx, std::complex<double> beta,
2943 DeviceMemory<std::complex<double>> *y, int incy) {
2944 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2945 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2946
2947 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2948 const DeviceMemory<std::complex<double>> &, int,
2949 const DeviceMemory<std::complex<double>> &, int,
2950 std::complex<double>, DeviceMemory<std::complex<double>> *,
2951 int> impl;
2952 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2953 incx, beta, y, incy);
2954 }
2955
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2956 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2957 const DeviceMemory<std::complex<float>> &x,
2958 int incx, DeviceMemory<std::complex<float>> *a,
2959 int lda) {
2960 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2961 PARAM(a), PARAM(lda));
2962
2963 ThenBlasImpl<blas::UpperLower, uint64, float,
2964 const DeviceMemory<std::complex<float>> &, int,
2965 DeviceMemory<std::complex<float>> *, int> impl;
2966 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2967 lda);
2968 }
2969
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2970 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2971 const DeviceMemory<std::complex<double>> &x,
2972 int incx, DeviceMemory<std::complex<double>> *a,
2973 int lda) {
2974 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2975 PARAM(a), PARAM(lda));
2976
2977 ThenBlasImpl<blas::UpperLower, uint64, double,
2978 const DeviceMemory<std::complex<double>> &, int,
2979 DeviceMemory<std::complex<double>> *, int> impl;
2980 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2981 lda);
2982 }
2983
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2984 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2985 std::complex<float> alpha,
2986 const DeviceMemory<std::complex<float>> &x,
2987 int incx,
2988 const DeviceMemory<std::complex<float>> &y,
2989 int incy, DeviceMemory<std::complex<float>> *a,
2990 int lda) {
2991 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2992 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2993
2994 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2995 const DeviceMemory<std::complex<float>> &, int,
2996 const DeviceMemory<std::complex<float>> &, int,
2997 DeviceMemory<std::complex<float>> *, int> impl;
2998 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2999 incy, a, lda);
3000 }
3001
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)3002 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
3003 std::complex<double> alpha,
3004 const DeviceMemory<std::complex<double>> &x,
3005 int incx,
3006 const DeviceMemory<std::complex<double>> &y,
3007 int incy, DeviceMemory<std::complex<double>> *a,
3008 int lda) {
3009 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3010 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3011
3012 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3013 const DeviceMemory<std::complex<double>> &, int,
3014 const DeviceMemory<std::complex<double>> &, int,
3015 DeviceMemory<std::complex<double>> *, int> impl;
3016 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
3017 incy, a, lda);
3018 }
3019
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)3020 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
3021 std::complex<float> alpha,
3022 const DeviceMemory<std::complex<float>> &ap,
3023 const DeviceMemory<std::complex<float>> &x,
3024 int incx, std::complex<float> beta,
3025 DeviceMemory<std::complex<float>> *y, int incy) {
3026 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3027 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3028
3029 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
3030 const DeviceMemory<std::complex<float>> &,
3031 const DeviceMemory<std::complex<float>> &, int,
3032 std::complex<float>, DeviceMemory<std::complex<float>> *,
3033 int> impl;
3034 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
3035 beta, y, incy);
3036 }
3037
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)3038 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
3039 std::complex<double> alpha,
3040 const DeviceMemory<std::complex<double>> &ap,
3041 const DeviceMemory<std::complex<double>> &x,
3042 int incx, std::complex<double> beta,
3043 DeviceMemory<std::complex<double>> *y, int incy) {
3044 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3045 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3046
3047 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3048 const DeviceMemory<std::complex<double>> &,
3049 const DeviceMemory<std::complex<double>> &, int,
3050 std::complex<double>, DeviceMemory<std::complex<double>> *,
3051 int> impl;
3052 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
3053 beta, y, incy);
3054 }
3055
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)3056 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
3057 const DeviceMemory<std::complex<float>> &x,
3058 int incx, DeviceMemory<std::complex<float>> *ap) {
3059 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3060 PARAM(ap));
3061
3062 ThenBlasImpl<blas::UpperLower, uint64, float,
3063 const DeviceMemory<std::complex<float>> &, int,
3064 DeviceMemory<std::complex<float>> *> impl;
3065 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
3066 }
3067
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)3068 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
3069 const DeviceMemory<std::complex<double>> &x,
3070 int incx, DeviceMemory<std::complex<double>> *ap) {
3071 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3072 PARAM(ap));
3073
3074 ThenBlasImpl<blas::UpperLower, uint64, double,
3075 const DeviceMemory<std::complex<double>> &, int,
3076 DeviceMemory<std::complex<double>> *> impl;
3077 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
3078 }
3079
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)3080 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
3081 std::complex<float> alpha,
3082 const DeviceMemory<std::complex<float>> &x,
3083 int incx,
3084 const DeviceMemory<std::complex<float>> &y,
3085 int incy, DeviceMemory<std::complex<float>> *ap) {
3086 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3087 PARAM(y), PARAM(incy), PARAM(ap));
3088
3089 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
3090 const DeviceMemory<std::complex<float>> &, int,
3091 const DeviceMemory<std::complex<float>> &, int,
3092 DeviceMemory<std::complex<float>> *> impl;
3093 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
3094 incy, ap);
3095 }
3096
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)3097 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
3098 std::complex<double> alpha,
3099 const DeviceMemory<std::complex<double>> &x,
3100 int incx,
3101 const DeviceMemory<std::complex<double>> &y,
3102 int incy, DeviceMemory<std::complex<double>> *ap) {
3103 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3104 PARAM(y), PARAM(incy), PARAM(ap));
3105
3106 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3107 const DeviceMemory<std::complex<double>> &, int,
3108 const DeviceMemory<std::complex<double>> &, int,
3109 DeviceMemory<std::complex<double>> *> impl;
3110 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
3111 incy, ap);
3112 }
3113
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3114 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
3115 float alpha, const DeviceMemory<float> &a, int lda,
3116 const DeviceMemory<float> &x, int incx, float beta,
3117 DeviceMemory<float> *y, int incy) {
3118 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
3119 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3120
3121 ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
3122 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3123 int, float, DeviceMemory<float> *, int> impl;
3124 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
3125 x, incx, beta, y, incy);
3126 }
3127
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3128 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
3129 double alpha, const DeviceMemory<double> &a,
3130 int lda, const DeviceMemory<double> &x, int incx,
3131 double beta, DeviceMemory<double> *y, int incy) {
3132 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
3133 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3134
3135 ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
3136 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3137 int, double, DeviceMemory<double> *, int> impl;
3138 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
3139 x, incx, beta, y, incy);
3140 }
3141
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3142 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
3143 const DeviceMemory<float> &ap,
3144 const DeviceMemory<float> &x, int incx, float beta,
3145 DeviceMemory<float> *y, int incy) {
3146 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3147 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3148
3149 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3150 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3151 int> impl;
3152 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3153 beta, y, incy);
3154 }
3155
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3156 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3157 const DeviceMemory<double> &ap,
3158 const DeviceMemory<double> &x, int incx,
3159 double beta, DeviceMemory<double> *y, int incy) {
3160 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3161 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3162
3163 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3164 const DeviceMemory<double> &, int, double,
3165 DeviceMemory<double> *, int> impl;
3166 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3167 beta, y, incy);
3168 }
3169
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3170 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3171 const DeviceMemory<float> &x, int incx,
3172 DeviceMemory<float> *ap) {
3173 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3174 PARAM(ap));
3175
3176 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3177 int, DeviceMemory<float> *> impl;
3178 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3179 }
3180
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3181 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3182 const DeviceMemory<double> &x, int incx,
3183 DeviceMemory<double> *ap) {
3184 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3185 PARAM(ap));
3186
3187 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3188 int, DeviceMemory<double> *> impl;
3189 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3190 }
3191
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3192 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3193 const DeviceMemory<float> &x, int incx,
3194 const DeviceMemory<float> &y, int incy,
3195 DeviceMemory<float> *ap) {
3196 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3197 PARAM(y), PARAM(incy), PARAM(ap));
3198
3199 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3200 int, const DeviceMemory<float> &, int,
3201 DeviceMemory<float> *> impl;
3202 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3203 incy, ap);
3204 }
3205
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3206 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3207 const DeviceMemory<double> &x, int incx,
3208 const DeviceMemory<double> &y, int incy,
3209 DeviceMemory<double> *ap) {
3210 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3211 PARAM(y), PARAM(incy), PARAM(ap));
3212
3213 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3214 int, const DeviceMemory<double> &, int,
3215 DeviceMemory<double> *> impl;
3216 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3217 incy, ap);
3218 }
3219
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3220 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3221 const DeviceMemory<float> &a, int lda,
3222 const DeviceMemory<float> &x, int incx, float beta,
3223 DeviceMemory<float> *y, int incy) {
3224 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3225 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3226
3227 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3228 int, const DeviceMemory<float> &, int, float,
3229 DeviceMemory<float> *, int> impl;
3230 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3231 incx, beta, y, incy);
3232 }
3233
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3234 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3235 const DeviceMemory<double> &a, int lda,
3236 const DeviceMemory<double> &x, int incx,
3237 double beta, DeviceMemory<double> *y, int incy) {
3238 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3239 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3240
3241 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3242 int, const DeviceMemory<double> &, int, double,
3243 DeviceMemory<double> *, int> impl;
3244 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3245 incx, beta, y, incy);
3246 }
3247
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3248 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3249 const DeviceMemory<float> &x, int incx,
3250 DeviceMemory<float> *a, int lda) {
3251 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3252 PARAM(a), PARAM(lda));
3253
3254 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3255 int, DeviceMemory<float> *, int> impl;
3256 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3257 lda);
3258 }
3259
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3260 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3261 const DeviceMemory<double> &x, int incx,
3262 DeviceMemory<double> *a, int lda) {
3263 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3264 PARAM(a), PARAM(lda));
3265
3266 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3267 int, DeviceMemory<double> *, int> impl;
3268 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3269 lda);
3270 }
3271
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3272 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3273 const DeviceMemory<float> &x, int incx,
3274 const DeviceMemory<float> &y, int incy,
3275 DeviceMemory<float> *a, int lda) {
3276 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3277 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3278
3279 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3280 int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3281 int> impl;
3282 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3283 incy, a, lda);
3284 }
3285
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3286 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3287 const DeviceMemory<double> &x, int incx,
3288 const DeviceMemory<double> &y, int incy,
3289 DeviceMemory<double> *a, int lda) {
3290 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3291 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3292
3293 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3294 int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3295 int> impl;
3296 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3297 incy, a, lda);
3298 }
3299
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3300 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3301 blas::Diagonal diag, uint64 n, uint64 k,
3302 const DeviceMemory<float> &a, int lda,
3303 DeviceMemory<float> *x, int incx) {
3304 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3305 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3306
3307 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3308 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3309 int> impl;
3310 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3311 lda, x, incx);
3312 }
3313
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3314 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3315 blas::Diagonal diag, uint64 n, uint64 k,
3316 const DeviceMemory<double> &a, int lda,
3317 DeviceMemory<double> *x, int incx) {
3318 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3319 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3320
3321 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3322 uint64, const DeviceMemory<double> &, int,
3323 DeviceMemory<double> *, int> impl;
3324 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3325 lda, x, incx);
3326 }
3327
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3328 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3329 blas::Diagonal diag, uint64 n, uint64 k,
3330 const DeviceMemory<std::complex<float>> &a,
3331 int lda, DeviceMemory<std::complex<float>> *x,
3332 int incx) {
3333 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3334 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3335
3336 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3337 uint64, const DeviceMemory<std::complex<float>> &, int,
3338 DeviceMemory<std::complex<float>> *, int> impl;
3339 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3340 lda, x, incx);
3341 }
3342
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3343 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3344 blas::Diagonal diag, uint64 n, uint64 k,
3345 const DeviceMemory<std::complex<double>> &a,
3346 int lda, DeviceMemory<std::complex<double>> *x,
3347 int incx) {
3348 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3349 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3350
3351 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3352 uint64, const DeviceMemory<std::complex<double>> &, int,
3353 DeviceMemory<std::complex<double>> *, int> impl;
3354 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3355 lda, x, incx);
3356 }
3357
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3358 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3359 blas::Diagonal diag, uint64 n, uint64 k,
3360 const DeviceMemory<float> &a, int lda,
3361 DeviceMemory<float> *x, int incx) {
3362 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3363 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3364
3365 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3366 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3367 int> impl;
3368 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3369 lda, x, incx);
3370 }
3371
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3372 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3373 blas::Diagonal diag, uint64 n, uint64 k,
3374 const DeviceMemory<double> &a, int lda,
3375 DeviceMemory<double> *x, int incx) {
3376 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3377 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3378
3379 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3380 uint64, const DeviceMemory<double> &, int,
3381 DeviceMemory<double> *, int> impl;
3382 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3383 lda, x, incx);
3384 }
3385
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3386 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3387 blas::Diagonal diag, uint64 n, uint64 k,
3388 const DeviceMemory<std::complex<float>> &a,
3389 int lda, DeviceMemory<std::complex<float>> *x,
3390 int incx) {
3391 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3392 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3393
3394 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3395 uint64, const DeviceMemory<std::complex<float>> &, int,
3396 DeviceMemory<std::complex<float>> *, int> impl;
3397 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3398 lda, x, incx);
3399 }
3400
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3401 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3402 blas::Diagonal diag, uint64 n, uint64 k,
3403 const DeviceMemory<std::complex<double>> &a,
3404 int lda, DeviceMemory<std::complex<double>> *x,
3405 int incx) {
3406 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3407 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3408
3409 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3410 uint64, const DeviceMemory<std::complex<double>> &, int,
3411 DeviceMemory<std::complex<double>> *, int> impl;
3412 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3413 lda, x, incx);
3414 }
3415
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3416 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3417 blas::Diagonal diag, uint64 n,
3418 const DeviceMemory<float> &ap,
3419 DeviceMemory<float> *x, int incx) {
3420 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3421 PARAM(x), PARAM(incx));
3422
3423 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3424 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3425 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3426 incx);
3427 }
3428
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3429 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3430 blas::Diagonal diag, uint64 n,
3431 const DeviceMemory<double> &ap,
3432 DeviceMemory<double> *x, int incx) {
3433 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3434 PARAM(x), PARAM(incx));
3435
3436 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3437 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3438 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3439 incx);
3440 }
3441
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3442 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3443 blas::Diagonal diag, uint64 n,
3444 const DeviceMemory<std::complex<float>> &ap,
3445 DeviceMemory<std::complex<float>> *x, int incx) {
3446 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3447 PARAM(x), PARAM(incx));
3448
3449 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3450 const DeviceMemory<std::complex<float>> &,
3451 DeviceMemory<std::complex<float>> *, int> impl;
3452 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3453 incx);
3454 }
3455
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3456 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3457 blas::Diagonal diag, uint64 n,
3458 const DeviceMemory<std::complex<double>> &ap,
3459 DeviceMemory<std::complex<double>> *x, int incx) {
3460 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3461 PARAM(x), PARAM(incx));
3462
3463 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3464 const DeviceMemory<std::complex<double>> &,
3465 DeviceMemory<std::complex<double>> *, int> impl;
3466 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3467 incx);
3468 }
3469
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3470 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3471 blas::Diagonal diag, uint64 n,
3472 const DeviceMemory<float> &ap,
3473 DeviceMemory<float> *x, int incx) {
3474 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3475 PARAM(x), PARAM(incx));
3476
3477 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3478 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3479 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3480 incx);
3481 }
3482
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3483 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3484 blas::Diagonal diag, uint64 n,
3485 const DeviceMemory<double> &ap,
3486 DeviceMemory<double> *x, int incx) {
3487 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3488 PARAM(x), PARAM(incx));
3489
3490 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3491 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3492 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3493 incx);
3494 }
3495
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3496 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3497 blas::Diagonal diag, uint64 n,
3498 const DeviceMemory<std::complex<float>> &ap,
3499 DeviceMemory<std::complex<float>> *x, int incx) {
3500 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3501 PARAM(x), PARAM(incx));
3502
3503 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3504 const DeviceMemory<std::complex<float>> &,
3505 DeviceMemory<std::complex<float>> *, int> impl;
3506 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3507 incx);
3508 }
3509
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3510 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3511 blas::Diagonal diag, uint64 n,
3512 const DeviceMemory<std::complex<double>> &ap,
3513 DeviceMemory<std::complex<double>> *x, int incx) {
3514 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3515 PARAM(x), PARAM(incx));
3516
3517 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3518 const DeviceMemory<std::complex<double>> &,
3519 DeviceMemory<std::complex<double>> *, int> impl;
3520 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3521 incx);
3522 }
3523
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3524 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3525 blas::Diagonal diag, uint64 n,
3526 const DeviceMemory<float> &a, int lda,
3527 DeviceMemory<float> *x, int incx) {
3528 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3529 PARAM(lda), PARAM(x), PARAM(incx));
3530
3531 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3532 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3533 int> impl;
3534 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3535 lda, x, incx);
3536 }
3537
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3538 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3539 blas::Diagonal diag, uint64 n,
3540 const DeviceMemory<double> &a, int lda,
3541 DeviceMemory<double> *x, int incx) {
3542 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3543 PARAM(lda), PARAM(x), PARAM(incx));
3544
3545 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3546 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3547 int> impl;
3548 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3549 lda, x, incx);
3550 }
3551
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3552 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3553 blas::Diagonal diag, uint64 n,
3554 const DeviceMemory<std::complex<float>> &a,
3555 int lda, DeviceMemory<std::complex<float>> *x,
3556 int incx) {
3557 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3558 PARAM(lda), PARAM(x), PARAM(incx));
3559
3560 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3561 const DeviceMemory<std::complex<float>> &, int,
3562 DeviceMemory<std::complex<float>> *, int> impl;
3563 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3564 lda, x, incx);
3565 }
3566
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3567 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3568 blas::Diagonal diag, uint64 n,
3569 const DeviceMemory<std::complex<double>> &a,
3570 int lda, DeviceMemory<std::complex<double>> *x,
3571 int incx) {
3572 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3573 PARAM(lda), PARAM(x), PARAM(incx));
3574
3575 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3576 const DeviceMemory<std::complex<double>> &, int,
3577 DeviceMemory<std::complex<double>> *, int> impl;
3578 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3579 lda, x, incx);
3580 }
3581
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3582 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3583 blas::Diagonal diag, uint64 n,
3584 const DeviceMemory<float> &a, int lda,
3585 DeviceMemory<float> *x, int incx) {
3586 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3587 PARAM(lda), PARAM(x), PARAM(incx));
3588
3589 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3590 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3591 int> impl;
3592 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3593 lda, x, incx);
3594 }
3595
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3596 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3597 blas::Diagonal diag, uint64 n,
3598 const DeviceMemory<double> &a, int lda,
3599 DeviceMemory<double> *x, int incx) {
3600 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3601 PARAM(lda), PARAM(x), PARAM(incx));
3602
3603 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3604 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3605 int> impl;
3606 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3607 lda, x, incx);
3608 }
3609
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3610 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3611 blas::Diagonal diag, uint64 n,
3612 const DeviceMemory<std::complex<float>> &a,
3613 int lda, DeviceMemory<std::complex<float>> *x,
3614 int incx) {
3615 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3616 PARAM(lda), PARAM(x), PARAM(incx));
3617
3618 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3619 const DeviceMemory<std::complex<float>> &, int,
3620 DeviceMemory<std::complex<float>> *, int> impl;
3621 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3622 lda, x, incx);
3623 }
3624
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3625 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3626 blas::Diagonal diag, uint64 n,
3627 const DeviceMemory<std::complex<double>> &a,
3628 int lda, DeviceMemory<std::complex<double>> *x,
3629 int incx) {
3630 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3631 PARAM(lda), PARAM(x), PARAM(incx));
3632
3633 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3634 const DeviceMemory<std::complex<double>> &, int,
3635 DeviceMemory<std::complex<double>> *, int> impl;
3636 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3637 lda, x, incx);
3638 }
3639
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3640 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3641 uint64 m, uint64 n, uint64 k, float alpha,
3642 const DeviceMemory<Eigen::half> &a, int lda,
3643 const DeviceMemory<Eigen::half> &b, int ldb,
3644 float beta,
3645 DeviceMemory<Eigen::half> *c, int ldc) {
3646 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3647 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3648 PARAM(beta), PARAM(c), PARAM(ldc));
3649
3650 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3651 const DeviceMemory<Eigen::half> &, int,
3652 const DeviceMemory<Eigen::half> &, int,
3653 float, DeviceMemory<Eigen::half> *, int> impl;
3654 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3655 alpha, a, lda, b, ldb, beta, c, ldc);
3656 }
3657
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3658 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3659 uint64 m, uint64 n, uint64 k, float alpha,
3660 const DeviceMemory<float> &a, int lda,
3661 const DeviceMemory<float> &b, int ldb, float beta,
3662 DeviceMemory<float> *c, int ldc) {
3663 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3664 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3665 PARAM(beta), PARAM(c), PARAM(ldc));
3666
3667 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3668 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3669 int, float, DeviceMemory<float> *, int> impl;
3670 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3671 alpha, a, lda, b, ldb, beta, c, ldc);
3672 }
3673
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3674 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3675 uint64 m, uint64 n, uint64 k, double alpha,
3676 const DeviceMemory<double> &a, int lda,
3677 const DeviceMemory<double> &b, int ldb,
3678 double beta, DeviceMemory<double> *c, int ldc) {
3679 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3680 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3681 PARAM(beta), PARAM(c), PARAM(ldc));
3682
3683 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3684 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3685 int, double, DeviceMemory<double> *, int> impl;
3686 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3687 alpha, a, lda, b, ldb, beta, c, ldc);
3688 }
3689
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3690 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3691 uint64 m, uint64 n, uint64 k,
3692 std::complex<float> alpha,
3693 const DeviceMemory<std::complex<float>> &a,
3694 int lda,
3695 const DeviceMemory<std::complex<float>> &b,
3696 int ldb, std::complex<float> beta,
3697 DeviceMemory<std::complex<float>> *c, int ldc) {
3698 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3699 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3700 PARAM(beta), PARAM(c), PARAM(ldc));
3701
3702 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3703 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3704 int, const DeviceMemory<std::complex<float>> &, int,
3705 std::complex<float>, DeviceMemory<std::complex<float>> *,
3706 int> impl;
3707 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3708 alpha, a, lda, b, ldb, beta, c, ldc);
3709 }
3710
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3711 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3712 uint64 m, uint64 n, uint64 k,
3713 std::complex<double> alpha,
3714 const DeviceMemory<std::complex<double>> &a,
3715 int lda,
3716 const DeviceMemory<std::complex<double>> &b,
3717 int ldb, std::complex<double> beta,
3718 DeviceMemory<std::complex<double>> *c, int ldc) {
3719 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3720 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3721 PARAM(beta), PARAM(c), PARAM(ldc));
3722
3723 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3724 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3725 int, const DeviceMemory<std::complex<double>> &, int,
3726 std::complex<double>, DeviceMemory<std::complex<double>> *,
3727 int> impl;
3728 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3729 alpha, a, lda, b, ldb, beta, c, ldc);
3730 }
3731
3732 namespace {
3733 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3734 // blas::ProfileResult*. This functor doesn't put the stream into an error
3735 // state if the op fails and the profile result is non-null. Instead, the
3736 // error-ness is returned in the profile result itself.
3737 template <typename... Args>
3738 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon537844490211::ThenBlasWithProfileImpl3739 Stream &operator()(Stream *stream,
3740 bool (blas::BlasSupport::*blas_func)(
3741 Stream *, Args..., blas::ProfileResult *),
3742 Args... args, blas::ProfileResult *profile_result) {
3743 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3744 bool record_error = profile_result == nullptr;
3745 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3746 }
3747 };
3748 } // anonymous namespace
3749
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3750 Stream &Stream::ThenBlasGemvWithProfiling(
3751 blas::Transpose trans, uint64 m, uint64 n, float alpha,
3752 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3753 int incx, float beta, DeviceMemory<float> *y, int incy,
3754 blas::ProfileResult *output_profile_result) {
3755 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3756 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3757 PARAM(incy));
3758
3759 ThenBlasWithProfileImpl<
3760 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3761 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3762 impl;
3763 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3764 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3765 }
3766
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3767 Stream &Stream::ThenBlasGemvWithProfiling(
3768 blas::Transpose trans, uint64 m, uint64 n, double alpha,
3769 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3770 int incx, double beta, DeviceMemory<double> *y, int incy,
3771 blas::ProfileResult *output_profile_result) {
3772 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3773 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3774 PARAM(incy));
3775
3776 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3777 const DeviceMemory<double> &, int,
3778 const DeviceMemory<double> &, int, double,
3779 DeviceMemory<double> *, int>
3780 impl;
3781 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3782 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3783 }
3784
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3785 Stream &Stream::ThenBlasGemvWithProfiling(
3786 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3787 const DeviceMemory<std::complex<float>> &a, int lda,
3788 const DeviceMemory<std::complex<float>> &x, int incx,
3789 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3790 blas::ProfileResult *output_profile_result) {
3791 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3792 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3793 PARAM(incy));
3794
3795 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3796 const DeviceMemory<std::complex<float>> &, int,
3797 const DeviceMemory<std::complex<float>> &, int,
3798 std::complex<float>,
3799 DeviceMemory<std::complex<float>> *, int>
3800 impl;
3801 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3802 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3803 }
3804
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3805 Stream &Stream::ThenBlasGemvWithProfiling(
3806 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3807 const DeviceMemory<std::complex<double>> &a, int lda,
3808 const DeviceMemory<std::complex<double>> &x, int incx,
3809 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3810 blas::ProfileResult *output_profile_result) {
3811 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3812 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3813 PARAM(incy));
3814
3815 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3816 const DeviceMemory<std::complex<double>> &, int,
3817 const DeviceMemory<std::complex<double>> &, int,
3818 std::complex<double>,
3819 DeviceMemory<std::complex<double>> *, int>
3820 impl;
3821 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3822 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3823 }
3824
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3825 Stream &Stream::ThenBlasGemmWithProfiling(
3826 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3827 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3828 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3829 DeviceMemory<Eigen::half> *c, int ldc,
3830 blas::ProfileResult *output_profile_result) {
3831 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3832 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3833 PARAM(beta), PARAM(c), PARAM(ldc));
3834
3835 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3836 uint64, float, const DeviceMemory<Eigen::half> &, int,
3837 const DeviceMemory<Eigen::half> &, int, float,
3838 DeviceMemory<Eigen::half> *, int>
3839 impl;
3840 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3841 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3842 output_profile_result);
3843 }
3844
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3845 Stream &Stream::ThenBlasGemmWithProfiling(
3846 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3847 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3848 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3849 int ldc, blas::ProfileResult *output_profile_result) {
3850 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3851 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3852 PARAM(beta), PARAM(c), PARAM(ldc));
3853
3854 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3855 uint64, float, const DeviceMemory<float> &, int,
3856 const DeviceMemory<float> &, int, float,
3857 DeviceMemory<float> *, int>
3858 impl;
3859 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3860 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3861 output_profile_result);
3862 }
3863
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3864 Stream &Stream::ThenBlasGemmWithProfiling(
3865 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3866 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3867 const DeviceMemory<double> &b, int ldb, double beta,
3868 DeviceMemory<double> *c, int ldc,
3869 blas::ProfileResult *output_profile_result) {
3870 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3871 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3872 PARAM(beta), PARAM(c), PARAM(ldc));
3873
3874 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3875 uint64, double, const DeviceMemory<double> &, int,
3876 const DeviceMemory<double> &, int, double,
3877 DeviceMemory<double> *, int>
3878 impl;
3879 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3880 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3881 output_profile_result);
3882 }
3883
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3884 Stream &Stream::ThenBlasGemmWithProfiling(
3885 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3886 uint64 k, std::complex<float> alpha,
3887 const DeviceMemory<std::complex<float>> &a, int lda,
3888 const DeviceMemory<std::complex<float>> &b, int ldb,
3889 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3890 blas::ProfileResult *output_profile_result) {
3891 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3892 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3893 PARAM(beta), PARAM(c), PARAM(ldc));
3894
3895 ThenBlasWithProfileImpl<
3896 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3897 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3898 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3899 DeviceMemory<std::complex<float>> *, int>
3900 impl;
3901 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3902 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3903 output_profile_result);
3904 }
3905
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3906 Stream &Stream::ThenBlasGemmWithProfiling(
3907 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3908 uint64 k, std::complex<double> alpha,
3909 const DeviceMemory<std::complex<double>> &a, int lda,
3910 const DeviceMemory<std::complex<double>> &b, int ldb,
3911 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3912 blas::ProfileResult *output_profile_result) {
3913 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3914 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3915 PARAM(beta), PARAM(c), PARAM(ldc));
3916
3917 ThenBlasWithProfileImpl<
3918 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3919 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3920 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3921 DeviceMemory<std::complex<double>> *, int>
3922 impl;
3923 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3924 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3925 output_profile_result);
3926 }
3927
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3928 Stream &Stream::ThenBlasGemmWithAlgorithm(
3929 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3930 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3931 const DeviceMemory<Eigen::half> &a, int lda,
3932 const DeviceMemory<Eigen::half> &b, int ldb,
3933 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3934 int ldc, blas::ComputationType computation_type,
3935 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3936 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3937 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3938 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3939 PARAM(algorithm));
3940
3941 ThenBlasWithProfileImpl<
3942 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3943 const HostOrDeviceScalar<Eigen::half> &,
3944 const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3945 int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3946 int, blas::ComputationType, blas::AlgorithmType>
3947 impl;
3948 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3949 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3950 algorithm, output_profile_result);
3951 }
3952
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3953 Stream &Stream::ThenBlasGemmWithAlgorithm(
3954 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3955 uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3956 int lda, const DeviceMemory<int8> &b, int ldb,
3957 const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3958 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3959 blas::ProfileResult *output_profile_result) {
3960 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3961 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3962 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3963 PARAM(algorithm));
3964
3965 ThenBlasWithProfileImpl<
3966 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3967 const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3968 const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3969 DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3970 impl;
3971 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3972 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3973 algorithm, output_profile_result);
3974 }
3975
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3976 Stream &Stream::ThenBlasGemmWithAlgorithm(
3977 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3978 uint64 k, const HostOrDeviceScalar<float> &alpha,
3979 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3980 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3981 int ldc, blas::ComputationType computation_type,
3982 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3983 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3984 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3985 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3986 PARAM(algorithm));
3987
3988 ThenBlasWithProfileImpl<
3989 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3990 const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3991 const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3992 DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3993 impl;
3994 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3995 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3996 algorithm, output_profile_result);
3997 }
3998
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3999 Stream &Stream::ThenBlasGemmWithAlgorithm(
4000 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4001 uint64 k, const HostOrDeviceScalar<double> &alpha,
4002 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
4003 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
4004 int ldc, blas::ComputationType computation_type,
4005 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
4006 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4007 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4008 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4009 PARAM(algorithm));
4010
4011 ThenBlasWithProfileImpl<
4012 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4013 const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
4014 const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
4015 DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
4016 impl;
4017 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4018 m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
4019 HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
4020 algorithm, output_profile_result);
4021 }
4022
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)4023 Stream &Stream::ThenBlasGemmWithAlgorithm(
4024 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4025 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
4026 const DeviceMemory<std::complex<float>> &a, int lda,
4027 const DeviceMemory<std::complex<float>> &b, int ldb,
4028 const HostOrDeviceScalar<std::complex<float>> &beta,
4029 DeviceMemory<std::complex<float>> *c, int ldc,
4030 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
4031 blas::ProfileResult *output_profile_result) {
4032 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4033 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4034 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4035 PARAM(algorithm));
4036
4037 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
4038 uint64,
4039 const HostOrDeviceScalar<std::complex<float>> &,
4040 const DeviceMemory<std::complex<float>> &, int,
4041 const DeviceMemory<std::complex<float>> &, int,
4042 const HostOrDeviceScalar<std::complex<float>> &,
4043 DeviceMemory<std::complex<float>> *, int,
4044 blas::ComputationType, blas::AlgorithmType>
4045 impl;
4046 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4047 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
4048 algorithm, output_profile_result);
4049 }
4050
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)4051 Stream &Stream::ThenBlasGemmWithAlgorithm(
4052 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4053 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
4054 const DeviceMemory<std::complex<double>> &a, int lda,
4055 const DeviceMemory<std::complex<double>> &b, int ldb,
4056 const HostOrDeviceScalar<std::complex<double>> &beta,
4057 DeviceMemory<std::complex<double>> *c, int ldc,
4058 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
4059 blas::ProfileResult *output_profile_result) {
4060 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4061 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4062 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4063 PARAM(algorithm));
4064
4065 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
4066 uint64,
4067 const HostOrDeviceScalar<std::complex<double>> &,
4068 const DeviceMemory<std::complex<double>> &, int,
4069 const DeviceMemory<std::complex<double>> &, int,
4070 const HostOrDeviceScalar<std::complex<double>> &,
4071 DeviceMemory<std::complex<double>> *, int,
4072 blas::ComputationType, blas::AlgorithmType>
4073 impl;
4074 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4075 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
4076 algorithm, output_profile_result);
4077 }
4078
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4079 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
4080 uint64 n, std::complex<float> alpha,
4081 const DeviceMemory<std::complex<float>> &a,
4082 int lda,
4083 const DeviceMemory<std::complex<float>> &b,
4084 int ldb, std::complex<float> beta,
4085 DeviceMemory<std::complex<float>> *c, int ldc) {
4086 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4087 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4088 PARAM(ldc));
4089
4090 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4091 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4092 int, const DeviceMemory<std::complex<float>> &, int,
4093 std::complex<float>, DeviceMemory<std::complex<float>> *,
4094 int> impl;
4095 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
4096 lda, b, ldb, beta, c, ldc);
4097 }
4098
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4099 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
4100 uint64 n, std::complex<double> alpha,
4101 const DeviceMemory<std::complex<double>> &a,
4102 int lda,
4103 const DeviceMemory<std::complex<double>> &b,
4104 int ldb, std::complex<double> beta,
4105 DeviceMemory<std::complex<double>> *c, int ldc) {
4106 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4107 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4108 PARAM(ldc));
4109
4110 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4111 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4112 int, const DeviceMemory<std::complex<double>> &, int,
4113 std::complex<double>, DeviceMemory<std::complex<double>> *,
4114 int> impl;
4115 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
4116 lda, b, ldb, beta, c, ldc);
4117 }
4118
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4119 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
4120 uint64 n, uint64 k, float alpha,
4121 const DeviceMemory<std::complex<float>> &a,
4122 int lda, float beta,
4123 DeviceMemory<std::complex<float>> *c, int ldc) {
4124 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4125 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4126
4127 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4128 const DeviceMemory<std::complex<float>> &, int, float,
4129 DeviceMemory<std::complex<float>> *, int> impl;
4130 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4131 lda, beta, c, ldc);
4132 }
4133
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4134 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
4135 uint64 n, uint64 k, double alpha,
4136 const DeviceMemory<std::complex<double>> &a,
4137 int lda, double beta,
4138 DeviceMemory<std::complex<double>> *c, int ldc) {
4139 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4140 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4141
4142 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4143 const DeviceMemory<std::complex<double>> &, int, double,
4144 DeviceMemory<std::complex<double>> *, int> impl;
4145 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4146 lda, beta, c, ldc);
4147 }
4148
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4149 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4150 uint64 n, uint64 k, std::complex<float> alpha,
4151 const DeviceMemory<std::complex<float>> &a,
4152 int lda,
4153 const DeviceMemory<std::complex<float>> &b,
4154 int ldb, float beta,
4155 DeviceMemory<std::complex<float>> *c, int ldc) {
4156 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4157 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4158 PARAM(ldc));
4159
4160 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4161 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4162 int, const DeviceMemory<std::complex<float>> &, int, float,
4163 DeviceMemory<std::complex<float>> *, int> impl;
4164 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4165 a, lda, b, ldb, beta, c, ldc);
4166 }
4167
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4168 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4169 uint64 n, uint64 k, std::complex<double> alpha,
4170 const DeviceMemory<std::complex<double>> &a,
4171 int lda,
4172 const DeviceMemory<std::complex<double>> &b,
4173 int ldb, double beta,
4174 DeviceMemory<std::complex<double>> *c, int ldc) {
4175 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4176 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4177 PARAM(ldc));
4178
4179 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4180 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4181 int, const DeviceMemory<std::complex<double>> &, int, double,
4182 DeviceMemory<std::complex<double>> *, int> impl;
4183 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4184 a, lda, b, ldb, beta, c, ldc);
4185 }
4186
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4187 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4188 uint64 n, float alpha,
4189 const DeviceMemory<float> &a, int lda,
4190 const DeviceMemory<float> &b, int ldb, float beta,
4191 DeviceMemory<float> *c, int ldc) {
4192 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4193 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4194 PARAM(ldc));
4195
4196 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4197 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4198 int, float, DeviceMemory<float> *, int> impl;
4199 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4200 lda, b, ldb, beta, c, ldc);
4201 }
4202
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4203 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4204 uint64 n, double alpha,
4205 const DeviceMemory<double> &a, int lda,
4206 const DeviceMemory<double> &b, int ldb,
4207 double beta, DeviceMemory<double> *c, int ldc) {
4208 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4209 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4210 PARAM(ldc));
4211
4212 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4213 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4214 int, double, DeviceMemory<double> *, int> impl;
4215 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4216 lda, b, ldb, beta, c, ldc);
4217 }
4218
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4219 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4220 uint64 n, std::complex<float> alpha,
4221 const DeviceMemory<std::complex<float>> &a,
4222 int lda,
4223 const DeviceMemory<std::complex<float>> &b,
4224 int ldb, std::complex<float> beta,
4225 DeviceMemory<std::complex<float>> *c, int ldc) {
4226 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4227 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4228 PARAM(ldc));
4229
4230 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4231 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4232 int, const DeviceMemory<std::complex<float>> &, int,
4233 std::complex<float>, DeviceMemory<std::complex<float>> *,
4234 int> impl;
4235 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4236 lda, b, ldb, beta, c, ldc);
4237 }
4238
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4239 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4240 uint64 n, std::complex<double> alpha,
4241 const DeviceMemory<std::complex<double>> &a,
4242 int lda,
4243 const DeviceMemory<std::complex<double>> &b,
4244 int ldb, std::complex<double> beta,
4245 DeviceMemory<std::complex<double>> *c, int ldc) {
4246 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4247 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4248 PARAM(ldc));
4249
4250 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4251 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4252 int, const DeviceMemory<std::complex<double>> &, int,
4253 std::complex<double>, DeviceMemory<std::complex<double>> *,
4254 int> impl;
4255 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4256 lda, b, ldb, beta, c, ldc);
4257 }
4258
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4259 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4260 uint64 n, uint64 k, float alpha,
4261 const DeviceMemory<float> &a, int lda, float beta,
4262 DeviceMemory<float> *c, int ldc) {
4263 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4264 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4265
4266 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4267 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4268 int> impl;
4269 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4270 lda, beta, c, ldc);
4271 }
4272
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4273 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4274 uint64 n, uint64 k, double alpha,
4275 const DeviceMemory<double> &a, int lda,
4276 double beta, DeviceMemory<double> *c, int ldc) {
4277 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4278 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4279
4280 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4281 const DeviceMemory<double> &, int, double,
4282 DeviceMemory<double> *, int> impl;
4283 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4284 lda, beta, c, ldc);
4285 }
4286
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4287 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4288 uint64 n, uint64 k, std::complex<float> alpha,
4289 const DeviceMemory<std::complex<float>> &a,
4290 int lda, std::complex<float> beta,
4291 DeviceMemory<std::complex<float>> *c, int ldc) {
4292 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4293 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4294
4295 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4296 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4297 int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4298 int> impl;
4299 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4300 lda, beta, c, ldc);
4301 }
4302
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4303 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4304 uint64 n, uint64 k, std::complex<double> alpha,
4305 const DeviceMemory<std::complex<double>> &a,
4306 int lda, std::complex<double> beta,
4307 DeviceMemory<std::complex<double>> *c, int ldc) {
4308 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4309 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4310
4311 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4312 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4313 int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4314 int> impl;
4315 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4316 lda, beta, c, ldc);
4317 }
4318
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4319 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4320 uint64 n, uint64 k, float alpha,
4321 const DeviceMemory<float> &a, int lda,
4322 const DeviceMemory<float> &b, int ldb, float beta,
4323 DeviceMemory<float> *c, int ldc) {
4324 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4325 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4326 PARAM(ldc));
4327
4328 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4329 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4330 int, float, DeviceMemory<float> *, int> impl;
4331 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4332 a, lda, b, ldb, beta, c, ldc);
4333 }
4334
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4335 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4336 uint64 n, uint64 k, double alpha,
4337 const DeviceMemory<double> &a, int lda,
4338 const DeviceMemory<double> &b, int ldb,
4339 double beta, DeviceMemory<double> *c, int ldc) {
4340 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4341 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4342 PARAM(ldc));
4343
4344 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4345 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4346 int, double, DeviceMemory<double> *, int> impl;
4347 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4348 a, lda, b, ldb, beta, c, ldc);
4349 }
4350
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4351 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4352 uint64 n, uint64 k, std::complex<float> alpha,
4353 const DeviceMemory<std::complex<float>> &a,
4354 int lda,
4355 const DeviceMemory<std::complex<float>> &b,
4356 int ldb, std::complex<float> beta,
4357 DeviceMemory<std::complex<float>> *c, int ldc) {
4358 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4359 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4360 PARAM(ldc));
4361
4362 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4363 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4364 int, const DeviceMemory<std::complex<float>> &, int,
4365 std::complex<float>, DeviceMemory<std::complex<float>> *,
4366 int> impl;
4367 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4368 a, lda, b, ldb, beta, c, ldc);
4369 }
4370
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4371 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4372 uint64 n, uint64 k, std::complex<double> alpha,
4373 const DeviceMemory<std::complex<double>> &a,
4374 int lda,
4375 const DeviceMemory<std::complex<double>> &b,
4376 int ldb, std::complex<double> beta,
4377 DeviceMemory<std::complex<double>> *c, int ldc) {
4378 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4379 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4380 PARAM(ldc));
4381
4382 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4383 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4384 int, const DeviceMemory<std::complex<double>> &, int,
4385 std::complex<double>, DeviceMemory<std::complex<double>> *,
4386 int> impl;
4387 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4388 a, lda, b, ldb, beta, c, ldc);
4389 }
4390
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4391 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4392 blas::Transpose transa, blas::Diagonal diag,
4393 uint64 m, uint64 n, float alpha,
4394 const DeviceMemory<float> &a, int lda,
4395 DeviceMemory<float> *b, int ldb) {
4396 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4397 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4398
4399 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4400 uint64, uint64, float, const DeviceMemory<float> &, int,
4401 DeviceMemory<float> *, int> impl;
4402 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4403 n, alpha, a, lda, b, ldb);
4404 }
4405
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4406 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4407 blas::Transpose transa, blas::Diagonal diag,
4408 uint64 m, uint64 n, double alpha,
4409 const DeviceMemory<double> &a, int lda,
4410 DeviceMemory<double> *b, int ldb) {
4411 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4412 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4413
4414 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4415 uint64, uint64, double, const DeviceMemory<double> &, int,
4416 DeviceMemory<double> *, int> impl;
4417 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4418 n, alpha, a, lda, b, ldb);
4419 }
4420
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4421 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4422 blas::Transpose transa, blas::Diagonal diag,
4423 uint64 m, uint64 n, std::complex<float> alpha,
4424 const DeviceMemory<std::complex<float>> &a,
4425 int lda, DeviceMemory<std::complex<float>> *b,
4426 int ldb) {
4427 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4428 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4429
4430 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4431 uint64, uint64, std::complex<float>,
4432 const DeviceMemory<std::complex<float>> &, int,
4433 DeviceMemory<std::complex<float>> *, int> impl;
4434 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4435 n, alpha, a, lda, b, ldb);
4436 }
4437
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4438 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4439 blas::Transpose transa, blas::Diagonal diag,
4440 uint64 m, uint64 n, std::complex<double> alpha,
4441 const DeviceMemory<std::complex<double>> &a,
4442 int lda, DeviceMemory<std::complex<double>> *b,
4443 int ldb) {
4444 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4445 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4446
4447 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4448 uint64, uint64, std::complex<double>,
4449 const DeviceMemory<std::complex<double>> &, int,
4450 DeviceMemory<std::complex<double>> *, int> impl;
4451 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4452 n, alpha, a, lda, b, ldb);
4453 }
4454
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4455 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4456 blas::Transpose transa, blas::Diagonal diag,
4457 uint64 m, uint64 n, float alpha,
4458 const DeviceMemory<float> &a, int lda,
4459 DeviceMemory<float> *b, int ldb) {
4460 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4461 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4462
4463 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4464 uint64, uint64, float, const DeviceMemory<float> &, int,
4465 DeviceMemory<float> *, int> impl;
4466 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4467 n, alpha, a, lda, b, ldb);
4468 }
4469
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4470 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4471 blas::Transpose transa, blas::Diagonal diag,
4472 uint64 m, uint64 n, double alpha,
4473 const DeviceMemory<double> &a, int lda,
4474 DeviceMemory<double> *b, int ldb) {
4475 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4476 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4477
4478 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4479 uint64, uint64, double, const DeviceMemory<double> &, int,
4480 DeviceMemory<double> *, int> impl;
4481 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4482 n, alpha, a, lda, b, ldb);
4483 }
4484
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4485 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4486 blas::Transpose transa, blas::Diagonal diag,
4487 uint64 m, uint64 n, std::complex<float> alpha,
4488 const DeviceMemory<std::complex<float>> &a,
4489 int lda, DeviceMemory<std::complex<float>> *b,
4490 int ldb) {
4491 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4492 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4493
4494 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4495 uint64, uint64, std::complex<float>,
4496 const DeviceMemory<std::complex<float>> &, int,
4497 DeviceMemory<std::complex<float>> *, int> impl;
4498 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4499 n, alpha, a, lda, b, ldb);
4500 }
4501
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4502 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4503 blas::Transpose transa, blas::Diagonal diag,
4504 uint64 m, uint64 n, std::complex<double> alpha,
4505 const DeviceMemory<std::complex<double>> &a,
4506 int lda, DeviceMemory<std::complex<double>> *b,
4507 int ldb) {
4508 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4509 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4510
4511 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4512 uint64, uint64, std::complex<double>,
4513 const DeviceMemory<std::complex<double>> &, int,
4514 DeviceMemory<std::complex<double>> *, int> impl;
4515 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4516 n, alpha, a, lda, b, ldb);
4517 }
4518
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4519 Stream &Stream::ThenBlasGemmBatched(
4520 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4521 uint64 k, float alpha,
4522 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4523 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4524 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4525 int batch_count) {
4526 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4527 b, ldb, beta, c, ldc, batch_count,
4528 /*scratch_allocator=*/nullptr);
4529 }
4530
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4531 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4532 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4533 uint64 k, float alpha,
4534 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4535 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4536 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4537 int batch_count, ScratchAllocator *scratch_allocator) {
4538 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4539 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4540 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4541
4542 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4543 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4544 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4545 float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4546 int, int, ScratchAllocator *>
4547 impl;
4548 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4549 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4550 scratch_allocator);
4551 }
4552
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4553 Stream &Stream::ThenBlasGemmBatched(
4554 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4555 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4556 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4557 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4558 int batch_count) {
4559 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4560 b, ldb, beta, c, ldc, batch_count,
4561 /*scratch_allocator=*/nullptr);
4562 }
4563
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4564 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4565 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4566 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4567 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4568 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4569 int batch_count, ScratchAllocator *scratch_allocator) {
4570 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4571 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4572 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4573
4574 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4575 const port::ArraySlice<DeviceMemory<float> *> &, int,
4576 const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4577 const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4578 ScratchAllocator *>
4579 impl;
4580 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4581 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4582 scratch_allocator);
4583 }
4584
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4585 Stream &Stream::ThenBlasGemmBatched(
4586 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4587 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4588 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4589 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4590 int batch_count) {
4591 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4592 b, ldb, beta, c, ldc, batch_count,
4593 /*scratch_allocator=*/nullptr);
4594 }
4595
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4596 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4597 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4598 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4599 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4600 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4601 int batch_count, ScratchAllocator *scratch_allocator) {
4602 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4603 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4604 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4605
4606 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4607 const port::ArraySlice<DeviceMemory<double> *> &, int,
4608 const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4609 const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4610 ScratchAllocator *>
4611 impl;
4612 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4613 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4614 scratch_allocator);
4615 }
4616
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4617 Stream &Stream::ThenBlasGemmBatched(
4618 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4619 uint64 k, std::complex<float> alpha,
4620 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4621 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4622 std::complex<float> beta,
4623 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4624 int batch_count) {
4625 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4626 b, ldb, beta, c, ldc, batch_count,
4627 /*scratch_allocator=*/nullptr);
4628 }
4629
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4630 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4631 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4632 uint64 k, std::complex<float> alpha,
4633 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4634 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4635 std::complex<float> beta,
4636 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4637 int batch_count, ScratchAllocator *scratch_allocator) {
4638 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4639 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4640 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4641
4642 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4643 std::complex<float>,
4644 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4645 int,
4646 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4647 int, std::complex<float>,
4648 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4649 int, int, ScratchAllocator *>
4650 impl;
4651 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4652 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4653 scratch_allocator);
4654 }
4655
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4656 Stream &Stream::ThenBlasGemmBatched(
4657 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4658 uint64 k, std::complex<double> alpha,
4659 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4660 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4661 std::complex<double> beta,
4662 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4663 int batch_count) {
4664 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4665 b, ldb, beta, c, ldc, batch_count,
4666 /*scratch_allocator=*/nullptr);
4667 }
4668
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4669 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4670 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4671 uint64 k, std::complex<double> alpha,
4672 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4673 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4674 std::complex<double> beta,
4675 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4676 int batch_count, ScratchAllocator *scratch_allocator) {
4677 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4678 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4679 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4680
4681 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4682 std::complex<double>,
4683 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4684 int,
4685 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4686 int, std::complex<double>,
4687 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4688 int, int, ScratchAllocator *>
4689 impl;
4690 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4691 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4692 scratch_allocator);
4693 }
4694
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4695 Stream &Stream::ThenBlasGemmStridedBatched(
4696 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4697 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4698 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4699 float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4700 int batch_count) {
4701 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4702 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4703 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4704 PARAM(stride_c), PARAM(batch_count));
4705
4706 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4707 const DeviceMemory<Eigen::half> &, int, int64,
4708 const DeviceMemory<Eigen::half> &, int, int64, float,
4709 DeviceMemory<Eigen::half> *, int, int64, int>
4710 impl;
4711 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4712 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4713 c, ldc, stride_c, batch_count);
4714 }
4715
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4716 Stream &Stream::ThenBlasGemmStridedBatched(
4717 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4718 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4719 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4720 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4721 int batch_count) {
4722 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4723 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4724 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4725 PARAM(stride_c), PARAM(batch_count));
4726
4727 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4728 const DeviceMemory<float> &, int, int64,
4729 const DeviceMemory<float> &, int, int64, float,
4730 DeviceMemory<float> *, int, int64, int>
4731 impl;
4732 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4733 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4734 c, ldc, stride_c, batch_count);
4735 }
4736
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4737 Stream &Stream::ThenBlasGemmStridedBatched(
4738 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4739 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4740 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4741 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4742 int batch_count) {
4743 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4744 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4745 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4746 PARAM(stride_c), PARAM(batch_count));
4747
4748 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4749 const DeviceMemory<double> &, int, int64,
4750 const DeviceMemory<double> &, int, int64, double,
4751 DeviceMemory<double> *, int, int64, int>
4752 impl;
4753 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4754 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4755 c, ldc, stride_c, batch_count);
4756 }
4757
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4758 Stream &Stream::ThenBlasGemmStridedBatched(
4759 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4760 uint64 k, std::complex<float> alpha,
4761 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4762 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4763 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4764 int64 stride_c, int batch_count) {
4765 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4766 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4767 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4768 PARAM(stride_c), PARAM(batch_count));
4769
4770 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4771 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4772 int, int64, const DeviceMemory<std::complex<float>> &, int,
4773 int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4774 int, int64, int>
4775 impl;
4776 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4777 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4778 c, ldc, stride_c, batch_count);
4779 }
4780
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4781 Stream &Stream::ThenBlasGemmStridedBatched(
4782 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4783 uint64 k, std::complex<double> alpha,
4784 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4785 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4786 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4787 int64 stride_c, int batch_count) {
4788 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4789 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4790 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4791 PARAM(stride_c), PARAM(batch_count));
4792
4793 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4794 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4795 int, int64, const DeviceMemory<std::complex<double>> &, int,
4796 int64, std::complex<double>,
4797 DeviceMemory<std::complex<double>> *, int, int64, int>
4798 impl;
4799 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4800 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4801 c, ldc, stride_c, batch_count);
4802 }
4803
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4804 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4805 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4806
4807 if (ok()) {
4808 if (rng::RngSupport *rng = parent_->AsRng()) {
4809 CheckError(rng->SetSeed(this, seed, seed_bytes));
4810 } else {
4811 SetError();
4812 LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4813 }
4814 } else {
4815 LOG(INFO) << DebugStreamPointers()
4816 << " did not set RNG seed: " << static_cast<const void *>(seed)
4817 << "; bytes: " << seed_bytes;
4818 }
4819 return *this;
4820 }
4821
ThenPopulateRandUniform(DeviceMemory<float> * values)4822 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4823 VLOG_CALL(PARAM(values));
4824
4825 if (ok()) {
4826 if (rng::RngSupport *rng = parent_->AsRng()) {
4827 CheckError(rng->DoPopulateRandUniform(this, values));
4828 } else {
4829 SetError();
4830 LOG(INFO) << DebugStreamPointers()
4831 << " attempting to perform RNG operation using StreamExecutor"
4832 " without RNG support.";
4833 }
4834 }
4835 return *this;
4836 }
4837
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4838 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4839 DeviceMemory<float> *values) {
4840 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4841
4842 if (ok()) {
4843 if (rng::RngSupport *rng = parent_->AsRng()) {
4844 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4845 } else {
4846 SetError();
4847 LOG(INFO) << DebugStreamPointers()
4848 << " attempting to perform RNG operation using StreamExecutor"
4849 " without RNG support.";
4850 }
4851 }
4852 return *this;
4853 }
4854
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4855 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4856 DeviceMemory<double> *values) {
4857 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4858
4859 if (ok()) {
4860 if (rng::RngSupport *rng = parent_->AsRng()) {
4861 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4862 } else {
4863 SetError();
4864 LOG(INFO) << DebugStreamPointers()
4865 << " attempting to perform RNG operation using StreamExecutor"
4866 " without RNG support.";
4867 }
4868 }
4869 return *this;
4870 }
4871
ThenPopulateRandUniform(DeviceMemory<double> * values)4872 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4873 VLOG_CALL(PARAM(values));
4874
4875 if (ok()) {
4876 if (rng::RngSupport *rng = parent_->AsRng()) {
4877 CheckError(rng->DoPopulateRandUniform(this, values));
4878 } else {
4879 SetError();
4880 LOG(INFO) << DebugStreamPointers()
4881 << " attempting to perform RNG operation using StreamExecutor"
4882 " without RNG support.";
4883 }
4884 }
4885 return *this;
4886 }
4887
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4888 Stream &Stream::ThenPopulateRandUniform(
4889 DeviceMemory<std::complex<float>> *values) {
4890 VLOG_CALL(PARAM(values));
4891
4892 if (ok()) {
4893 if (rng::RngSupport *rng = parent_->AsRng()) {
4894 CheckError(rng->DoPopulateRandUniform(this, values));
4895 } else {
4896 SetError();
4897 LOG(INFO) << DebugStreamPointers()
4898 << " attempting to perform RNG operation using StreamExecutor"
4899 " without RNG support.";
4900 }
4901 }
4902 return *this;
4903 }
4904
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4905 Stream &Stream::ThenPopulateRandUniform(
4906 DeviceMemory<std::complex<double>> *values) {
4907 VLOG_CALL(PARAM(values));
4908
4909 if (ok()) {
4910 if (rng::RngSupport *rng = parent_->AsRng()) {
4911 CheckError(rng->DoPopulateRandUniform(this, values));
4912 } else {
4913 SetError();
4914 LOG(INFO) << DebugStreamPointers()
4915 << " attempting to perform RNG operation using StreamExecutor"
4916 " without RNG support.";
4917 }
4918 }
4919 return *this;
4920 }
4921
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4922 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4923 uint64 size) {
4924 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4925
4926 if (ok()) {
4927 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4928 } else {
4929 LOG(INFO) << DebugStreamPointers()
4930 << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4931 }
4932 return *this;
4933 }
4934
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4935 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4936 uint64 size) {
4937 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4938
4939 if (ok()) {
4940 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4941 } else {
4942 LOG(INFO) << DebugStreamPointers()
4943 << " did not memcpy host-to-device; source: " << host_src;
4944 }
4945 return *this;
4946 }
4947
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4948 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4949 const DeviceMemoryBase &gpu_src, uint64 size) {
4950 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4951
4952 if (ok()) {
4953 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4954 } else {
4955 LOG(INFO) << DebugStreamPointers()
4956 << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4957 }
4958 return *this;
4959 }
4960
ThenMemZero(DeviceMemoryBase * location,uint64 size)4961 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4962 VLOG_CALL(PARAM(location), PARAM(size));
4963
4964 if (ok()) {
4965 CheckStatus(parent_->MemZero(this, location, size));
4966 } else {
4967 LOG(INFO) << DebugStreamPointers()
4968 << " did not memzero GPU location; source: " << location;
4969 }
4970 return *this;
4971 }
4972
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4973 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4974 uint64 size) {
4975 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4976
4977 if (ok()) {
4978 CheckStatus(parent_->Memset32(this, location, pattern, size));
4979 } else {
4980 LOG(INFO) << DebugStreamPointers()
4981 << " did not memset GPU location; source: " << location
4982 << "; size: " << size << "; pattern: " << std::hex << pattern;
4983 }
4984 return *this;
4985 }
4986
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4987 Stream &Stream::ThenRnnForward(
4988 const dnn::RnnDescriptor &rnn_desc,
4989 const dnn::RnnSequenceTensorDescriptor &input_desc,
4990 const DeviceMemory<Eigen::half> &input_data,
4991 const dnn::RnnStateTensorDescriptor &input_h_desc,
4992 const DeviceMemory<Eigen::half> &input_h_data,
4993 const dnn::RnnStateTensorDescriptor &input_c_desc,
4994 const DeviceMemory<Eigen::half> &input_c_data,
4995 const DeviceMemory<Eigen::half> ¶ms,
4996 const dnn::RnnSequenceTensorDescriptor &output_desc,
4997 DeviceMemory<Eigen::half> *output_data,
4998 const dnn::RnnStateTensorDescriptor &output_h_desc,
4999 DeviceMemory<Eigen::half> *output_h_data,
5000 const dnn::RnnStateTensorDescriptor &output_c_desc,
5001 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
5002 ScratchAllocator *reserve_space_allocator,
5003 ScratchAllocator *workspace_allocator,
5004 dnn::ProfileResult *output_profile_result) {
5005 // TODO(zhengxq): add VLOG PARAM calls.
5006 if (ok()) {
5007 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5008 auto status = dnn->DoRnnForward(
5009 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5010 input_c_desc, input_c_data, params, output_desc, output_data,
5011 output_h_desc, output_h_data, output_c_desc, output_c_data,
5012 is_training, reserve_space_allocator, workspace_allocator,
5013 output_profile_result);
5014 if (!status && !output_profile_result) {
5015 SetError();
5016 }
5017 } else {
5018 SetErrorAndLogNoDnnSupport();
5019 }
5020 }
5021 return *this;
5022 }
5023
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5024 Stream &Stream::ThenRnnForward(
5025 const dnn::RnnDescriptor &rnn_desc,
5026 const dnn::RnnSequenceTensorDescriptor &input_desc,
5027 const DeviceMemory<float> &input_data,
5028 const dnn::RnnStateTensorDescriptor &input_h_desc,
5029 const DeviceMemory<float> &input_h_data,
5030 const dnn::RnnStateTensorDescriptor &input_c_desc,
5031 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
5032 const dnn::RnnSequenceTensorDescriptor &output_desc,
5033 DeviceMemory<float> *output_data,
5034 const dnn::RnnStateTensorDescriptor &output_h_desc,
5035 DeviceMemory<float> *output_h_data,
5036 const dnn::RnnStateTensorDescriptor &output_c_desc,
5037 DeviceMemory<float> *output_c_data, bool is_training,
5038 ScratchAllocator *reserve_space_allocator,
5039 ScratchAllocator *workspace_allocator,
5040 dnn::ProfileResult *output_profile_result) {
5041 // TODO(zhengxq): add VLOG PARAM calls.
5042 if (ok()) {
5043 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5044 auto status = dnn->DoRnnForward(
5045 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5046 input_c_desc, input_c_data, params, output_desc, output_data,
5047 output_h_desc, output_h_data, output_c_desc, output_c_data,
5048 is_training, reserve_space_allocator, workspace_allocator,
5049 output_profile_result);
5050 if (!status && !output_profile_result) {
5051 SetError();
5052 }
5053 } else {
5054 SetErrorAndLogNoDnnSupport();
5055 }
5056 }
5057 return *this;
5058 }
5059
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5060 Stream &Stream::ThenRnnForward(
5061 const dnn::RnnDescriptor &rnn_desc,
5062 const dnn::RnnSequenceTensorDescriptor &input_desc,
5063 const DeviceMemory<double> &input_data,
5064 const dnn::RnnStateTensorDescriptor &input_h_desc,
5065 const DeviceMemory<double> &input_h_data,
5066 const dnn::RnnStateTensorDescriptor &input_c_desc,
5067 const DeviceMemory<double> &input_c_data,
5068 const DeviceMemory<double> ¶ms,
5069 const dnn::RnnSequenceTensorDescriptor &output_desc,
5070 DeviceMemory<double> *output_data,
5071 const dnn::RnnStateTensorDescriptor &output_h_desc,
5072 DeviceMemory<double> *output_h_data,
5073 const dnn::RnnStateTensorDescriptor &output_c_desc,
5074 DeviceMemory<double> *output_c_data, bool is_training,
5075 ScratchAllocator *reserve_space_allocator,
5076 ScratchAllocator *workspace_allocator,
5077 dnn::ProfileResult *output_profile_result) {
5078 // TODO(zhengxq): add VLOG PARAM calls.
5079 if (ok()) {
5080 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5081 auto status = dnn->DoRnnForward(
5082 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5083 input_c_desc, input_c_data, params, output_desc, output_data,
5084 output_h_desc, output_h_data, output_c_desc, output_c_data,
5085 is_training, reserve_space_allocator, workspace_allocator,
5086 output_profile_result);
5087 if (!status && !output_profile_result) {
5088 SetError();
5089 }
5090 } else {
5091 SetErrorAndLogNoDnnSupport();
5092 }
5093 }
5094 return *this;
5095 }
5096
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5097 Stream &Stream::ThenRnnBackward(
5098 const dnn::RnnDescriptor &rnn_desc,
5099 const dnn::RnnSequenceTensorDescriptor &input_desc,
5100 const DeviceMemory<Eigen::half> &input_data,
5101 const dnn::RnnStateTensorDescriptor &input_h_desc,
5102 const DeviceMemory<Eigen::half> &input_h_data,
5103 const dnn::RnnStateTensorDescriptor &input_c_desc,
5104 const DeviceMemory<Eigen::half> &input_c_data,
5105 const DeviceMemory<Eigen::half> ¶ms,
5106 const dnn::RnnSequenceTensorDescriptor &output_desc,
5107 const DeviceMemory<Eigen::half> &output_data,
5108 const dnn::RnnStateTensorDescriptor &output_h_desc,
5109 const DeviceMemory<Eigen::half> &output_h_data,
5110 const dnn::RnnStateTensorDescriptor &output_c_desc,
5111 const DeviceMemory<Eigen::half> &output_c_data,
5112 const DeviceMemory<Eigen::half> &output_backprop_data,
5113 const DeviceMemory<Eigen::half> &output_h_backprop_data,
5114 const DeviceMemory<Eigen::half> &output_c_backprop_data,
5115 DeviceMemory<Eigen::half> *input_backprop_data,
5116 DeviceMemory<Eigen::half> *input_h_backprop_data,
5117 DeviceMemory<Eigen::half> *input_c_backprop_data,
5118 DeviceMemory<Eigen::half> *params_backprop_data,
5119 DeviceMemory<uint8> *reserve_space_data,
5120 ScratchAllocator *workspace_allocator,
5121 dnn::ProfileResult *output_profile_result) {
5122 // TODO(zhengxq): add VLOG PARAM calls.
5123 if (ok()) {
5124 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5125 auto status = dnn->DoRnnBackward(
5126 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5127 input_c_desc, input_c_data, params, output_desc, output_data,
5128 output_h_desc, output_h_data, output_c_desc, output_c_data,
5129 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5130 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5131 params_backprop_data, reserve_space_data, workspace_allocator,
5132 output_profile_result);
5133 if (!status && !output_profile_result) {
5134 SetError();
5135 }
5136 } else {
5137 SetError();
5138 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5139 }
5140 }
5141 return *this;
5142 }
5143
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5144 Stream &Stream::ThenRnnBackward(
5145 const dnn::RnnDescriptor &rnn_desc,
5146 const dnn::RnnSequenceTensorDescriptor &input_desc,
5147 const DeviceMemory<float> &input_data,
5148 const dnn::RnnStateTensorDescriptor &input_h_desc,
5149 const DeviceMemory<float> &input_h_data,
5150 const dnn::RnnStateTensorDescriptor &input_c_desc,
5151 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
5152 const dnn::RnnSequenceTensorDescriptor &output_desc,
5153 const DeviceMemory<float> &output_data,
5154 const dnn::RnnStateTensorDescriptor &output_h_desc,
5155 const DeviceMemory<float> &output_h_data,
5156 const dnn::RnnStateTensorDescriptor &output_c_desc,
5157 const DeviceMemory<float> &output_c_data,
5158 const DeviceMemory<float> &output_backprop_data,
5159 const DeviceMemory<float> &output_h_backprop_data,
5160 const DeviceMemory<float> &output_c_backprop_data,
5161 DeviceMemory<float> *input_backprop_data,
5162 DeviceMemory<float> *input_h_backprop_data,
5163 DeviceMemory<float> *input_c_backprop_data,
5164 DeviceMemory<float> *params_backprop_data,
5165 DeviceMemory<uint8> *reserve_space_data,
5166 ScratchAllocator *workspace_allocator,
5167 dnn::ProfileResult *output_profile_result) {
5168 // TODO(zhengxq): add VLOG PARAM calls.
5169 if (ok()) {
5170 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5171 auto status = dnn->DoRnnBackward(
5172 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5173 input_c_desc, input_c_data, params, output_desc, output_data,
5174 output_h_desc, output_h_data, output_c_desc, output_c_data,
5175 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5176 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5177 params_backprop_data, reserve_space_data, workspace_allocator,
5178 output_profile_result);
5179 if (!status && !output_profile_result) {
5180 SetError();
5181 }
5182 } else {
5183 SetError();
5184 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5185 }
5186 }
5187 return *this;
5188 }
5189
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5190 Stream &Stream::ThenRnnBackward(
5191 const dnn::RnnDescriptor &rnn_desc,
5192 const dnn::RnnSequenceTensorDescriptor &input_desc,
5193 const DeviceMemory<double> &input_data,
5194 const dnn::RnnStateTensorDescriptor &input_h_desc,
5195 const DeviceMemory<double> &input_h_data,
5196 const dnn::RnnStateTensorDescriptor &input_c_desc,
5197 const DeviceMemory<double> &input_c_data,
5198 const DeviceMemory<double> ¶ms,
5199 const dnn::RnnSequenceTensorDescriptor &output_desc,
5200 const DeviceMemory<double> &output_data,
5201 const dnn::RnnStateTensorDescriptor &output_h_desc,
5202 const DeviceMemory<double> &output_h_data,
5203 const dnn::RnnStateTensorDescriptor &output_c_desc,
5204 const DeviceMemory<double> &output_c_data,
5205 const DeviceMemory<double> &output_backprop_data,
5206 const DeviceMemory<double> &output_h_backprop_data,
5207 const DeviceMemory<double> &output_c_backprop_data,
5208 DeviceMemory<double> *input_backprop_data,
5209 DeviceMemory<double> *input_h_backprop_data,
5210 DeviceMemory<double> *input_c_backprop_data,
5211 DeviceMemory<double> *params_backprop_data,
5212 DeviceMemory<uint8> *reserve_space_data,
5213 ScratchAllocator *workspace_allocator,
5214 dnn::ProfileResult *output_profile_result) {
5215 // TODO(zhengxq): add VLOG PARAM calls.
5216 if (ok()) {
5217 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5218 auto status = dnn->DoRnnBackward(
5219 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5220 input_c_desc, input_c_data, params, output_desc, output_data,
5221 output_h_desc, output_h_data, output_c_desc, output_c_data,
5222 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5223 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5224 params_backprop_data, reserve_space_data, workspace_allocator,
5225 output_profile_result);
5226 if (!status && !output_profile_result) {
5227 SetError();
5228 }
5229 } else {
5230 SetError();
5231 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5232 }
5233 }
5234 return *this;
5235 }
5236
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)5237 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
5238 const DeviceMemory<float> &probs_data,
5239 absl::Span<const int> labels_data,
5240 absl::Span<const int> labels_lengths_data,
5241 absl::Span<const int> input_lengths_data,
5242 DeviceMemory<float> *costs_data,
5243 const dnn::RnnStateTensorDescriptor &grads_desc,
5244 DeviceMemory<float> *grads_data,
5245 ScratchAllocator *workspace_allocator) {
5246 if (ok()) {
5247 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5248 DeviceMemory<uint8> scratch_memory;
5249 auto status = dnn->PrepareForCtcLoss(
5250 this, probs_desc, probs_data, grads_desc,
5251 labels_data, labels_lengths_data, input_lengths_data,
5252 workspace_allocator, &scratch_memory)
5253 .ok();
5254 if (status) {
5255 status =
5256 dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
5257 labels_lengths_data, input_lengths_data, costs_data,
5258 grads_desc, grads_data, &scratch_memory);
5259 }
5260 if (!status) {
5261 SetError();
5262 }
5263 } else {
5264 SetErrorAndLogNoDnnSupport();
5265 }
5266 }
5267 return *this;
5268 }
5269
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5270 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
5271 dnn::DataType input_type,
5272 const DeviceMemoryBase &input_data,
5273 const dnn::BatchDescriptor &output_desc,
5274 dnn::DataType output_type, float scale,
5275 DeviceMemoryBase *output_data) {
5276 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
5277 PARAM(output_desc), PARAM(output_type), PARAM(scale),
5278 PARAM(output_data));
5279 if (ok()) {
5280 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5281 CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
5282 input_data, output_desc, output_type,
5283 scale, output_data));
5284 } else {
5285 SetErrorAndLogNoDnnSupport();
5286 }
5287 }
5288 return *this;
5289 }
5290
ThenDoHostCallback(std::function<void ()> callback)5291 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
5292 VLOG_CALL(PARAM(callback));
5293
5294 if (!ok()) {
5295 LOG(INFO) << DebugStreamPointers()
5296 << " was in error state before adding host callback";
5297 }
5298 CheckError(parent_->HostCallback(this, std::move(callback)));
5299 return *this;
5300 }
5301
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)5302 Stream &Stream::ThenDoHostCallbackWithStatus(
5303 std::function<port::Status()> callback) {
5304 VLOG_CALL(PARAM(callback));
5305
5306 if (!ok()) {
5307 LOG(INFO) << DebugStreamPointers()
5308 << " was in error state before adding host callback";
5309 }
5310 CheckError(parent_->HostCallback(this, std::move(callback)));
5311 return *this;
5312 }
5313
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)5314 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
5315 std::function<void()> callback) {
5316 VLOG_CALL(PARAM(callback));
5317
5318 if (!ok()) {
5319 LOG(INFO) << DebugStreamPointers()
5320 << " was in error state before adding callback to be run after "
5321 "next block-host-until-done.";
5322 }
5323 absl::MutexLock lock(&mu_);
5324 after_block_host_until_done_callbacks_.push_back(std::move(callback));
5325 return *this;
5326 }
5327
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)5328 Stream &Stream::ThenFft(fft::Plan *plan,
5329 const DeviceMemory<std::complex<float>> &input,
5330 DeviceMemory<std::complex<float>> *output) {
5331 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5332
5333 if (ok()) {
5334 if (fft::FftSupport *fft = parent_->AsFft()) {
5335 CheckError(fft->DoFft(this, plan, input, output));
5336 } else {
5337 SetError();
5338 LOG(INFO) << DebugStreamPointers()
5339 << " attempting to perform FFT operation using StreamExecutor"
5340 " without FFT support";
5341 }
5342 }
5343 return *this;
5344 }
5345
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)5346 Stream &Stream::ThenFft(fft::Plan *plan,
5347 const DeviceMemory<std::complex<double>> &input,
5348 DeviceMemory<std::complex<double>> *output) {
5349 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5350
5351 if (ok()) {
5352 if (fft::FftSupport *fft = parent_->AsFft()) {
5353 CheckError(fft->DoFft(this, plan, input, output));
5354 } else {
5355 SetError();
5356 LOG(INFO) << DebugStreamPointers()
5357 << " attempting to perform FFT operation using StreamExecutor"
5358 " without FFT support";
5359 }
5360 }
5361 return *this;
5362 }
5363
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)5364 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
5365 DeviceMemory<std::complex<float>> *output) {
5366 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5367
5368 if (ok()) {
5369 if (fft::FftSupport *fft = parent_->AsFft()) {
5370 CheckError(fft->DoFft(this, plan, input, output));
5371 } else {
5372 SetError();
5373 LOG(INFO) << DebugStreamPointers()
5374 << " attempting to perform FFT operation using StreamExecutor"
5375 " without FFT support";
5376 }
5377 }
5378 return *this;
5379 }
5380
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)5381 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
5382 DeviceMemory<std::complex<double>> *output) {
5383 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5384
5385 if (ok()) {
5386 if (fft::FftSupport *fft = parent_->AsFft()) {
5387 CheckError(fft->DoFft(this, plan, input, output));
5388 } else {
5389 SetError();
5390 LOG(INFO) << DebugStreamPointers()
5391 << " attempting to perform FFT operation using StreamExecutor"
5392 " without FFT support";
5393 }
5394 }
5395 return *this;
5396 }
5397
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5398 Stream &Stream::ThenFft(fft::Plan *plan,
5399 const DeviceMemory<std::complex<float>> &input,
5400 DeviceMemory<float> *output) {
5401 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5402
5403 if (ok()) {
5404 if (fft::FftSupport *fft = parent_->AsFft()) {
5405 CheckError(fft->DoFft(this, plan, input, output));
5406 } else {
5407 SetError();
5408 LOG(INFO) << DebugStreamPointers()
5409 << " attempting to perform FFT operation using StreamExecutor"
5410 " without FFT support";
5411 }
5412 }
5413 return *this;
5414 }
5415
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5416 Stream &Stream::ThenFft(fft::Plan *plan,
5417 const DeviceMemory<std::complex<double>> &input,
5418 DeviceMemory<double> *output) {
5419 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5420
5421 if (ok()) {
5422 if (fft::FftSupport *fft = parent_->AsFft()) {
5423 CheckError(fft->DoFft(this, plan, input, output));
5424 } else {
5425 SetError();
5426 LOG(INFO) << DebugStreamPointers()
5427 << " attempting to perform FFT operation using StreamExecutor"
5428 " without FFT support";
5429 }
5430 }
5431 return *this;
5432 }
5433
5434 // It looks confusing, but all this is doing is inserting a callback at the
5435 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5436 Stream &Stream::ThenEnqueueOnBackgroundThread(
5437 std::function<void(StreamExecutor *)> task) {
5438 VLOG_CALL(PARAM(task));
5439
5440 StreamExecutor *stream_executor = this->parent_;
5441 std::function<void()> bound_task = std::bind(task, stream_executor);
5442
5443 return ThenDoHostCallback([stream_executor, bound_task]() {
5444 stream_executor->EnqueueOnBackgroundThread(bound_task);
5445 });
5446 }
5447
BlockHostUntilDone()5448 port::Status Stream::BlockHostUntilDone() {
5449 VLOG_CALL();
5450
5451 if (!ok()) {
5452 port::Status status = port::Status(
5453 port::error::INTERNAL,
5454 "stream did not block host until done; was already in an error state");
5455 LOG(INFO) << DebugStreamPointers() << " " << status;
5456 return status;
5457 }
5458
5459 temporary_memory_manager_.DeallocateFinalizedTemporaries();
5460
5461 port::Status error = parent_->BlockHostUntilDone(this);
5462 CheckError(error.ok());
5463
5464 RunAfterBlockHostUntilDoneCallbacks();
5465 return error;
5466 }
5467
RunAfterBlockHostUntilDoneCallbacks()5468 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
5469 std::vector<std::function<void()>> callbacks;
5470 {
5471 absl::MutexLock lock(&mu_);
5472 std::swap(callbacks, after_block_host_until_done_callbacks_);
5473 }
5474 for (const auto &fn : callbacks) {
5475 fn();
5476 }
5477 }
5478
DebugStreamPointers() const5479 string Stream::DebugStreamPointers() const {
5480 // Relies on the ToVlogString(const void*) overload above.
5481 return absl::StrCat("[stream=", ToVlogString(this),
5482 ",impl=", ToVlogString(implementation_.get()), "]");
5483 }
5484
CheckStatus(port::Status status)5485 void Stream::CheckStatus(port::Status status) {
5486 if (status.ok()) {
5487 return;
5488 }
5489 LOG(ERROR) << status;
5490 absl::MutexLock lock(&mu_);
5491 ok_ = false;
5492 }
5493
5494 } // namespace stream_executor
5495