1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/stream.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_or_device_scalar.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/platform.h"
24 #include "tensorflow/stream_executor/platform/logging.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29
30 namespace stream_executor {
31
32 namespace {
33 // Code to turn parameters to functions on stream into strings that
34 // will be VLOG'ed. We need overloads, instead of
35 // e.g. BatchDescriptorToVlogString(), as the code that calls these
36 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)37 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
38 return descriptor.ToShortString();
39 }
40
ToVlogString(const dnn::FilterDescriptor & descriptor)41 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
42 return descriptor.ToShortString();
43 }
44
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)45 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
46 return descriptor.ToShortString();
47 }
48
ToVlogString(const dnn::PoolingDescriptor & descriptor)49 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
50 return descriptor.ToShortString();
51 }
52
ToVlogString(const dnn::NormalizeDescriptor & descriptor)53 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
54 return descriptor.ToShortString();
55 }
56
ToVlogString(dnn::ActivationMode mode)57 std::string ToVlogString(dnn::ActivationMode mode) {
58 return dnn::ActivationModeString(mode);
59 }
60
ToVlogString(const dnn::AlgorithmConfig & algo_config)61 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
62 return algo_config.ToString();
63 }
64
ToVlogString(dnn::ElementwiseOperation op)65 std::string ToVlogString(dnn::ElementwiseOperation op) {
66 return dnn::ElementwiseOperationString(op);
67 }
68
ToVlogString(dnn::QuantizedActivationMode mode)69 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
70 return dnn::QuantizedActivationModeString(mode);
71 }
72
ToVlogString(blas::Transpose t)73 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
74
ToVlogString(blas::UpperLower ul)75 std::string ToVlogString(blas::UpperLower ul) {
76 return blas::UpperLowerString(ul);
77 }
78
ToVlogString(blas::Diagonal d)79 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
80
ToVlogString(blas::Side s)81 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
82
ToVlogString(blas::ComputationType ty)83 std::string ToVlogString(blas::ComputationType ty) {
84 return blas::ComputationTypeString(ty);
85 }
86
ToVlogString(const void * ptr)87 std::string ToVlogString(const void *ptr) {
88 if (ptr == nullptr) {
89 return "null";
90 }
91
92 // StrCat does not convert pointers to text.
93 std::ostringstream out;
94 out << ptr;
95 return out.str();
96 }
97
98 template <class T>
ToVlogString(const std::complex<T> & c)99 std::string ToVlogString(const std::complex<T> &c) {
100 // StrCat does not convert std::complex to text.
101 std::ostringstream out;
102 out << c;
103 return out.str();
104 }
105
106 template <class T>
ToVlogString(const std::function<T> & f)107 std::string ToVlogString(const std::function<T> &f) {
108 return f == nullptr ? "null" : "<non-null function>";
109 }
110
ToVlogString(const DeviceMemoryBase & memory)111 std::string ToVlogString(const DeviceMemoryBase &memory) {
112 return ToVlogString(memory.opaque());
113 }
114
ToVlogString(const DeviceMemoryBase * memory)115 std::string ToVlogString(const DeviceMemoryBase *memory) {
116 return memory == nullptr ? "null" : ToVlogString(*memory);
117 }
118
ToVlogString(const Eigen::half & h)119 std::string ToVlogString(const Eigen::half &h) {
120 return absl::StrCat(static_cast<float>(h));
121 }
122
ToVlogString(int i)123 std::string ToVlogString(int i) { return absl::StrCat(i); }
124
ToVlogString(uint32 i)125 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
126
ToVlogString(uint64 i)127 std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
128
ToVlogString(int64 i)129 std::string ToVlogString(int64 i) { return absl::StrCat(i); }
130
ToVlogString(float f)131 std::string ToVlogString(float f) { return absl::StrCat(f); }
132
ToVlogString(double d)133 std::string ToVlogString(double d) { return absl::StrCat(d); }
134
135 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)136 std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
137 if (memory_or_constant.is_pointer()) {
138 return ToVlogString(memory_or_constant.pointer());
139 }
140 return ToVlogString(memory_or_constant.value());
141 }
142
143 template <class T>
ToVlogString(port::ArraySlice<T> elements)144 std::string ToVlogString(port::ArraySlice<T> elements) {
145 std::string str = absl::StrCat(
146 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
147 elements.size(), "]{");
148 const char *separator = "";
149 size_t max_to_show = std::numeric_limits<size_t>::max();
150 if (!VLOG_IS_ON(2)) {
151 max_to_show = 5;
152 } else if (!VLOG_IS_ON(3)) {
153 max_to_show = 20;
154 } else if (!VLOG_IS_ON(11)) {
155 max_to_show = 1000;
156 }
157 for (size_t i = 0; i < elements.size(); ++i) {
158 if (i == max_to_show) {
159 str += ", ...";
160 break;
161 }
162 absl::StrAppend(&str, separator, ToVlogString(elements[i]));
163 separator = ", ";
164 }
165 str += "}";
166 return str;
167 }
168
169 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)170 std::string ToVlogString(port::MutableArraySlice<T> elements) {
171 return ToVlogString(port::ArraySlice<T>(elements));
172 }
173
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)174 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
175 switch (depth_to_space_layout) {
176 case dnn::DepthToSpaceLayout::DepthHeightWidth:
177 return "DepthToSpaceLayout::DepthHeightWidth";
178 }
179 return "unknown DepthToSpaceLayout";
180 }
181
ToVlogString(dnn::DataType data_type)182 std::string ToVlogString(dnn::DataType data_type) {
183 switch (data_type) {
184 case dnn::DataType::kFloat:
185 return "dnn::DataType::kFloat";
186 case dnn::DataType::kDouble:
187 return "dnn::DataType::kDouble";
188 case dnn::DataType::kHalf:
189 return "dnn::DataType::kHalf";
190 case dnn::DataType::kInt8:
191 return "dnn::DataType::kInt8";
192 case dnn::DataType::kInt32:
193 return "dnn::DataType::kInt32";
194 default:
195 return "unknown DataType";
196 }
197 }
198
199 // Used together with PARAM to VLOG calls made to the stream. Intended
200 // to be used like this:
201 //
202 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
203 //
204 // where a and b are the parameters to MyFunction.
205 //
206 // See VLOG_CALL for a short-hand for this. This way of doing it saves
207 // a tremendous amount of boilerplate code given how many functions
208 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)209 std::string CallStr(const char *function_name, Stream *stream,
210 std::vector<std::pair<const char *, std::string>> params) {
211 // Do not call this function unless VLOG is on since just
212 // constructing all the strings in params is expensive.
213 CHECK(VLOG_IS_ON(1));
214
215 std::string str = absl::StrCat(stream->DebugStreamPointers(),
216 " Called Stream::", function_name, "(");
217 const char *separator = "";
218 for (const auto ¶m : params) {
219 absl::StrAppend(&str, separator, param.first, "=", param.second);
220 separator = ", ";
221 }
222 absl::StrAppend(&str, ")");
223 if (VLOG_IS_ON(10)) {
224 absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
225 }
226 return str;
227 }
228
229 // Use this macro to avoid having to type every parameter twice to log
230 // it with VLOG and CallStr.
231 #define PARAM(parameter) \
232 { #parameter, ToVlogString(parameter) }
233
234 // Use this macro to avoid having to type out the name of each
235 // function and to save some boilerplate. Intended to be used like this:
236 //
237 // VLOG_CALL(PARAM(a), PARAM(b))
238 //
239 // This saves a tremendous amount of boilerplate compared to the alternative:
240 //
241 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
242 // << ", b=" << ToVlogString(b);
243 //
244 // Note here that most of the parameter names are not short and that
245 // most of the functions take many more than 2 parameters.
246 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
247
248 } // namespace
249
Stream(StreamExecutor * parent)250 Stream::Stream(StreamExecutor *parent)
251 : parent_(parent),
252 implementation_(parent->implementation()->GetStreamImplementation()),
253 allocated_(false),
254 status_(port::InternalError("Uninitialized stream")),
255 temporary_memory_manager_(this) {
256 VLOG_CALL(PARAM(parent));
257 }
258
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)259 Stream::Stream(StreamExecutor *parent,
260 internal::StreamInterface *implementation)
261 : parent_(parent),
262 implementation_(implementation),
263 allocated_(false),
264 status_(port::InternalError("Uninitialized stream")),
265 temporary_memory_manager_(this) {
266 VLOG_CALL(PARAM(parent), PARAM(implementation));
267 }
268
~Stream()269 Stream::~Stream() {
270 VLOG_CALL();
271
272 // Ensure the stream is completed.
273 auto status = BlockHostUntilDone();
274 if (!status.ok()) {
275 LOG(WARNING) << "Error blocking host until done in stream destructor: "
276 << status;
277 }
278 temporary_memory_manager_.ForceDeallocateAll();
279 RunAfterBlockHostUntilDoneCallbacks();
280
281 if (allocated_) {
282 parent_->DeallocateStream(this);
283 }
284 }
285
RefreshStatus()286 port::Status Stream::RefreshStatus() {
287 port::Status status = parent_->GetStatus(this);
288 // We should not put the stream in an error state, just because the GetStatus
289 // method is unimplemented.
290 if (status != port::Status(port::error::UNIMPLEMENTED,
291 "GetStatus is not supported on this executor.")) {
292 CheckStatus(status);
293 }
294 return status;
295 }
296
Init()297 Stream &Stream::Init() {
298 VLOG_CALL();
299
300 absl::MutexLock lock(&mu_);
301 CHECK_EQ(false, allocated_)
302 << "stream appears to already have been initialized";
303 CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
304
305 if (parent_->AllocateStream(this)) {
306 // Successful initialization!
307 allocated_ = true;
308 status_ = port::Status::OK();
309 } else {
310 LOG(ERROR) << "failed to allocate stream during initialization";
311 }
312
313 return *this;
314 }
315
InitTimer(Timer * timer)316 Stream &Stream::InitTimer(Timer *timer) {
317 VLOG_CALL(PARAM(timer));
318
319 CheckError(parent_->AllocateTimer(timer));
320 return *this;
321 }
322
InitWithTimer(Timer * timer)323 Stream &Stream::InitWithTimer(Timer *timer) {
324 VLOG_CALL(PARAM(timer));
325
326 return Init().InitTimer(timer);
327 }
328
ThenRecordEvent(Event * event)329 Stream &Stream::ThenRecordEvent(Event *event) {
330 VLOG_CALL(PARAM(event));
331
332 port::Status status = parent_->RecordEvent(this, event);
333 if (!status.ok()) {
334 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
335 << "; not marking stream as bad, as the Event object may be "
336 << "at fault. Monitor for further errors.";
337 }
338
339 return *this;
340 }
341
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)342 Stream &Stream::ThenBatchNormalizationForward(
343 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
344 const DeviceMemory<float> &offset,
345 const DeviceMemory<float> &estimated_mean,
346 const DeviceMemory<float> &estimated_variance,
347 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
348 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
349 const double exponential_average_factor,
350 dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
351 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
352 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
353 bool is_training,
354 ScratchAllocator *reserve_space_allocator,
355 ScratchAllocator *workspace_allocator) {
356 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
357 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
358 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
359 CheckError(dnn->DoBatchNormalizationForward(
360 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
361 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
362 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
363 is_training, reserve_space_allocator, workspace_allocator));
364 } else {
365 SetErrorAndLogNoDnnSupport();
366 }
367 return *this;
368 }
369
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)370 Stream &Stream::ThenBatchNormalizationBackward(
371 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
372 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
373 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
374 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
375 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
376 DeviceMemory<float> *offset_backprop,
377 DeviceMemory<uint8> *reserve_space_data,
378 ScratchAllocator *workspace_allocator) {
379 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
380 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
381 PARAM(scale_backprop), PARAM(offset_backprop));
382 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
383 CheckError(dnn->DoBatchNormalizationBackward(
384 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
385 epsilon, x_backprop, scale_backprop, offset_backprop,
386 reserve_space_data, workspace_allocator));
387 } else {
388 SetErrorAndLogNoDnnSupport();
389 }
390 return *this;
391 }
392
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)393 Stream &Stream::ThenBatchNormalizationForward(
394 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
395 const DeviceMemory<float> &offset,
396 const DeviceMemory<float> &estimated_mean,
397 const DeviceMemory<float> &estimated_variance,
398 const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
399 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
400 const double exponential_average_factor,
401 dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
402 DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
403 DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
404 bool is_training,
405 ScratchAllocator *reserve_space_allocator,
406 ScratchAllocator *workspace_allocator) {
407 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
408 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
409 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
410 CheckError(dnn->DoBatchNormalizationForward(
411 this, x, scale, offset, estimated_mean, estimated_variance, side_input,
412 x_desc, scale_offset_desc, epsilon, exponential_average_factor,
413 activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
414 is_training, reserve_space_allocator, workspace_allocator));
415 } else {
416 SetErrorAndLogNoDnnSupport();
417 }
418 return *this;
419 }
420
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)421 Stream &Stream::ThenBatchNormalizationBackward(
422 const DeviceMemory<Eigen::half> &y_backprop,
423 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
424 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
425 const dnn::BatchDescriptor &x_desc,
426 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
427 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
428 DeviceMemory<float> *offset_backprop,
429 DeviceMemory<uint8> *reserve_space_data,
430 ScratchAllocator *workspace_allocator) {
431 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
432 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
433 PARAM(scale_backprop), PARAM(offset_backprop));
434 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
435 CheckError(dnn->DoBatchNormalizationBackward(
436 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
437 epsilon, x_backprop, scale_backprop, offset_backprop,
438 reserve_space_data, workspace_allocator));
439
440 } else {
441 SetErrorAndLogNoDnnSupport();
442 }
443 return *this;
444 }
445
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)446 port::Status Stream::FusedConvolveWithAlgorithm(
447 const dnn::BatchDescriptor &conv_input_descriptor,
448 const DeviceMemory<double> &conv_input_data, double conv_input_scale,
449 const dnn::FilterDescriptor &filter_descriptor,
450 const DeviceMemory<double> &filter_data,
451 const dnn::ConvolutionDescriptor &convolution_descriptor,
452 const DeviceMemory<double> &side_input_data, double side_input_scale,
453 const dnn::BatchDescriptor &bias_descriptor,
454 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
455 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
456 ScratchAllocator *scratch_allocator,
457 const dnn::AlgorithmConfig &algorithm_config,
458 dnn::ProfileResult *output_profile_result) {
459 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
460 PARAM(conv_input_scale), PARAM(filter_descriptor),
461 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
462 PARAM(side_input_data), PARAM(side_input_scale),
463 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
464 PARAM(algorithm_config));
465
466 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
467 return dnn->DoFusedConvolve(
468 this, conv_input_descriptor, conv_input_data, conv_input_scale,
469 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
470 side_input_scale, bias_descriptor, biases, activation_mode,
471 output_descriptor, output, scratch_allocator, algorithm_config,
472 output_profile_result);
473 }
474 return port::UnimplementedError("DNN library is not found.");
475 }
476
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)477 port::Status Stream::FusedConvolveWithAlgorithm(
478 const dnn::BatchDescriptor &conv_input_descriptor,
479 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
480 const dnn::FilterDescriptor &filter_descriptor,
481 const DeviceMemory<float> &filter_data,
482 const dnn::ConvolutionDescriptor &convolution_descriptor,
483 const DeviceMemory<float> &side_input_data, float side_input_scale,
484 const dnn::BatchDescriptor &bias_descriptor,
485 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
486 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
487 ScratchAllocator *scratch_allocator,
488 const dnn::AlgorithmConfig &algorithm_config,
489 dnn::ProfileResult *output_profile_result) {
490 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
491 PARAM(conv_input_scale), PARAM(filter_descriptor),
492 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
493 PARAM(side_input_data), PARAM(side_input_scale),
494 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
495 PARAM(algorithm_config));
496
497 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
498 return dnn->DoFusedConvolve(
499 this, conv_input_descriptor, conv_input_data, conv_input_scale,
500 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
501 side_input_scale, bias_descriptor, biases, activation_mode,
502 output_descriptor, output, scratch_allocator, algorithm_config,
503 output_profile_result);
504 }
505 return port::UnimplementedError("DNN library is not found.");
506 }
507
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)508 port::Status Stream::FusedConvolveWithAlgorithm(
509 const dnn::BatchDescriptor &conv_input_descriptor,
510 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
511 const dnn::FilterDescriptor &filter_descriptor,
512 const DeviceMemory<Eigen::half> &filter_data,
513 const dnn::ConvolutionDescriptor &convolution_descriptor,
514 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
515 const dnn::BatchDescriptor &bias_descriptor,
516 const DeviceMemory<Eigen::half> &biases,
517 dnn::ActivationMode activation_mode,
518 const dnn::BatchDescriptor &output_descriptor,
519 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
520 const dnn::AlgorithmConfig &algorithm_config,
521 dnn::ProfileResult *output_profile_result) {
522 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
523 PARAM(conv_input_scale), PARAM(filter_descriptor),
524 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
525 PARAM(side_input_data), PARAM(side_input_scale),
526 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
527 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
528
529 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
530 return dnn->DoFusedConvolve(
531 this, conv_input_descriptor, conv_input_data, conv_input_scale,
532 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
533 side_input_scale, bias_descriptor, biases, activation_mode,
534 output_descriptor, output, scratch_allocator, algorithm_config,
535 output_profile_result);
536 }
537 return port::UnimplementedError("DNN library is not found.");
538 }
539
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)540 port::Status Stream::FusedConvolveWithAlgorithm(
541 const dnn::BatchDescriptor &conv_input_descriptor,
542 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
543 const dnn::FilterDescriptor &filter_descriptor,
544 const DeviceMemory<int8> &filter_data,
545 const dnn::ConvolutionDescriptor &convolution_descriptor,
546 const DeviceMemory<int8> &side_input_data, float side_input_scale,
547 const dnn::BatchDescriptor &bias_descriptor,
548 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
549 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
550 ScratchAllocator *scratch_allocator,
551 const dnn::AlgorithmConfig &algorithm_config,
552 dnn::ProfileResult *output_profile_result) {
553 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
554 PARAM(conv_input_scale), PARAM(filter_descriptor),
555 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
556 PARAM(side_input_data), PARAM(side_input_scale),
557 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
558 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
559
560 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
561 return dnn->DoFusedConvolve(
562 this, conv_input_descriptor, conv_input_data, conv_input_scale,
563 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
564 side_input_scale, bias_descriptor, biases, activation_mode,
565 output_descriptor, output, scratch_allocator, algorithm_config,
566 output_profile_result);
567 }
568 return port::UnimplementedError("DNN library is not found.");
569 }
570
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)571 port::Status Stream::FusedConvolveWithAlgorithm(
572 const dnn::BatchDescriptor &conv_input_descriptor,
573 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
574 const dnn::FilterDescriptor &filter_descriptor,
575 const DeviceMemory<int8> &filter_data,
576 const dnn::ConvolutionDescriptor &convolution_descriptor,
577 const DeviceMemory<float> &side_input_data, float side_input_scale,
578 const dnn::BatchDescriptor &bias_descriptor,
579 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
580 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
581 ScratchAllocator *scratch_allocator,
582 const dnn::AlgorithmConfig &algorithm_config,
583 dnn::ProfileResult *output_profile_result) {
584 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
585 PARAM(conv_input_scale), PARAM(filter_descriptor),
586 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
587 PARAM(side_input_data), PARAM(side_input_scale),
588 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
589 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
590
591 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
592 return dnn->DoFusedConvolve(
593 this, conv_input_descriptor, conv_input_data, conv_input_scale,
594 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
595 side_input_scale, bias_descriptor, biases, activation_mode,
596 output_descriptor, output, scratch_allocator, algorithm_config,
597 output_profile_result);
598 }
599 return port::UnimplementedError("DNN library is not found.");
600 }
601
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)602 Stream &Stream::ThenConvolve(
603 const dnn::BatchDescriptor &input_descriptor,
604 const DeviceMemory<float> &input_data,
605 const dnn::FilterDescriptor &filter_descriptor,
606 const DeviceMemory<float> &filter_data,
607 const dnn::ConvolutionDescriptor &convolution_descriptor,
608 const dnn::BatchDescriptor &output_descriptor,
609 DeviceMemory<float> *output) {
610 if (ok()) {
611 CheckError(ConvolveWithAlgorithm(
612 input_descriptor, input_data, filter_descriptor, filter_data,
613 convolution_descriptor, output_descriptor, output,
614 /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
615 /*output_profile_result=*/nullptr)
616 .ok());
617 }
618 return *this;
619 }
620
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)621 Stream &Stream::ThenConvolveQuantized(
622 const dnn::BatchDescriptor &input_descriptor,
623 const DeviceMemory<float> &input_data,
624 const dnn::FilterDescriptor &filter_descriptor,
625 const DeviceMemory<int8> &filter_coefficients,
626 const DeviceMemory<float> &coefficient_scales,
627 const dnn::ConvolutionDescriptor &convolution_descriptor,
628 const dnn::BatchDescriptor &output_descriptor,
629 DeviceMemory<float> *output) {
630 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
631 PARAM(filter_descriptor), PARAM(filter_coefficients),
632 PARAM(coefficient_scales), PARAM(convolution_descriptor),
633 PARAM(output_descriptor), PARAM(output));
634
635 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
636 CheckError(dnn->DoConvolveQuantized(
637 this, input_descriptor, input_data, filter_descriptor,
638 filter_coefficients, coefficient_scales, convolution_descriptor,
639 output_descriptor, output));
640 } else {
641 SetError();
642 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
643 "without DNN support";
644 }
645 return *this;
646 }
647
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)648 Stream &Stream::ThenConvolveQuantized(
649 const dnn::BatchDescriptor &input_descriptor,
650 const DeviceMemory<float> &input_data,
651 const dnn::FilterDescriptor &filter_descriptor,
652 const DeviceMemory<int16> &filter_coefficients,
653 const DeviceMemory<float> &coefficient_scales,
654 const dnn::ConvolutionDescriptor &convolution_descriptor,
655 const dnn::BatchDescriptor &output_descriptor,
656 DeviceMemory<float> *output) {
657 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
658 PARAM(filter_descriptor), PARAM(filter_coefficients),
659 PARAM(coefficient_scales), PARAM(convolution_descriptor),
660 PARAM(output_descriptor), PARAM(output));
661
662 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
663 CheckError(dnn->DoConvolveQuantized(
664 this, input_descriptor, input_data, filter_descriptor,
665 filter_coefficients, coefficient_scales, convolution_descriptor,
666 output_descriptor, output));
667 } else {
668 SetError();
669 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
670 "without DNN support";
671 }
672 return *this;
673 }
674
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)675 Stream &Stream::ThenSeparableConvolve(
676 const dnn::BatchDescriptor &batch_descriptor,
677 const DeviceMemory<float> &input_data,
678 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
679 const DeviceMemory<float> &first_weights,
680 const DeviceMemory<float> &second_weights,
681 const dnn::ConvolutionDescriptor &convolution_descriptor,
682 const dnn::BatchDescriptor &output_descriptor,
683 DeviceMemory<float> *output) {
684 VLOG_CALL(
685 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
686 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
687 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
688
689 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
690 CheckError(dnn->DoSeparableConvolve(
691 this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
692 first_weights, second_weights, convolution_descriptor,
693 output_descriptor, output));
694 } else {
695 SetErrorAndLogNoDnnSupport();
696 }
697 return *this;
698 }
699
700 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)701 Stream &Stream::ThenConvolveBackwardBiasImpl(
702 const dnn::BatchDescriptor &input_descriptor,
703 const DeviceMemory<T> &input_data,
704 const dnn::BatchDescriptor &bias_descriptor,
705 DeviceMemory<T> *backward_bias_data) {
706 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
707 PARAM(backward_bias_data));
708
709 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
710 CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
711 bias_descriptor,
712 backward_bias_data));
713 } else {
714 SetErrorAndLogNoDnnSupport();
715 }
716 return *this;
717 }
718
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)719 Stream &Stream::ThenConvolveBackwardBias(
720 const dnn::BatchDescriptor &input_descriptor,
721 const DeviceMemory<double> &input_data,
722 const dnn::BatchDescriptor &bias_descriptor,
723 DeviceMemory<double> *backward_bias_data) {
724 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
725 bias_descriptor, backward_bias_data);
726 }
727
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)728 Stream &Stream::ThenConvolveBackwardBias(
729 const dnn::BatchDescriptor &input_descriptor,
730 const DeviceMemory<float> &input_data,
731 const dnn::BatchDescriptor &bias_descriptor,
732 DeviceMemory<float> *backward_bias_data) {
733 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
734 bias_descriptor, backward_bias_data);
735 }
736
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)737 Stream &Stream::ThenConvolveBackwardBias(
738 const dnn::BatchDescriptor &input_descriptor,
739 const DeviceMemory<Eigen::half> &input_data,
740 const dnn::BatchDescriptor &bias_descriptor,
741 DeviceMemory<Eigen::half> *backward_bias_data) {
742 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
743 bias_descriptor, backward_bias_data);
744 }
745
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)746 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
747 const DeviceMemory<float> &weights,
748 const dnn::BatchDescriptor &input_dimensions,
749 const dnn::BatchDescriptor &output_dimensions,
750 DeviceMemory<float> *output_data) {
751 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
752 PARAM(output_dimensions), PARAM(output_data));
753
754 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
755 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
756 output_dimensions, output_data));
757 } else {
758 SetErrorAndLogNoDnnSupport();
759 }
760 return *this;
761 }
762
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)763 Stream &Stream::ThenMatMulQuantized(
764 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
765 const DeviceMemory<float> &weight_scales,
766 const dnn::BatchDescriptor &input_dimensions,
767 const dnn::BatchDescriptor &output_dimensions,
768 DeviceMemory<float> *output_data) {
769 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
770 PARAM(input_dimensions), PARAM(output_dimensions),
771 PARAM(output_data));
772
773 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
774 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
775 input_dimensions, output_dimensions,
776 output_data));
777 } else {
778 SetErrorAndLogNoDnnSupport();
779 }
780 return *this;
781 }
782
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)783 Stream &Stream::ThenMatMulQuantized(
784 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
785 const DeviceMemory<float> &weight_scales,
786 const dnn::BatchDescriptor &input_dimensions,
787 const dnn::BatchDescriptor &output_dimensions,
788 DeviceMemory<float> *output_data) {
789 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
790 PARAM(input_dimensions), PARAM(output_dimensions),
791 PARAM(output_data));
792
793 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
795 input_dimensions, output_dimensions,
796 output_data));
797 } else {
798 SetErrorAndLogNoDnnSupport();
799 }
800 return *this;
801 }
802
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)803 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
804 const DeviceMemory<float> &biases,
805 const dnn::BatchDescriptor &dimensions,
806 DeviceMemory<float> *output_data) {
807 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
808 PARAM(output_data));
809
810 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
811 CheckError(
812 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
813 } else {
814 SetErrorAndLogNoDnnSupport();
815 }
816 return *this;
817 }
818
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)819 Stream &Stream::ThenPoolForward(
820 const dnn::PoolingDescriptor &pooling_dimensions,
821 const dnn::BatchDescriptor &input_dimensions,
822 const DeviceMemory<double> &input_data,
823 const dnn::BatchDescriptor &output_dimensions,
824 DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
825 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
826 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
827 PARAM(workspace_allocator));
828
829 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
830 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
831 input_data, output_dimensions, output_data,
832 workspace_allocator));
833 } else {
834 SetError();
835 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
836 "without DNN support";
837 }
838 return *this;
839 }
840
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)841 Stream &Stream::ThenPoolForward(
842 const dnn::PoolingDescriptor &pooling_dimensions,
843 const dnn::BatchDescriptor &input_dimensions,
844 const DeviceMemory<float> &input_data,
845 const dnn::BatchDescriptor &output_dimensions,
846 DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
847 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
848 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
849 PARAM(workspace_allocator));
850
851 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
852 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
853 input_data, output_dimensions, output_data,
854 workspace_allocator));
855 } else {
856 SetErrorAndLogNoDnnSupport();
857 }
858 return *this;
859 }
860
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)861 Stream &Stream::ThenPoolForward(
862 const dnn::PoolingDescriptor &pooling_dimensions,
863 const dnn::BatchDescriptor &input_dimensions,
864 const DeviceMemory<Eigen::half> &input_data,
865 const dnn::BatchDescriptor &output_dimensions,
866 DeviceMemory<Eigen::half> *output_data,
867 ScratchAllocator *workspace_allocator) {
868 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
869 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
870 PARAM(workspace_allocator));
871
872 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
873 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
874 input_data, output_dimensions, output_data,
875 workspace_allocator));
876 } else {
877 SetErrorAndLogNoDnnSupport();
878 }
879 return *this;
880 }
881
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)882 Stream &Stream::ThenPoolForward(
883 const dnn::PoolingDescriptor &pooling_dimensions,
884 const dnn::BatchDescriptor &input_dimensions,
885 const DeviceMemory<int8> &input_data,
886 const dnn::BatchDescriptor &output_dimensions,
887 DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
888 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
889 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
890 PARAM(workspace_allocator));
891
892 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
893 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
894 input_data, output_dimensions, output_data,
895 workspace_allocator));
896 } else {
897 SetErrorAndLogNoDnnSupport();
898 }
899 return *this;
900 }
901
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)902 Stream &Stream::ThenPoolBackward(
903 const dnn::PoolingDescriptor &pooling_dimensions,
904 const dnn::BatchDescriptor &input_dimensions,
905 const DeviceMemory<double> &input_data,
906 const dnn::BatchDescriptor &output_dimensions,
907 const DeviceMemory<double> &output_data,
908 const DeviceMemory<double> &input_diff_data,
909 DeviceMemory<double> *output_diff_data,
910 ScratchAllocator *workspace_allocator) {
911 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
912 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
913 PARAM(input_diff_data), PARAM(output_diff_data),
914 PARAM(workspace_allocator));
915
916 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
917 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
918 input_data, output_dimensions, output_data,
919 input_diff_data, output_diff_data,
920 workspace_allocator));
921 } else {
922 SetError();
923 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
924 "without DNN support";
925 }
926 return *this;
927 }
928
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)929 Stream &Stream::ThenPoolBackward(
930 const dnn::PoolingDescriptor &pooling_dimensions,
931 const dnn::BatchDescriptor &input_dimensions,
932 const DeviceMemory<float> &input_data,
933 const dnn::BatchDescriptor &output_dimensions,
934 const DeviceMemory<float> &output_data,
935 const DeviceMemory<float> &input_diff_data,
936 DeviceMemory<float> *output_diff_data,
937 ScratchAllocator *workspace_allocator) {
938 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
939 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
940 PARAM(input_diff_data), PARAM(output_diff_data),
941 PARAM(workspace_allocator));
942
943 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
944 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
945 input_data, output_dimensions, output_data,
946 input_diff_data, output_diff_data,
947 workspace_allocator));
948 } else {
949 SetErrorAndLogNoDnnSupport();
950 }
951 return *this;
952 }
953
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)954 Stream &Stream::ThenPoolBackward(
955 const dnn::PoolingDescriptor &pooling_dimensions,
956 const dnn::BatchDescriptor &input_dimensions,
957 const DeviceMemory<Eigen::half> &input_data,
958 const dnn::BatchDescriptor &output_dimensions,
959 const DeviceMemory<Eigen::half> &output_data,
960 const DeviceMemory<Eigen::half> &input_diff_data,
961 DeviceMemory<Eigen::half> *output_diff_data,
962 ScratchAllocator *workspace_allocator) {
963 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
964 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
965 PARAM(input_diff_data), PARAM(output_diff_data),
966 PARAM(workspace_allocator));
967
968 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
969 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
970 input_data, output_dimensions, output_data,
971 input_diff_data, output_diff_data,
972 workspace_allocator));
973 } else {
974 SetErrorAndLogNoDnnSupport();
975 }
976 return *this;
977 }
978
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)979 Stream &Stream::ThenNormalizeWithDimensions(
980 const dnn::NormalizeDescriptor &normalize_descriptor,
981 const dnn::BatchDescriptor &dimensions,
982 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
983 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
984 PARAM(output_data));
985
986 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
987 CheckError(dnn->DoNormalizeWithDimensions(
988 this, normalize_descriptor, dimensions, input_data, output_data));
989 } else {
990 SetErrorAndLogNoDnnSupport();
991 }
992 return *this;
993 }
994
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)995 Stream &Stream::ThenNormalizeBackwardWithDimensions(
996 const dnn::NormalizeDescriptor &normalize_descriptor,
997 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
998 const DeviceMemory<float> &normalized_data,
999 const DeviceMemory<float> &normalized_variable_gradient,
1000 DeviceMemory<float> *raw_variable_gradient,
1001 ScratchAllocator *workspace_allocator) {
1002 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1003 PARAM(normalized_data), PARAM(normalized_variable_gradient),
1004 PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1005
1006 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1007 CheckError(dnn->DoNormalizeBackwardWithDimensions(
1008 this, normalize_descriptor, dimensions, raw_data, normalized_data,
1009 normalized_variable_gradient, raw_variable_gradient,
1010 workspace_allocator));
1011 } else {
1012 SetErrorAndLogNoDnnSupport();
1013 }
1014 return *this;
1015 }
1016
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1017 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1018 const dnn::BatchDescriptor &dimensions,
1019 const DeviceMemory<float> &input_data,
1020 DeviceMemory<float> *output_data) {
1021 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1022 output_data, /*options=*/0);
1023 }
1024
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1025 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1026 const dnn::BatchDescriptor &dimensions,
1027 const DeviceMemory<float> &input_data,
1028 DeviceMemory<float> *output_data,
1029 uint64 options) {
1030 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1031 PARAM(output_data), PARAM(options));
1032
1033 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1034 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1035 output_data, options));
1036 } else {
1037 SetErrorAndLogNoDnnSupport();
1038 }
1039 return *this;
1040 }
1041
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1042 Stream &Stream::ThenDepthConcatenate(
1043 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1044 port::ArraySlice<const DeviceMemory<float> *> input_data,
1045 DeviceMemory<float> *output_data) {
1046 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1047
1048 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1049 if (input_dimensions[i].count() != input_dimensions[0].count() ||
1050 input_dimensions[i].height() != input_dimensions[0].height() ||
1051 input_dimensions[i].width() != input_dimensions[0].width()) {
1052 SetError();
1053 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1054 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1055 << "input_dimensions[" << i
1056 << "]: " << input_dimensions[i].ToString();
1057 return *this;
1058 }
1059 }
1060
1061 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1062 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1063 output_data));
1064 } else {
1065 SetErrorAndLogNoDnnSupport();
1066 }
1067 return *this;
1068 }
1069
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1070 Stream &Stream::ThenSpaceConcatenate(
1071 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1072 port::ArraySlice<const DeviceMemory<float> *> input_data,
1073 DeviceMemory<float> *output_data,
1074 dnn::SpaceConcatenateMode concat_direction) {
1075 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1076
1077 // Check that the input dimensions of all the other batches match those of the
1078 // first batch.
1079 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1080 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1081 (input_dimensions[i].count() != input_dimensions[0].count() ||
1082 input_dimensions[i].height() != input_dimensions[0].height() ||
1083 input_dimensions[i].feature_map_count() !=
1084 input_dimensions[0].feature_map_count())) {
1085 SetError();
1086 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1087 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1088 << "input_dimensions[" << i
1089 << "]: " << input_dimensions[i].ToString();
1090 return *this;
1091 }
1092
1093 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1094 (input_dimensions[i].count() != input_dimensions[0].count() ||
1095 input_dimensions[i].width() != input_dimensions[0].width() ||
1096 input_dimensions[i].feature_map_count() !=
1097 input_dimensions[0].feature_map_count())) {
1098 SetError();
1099 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1100 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1101 << "input_dimensions[" << i
1102 << "]: " << input_dimensions[i].ToString();
1103 return *this;
1104 }
1105 }
1106 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1107 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1108 output_data, concat_direction));
1109 } else {
1110 SetErrorAndLogNoDnnSupport();
1111 }
1112 return *this;
1113 }
1114
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1115 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1116 const DeviceMemory<float> &input_data,
1117 const dnn::BatchDescriptor &output_dimensions,
1118 DeviceMemory<float> *output_data) {
1119 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1120 PARAM(output_dimensions), PARAM(output_data));
1121
1122 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1123 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1124 output_dimensions, output_data));
1125 } else {
1126 SetErrorAndLogNoDnnSupport();
1127 }
1128 return *this;
1129 }
1130
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1131 Stream &Stream::ThenDepthToSpace(
1132 const dnn::BatchDescriptor &input_dimensions,
1133 const DeviceMemory<float> &input_data,
1134 const dnn::DepthToSpaceLayout &depth_to_space_layout,
1135 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1136 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1137 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1138 PARAM(output_data));
1139
1140 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1141 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1142 depth_to_space_layout, sqrt_depth_reduction,
1143 output_data));
1144 } else {
1145 SetErrorAndLogNoDnnSupport();
1146 }
1147 return *this;
1148 }
1149
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1150 Stream &Stream::ThenSpaceToDepth(
1151 const dnn::BatchDescriptor &input_dimensions,
1152 const DeviceMemory<float> &input_data,
1153 const dnn::DepthToSpaceLayout &space_to_depth_layout,
1154 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1155 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1156 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1157 PARAM(output_data));
1158
1159 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1160 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1161 space_to_depth_layout, sqrt_depth_increase,
1162 output_data));
1163 } else {
1164 SetErrorAndLogNoDnnSupport();
1165 }
1166 return *this;
1167 }
1168
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1169 Stream &Stream::ThenElementwiseOperate(
1170 dnn::ElementwiseOperation operation,
1171 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1172 port::ArraySlice<const DeviceMemory<float> *> input_data,
1173 const dnn::BatchDescriptor &output_dimensions,
1174 DeviceMemory<float> *output_data) {
1175 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1176 PARAM(output_dimensions), PARAM(output_data));
1177
1178 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1179 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1180 input_data, output_dimensions,
1181 output_data));
1182 } else {
1183 SetErrorAndLogNoDnnSupport();
1184 }
1185 return *this;
1186 }
1187
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1188 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1189 dnn::ElementwiseOperation operation,
1190 port::ArraySlice<int> input_multiplicands, int output_divisor,
1191 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1192 port::ArraySlice<const DeviceMemory<float> *> input_data,
1193 const dnn::BatchDescriptor &output_dimensions,
1194 DeviceMemory<float> *output_data) {
1195 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1196 PARAM(input_dimensions), PARAM(input_data),
1197 PARAM(output_dimensions), PARAM(output_data));
1198
1199 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1200 CheckError(dnn->DoElementwiseOperateScaledQuantized(
1201 this, operation, input_multiplicands, output_divisor, input_dimensions,
1202 input_data, output_dimensions, output_data));
1203 } else {
1204 SetErrorAndLogNoDnnSupport();
1205 }
1206 return *this;
1207 }
1208
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1209 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1210 const DeviceMemory<float> &input_data, int64 left_pad,
1211 int64 right_pad, int64 top_pad, int64 bottom_pad,
1212 DeviceMemory<float> *output_data) {
1213 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1214 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1215 PARAM(output_data));
1216
1217 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1218 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1219 top_pad, bottom_pad, output_data));
1220 } else {
1221 SetErrorAndLogNoDnnSupport();
1222 }
1223 return *this;
1224 }
1225
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1226 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1227 const DeviceMemory<float> &input_data,
1228 int64 left_trim, int64 right_trim, int64 top_trim,
1229 int64 bottom_trim,
1230 DeviceMemory<float> *output_data) {
1231 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1232 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1233 PARAM(output_data));
1234
1235 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1236 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1237 right_trim, top_trim, bottom_trim, output_data));
1238 } else {
1239 SetErrorAndLogNoDnnSupport();
1240 }
1241 return *this;
1242 }
1243
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1244 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1245 const DeviceMemory<float> &input_data,
1246 int64 replicate_x, int64 replicate_y,
1247 DeviceMemory<float> *output_data) {
1248 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1249 PARAM(replicate_y), PARAM(output_data));
1250
1251 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1252 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1253 replicate_y, output_data));
1254 } else {
1255 SetErrorAndLogNoDnnSupport();
1256 }
1257 return *this;
1258 }
1259
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1260 Stream &Stream::ThenMemcpyD2HQuantized(
1261 const DeviceMemory<float> &gpu_unquantized_src,
1262 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1263 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1264 PARAM(size));
1265
1266 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1267 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1268 host_dst, size));
1269 } else {
1270 SetErrorAndLogNoDnnSupport();
1271 }
1272 return *this;
1273 }
1274
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1275 Stream &Stream::ThenMemcpyH2DQuantized(
1276 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1277 DeviceMemory<float> *gpu_unquantized_dst) {
1278 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1279 PARAM(gpu_unquantized_dst));
1280
1281 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1282 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1283 gpu_unquantized_dst));
1284 } else {
1285 SetErrorAndLogNoDnnSupport();
1286 }
1287 return *this;
1288 }
1289
GetOrCreateSubStream()1290 Stream *Stream::GetOrCreateSubStream() {
1291 // Do not destroy bad streams when holding mu_ because ~Stream() may
1292 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1293 std::vector<std::unique_ptr<Stream>> bad_streams;
1294
1295 absl::MutexLock lock(&mu_);
1296
1297 // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1298 // we encounter along the way.
1299 for (size_t index = 0; index < sub_streams_.size();) {
1300 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1301 if (pair.second) {
1302 // The sub_stream is reusable.
1303 Stream *sub_stream = pair.first.get();
1304 if (sub_stream->ok()) {
1305 VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1306 << sub_stream->DebugStreamPointers();
1307 pair.second = false;
1308 return sub_stream;
1309 }
1310
1311 // The stream is reusable and not ok. Streams have a monotonic state
1312 // machine; the stream will remain in !ok forever. Swap it with the last
1313 // stream and pop it off.
1314 const int64 last = sub_streams_.size() - 1;
1315 if (index != last) {
1316 std::swap(pair, sub_streams_[last]);
1317 }
1318 bad_streams.push_back(std::move(sub_streams_.back().first));
1319 sub_streams_.pop_back();
1320 VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1321 << sub_stream->DebugStreamPointers();
1322 } else {
1323 // The sub_stream is not reusable, move on to the next one.
1324 ++index;
1325 }
1326 }
1327
1328 // No streams are reusable; create a new stream.
1329 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1330 false);
1331 Stream *sub_stream = sub_streams_.back().first.get();
1332 sub_stream->Init();
1333 if (!sub_stream->ok()) {
1334 LOG(ERROR) << "sub-stream failed to be initialized";
1335 }
1336 VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1337 << sub_stream->DebugStreamPointers();
1338
1339 return sub_stream;
1340 }
1341
ReturnSubStream(Stream * sub_stream)1342 void Stream::ReturnSubStream(Stream *sub_stream) {
1343 // Do not destroy bad streams when holding mu_ because ~Stream() may
1344 // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1345 std::unique_ptr<Stream> bad_stream;
1346
1347 absl::MutexLock lock(&mu_);
1348
1349 // Look for the sub-stream.
1350 for (int64 index = 0, end = sub_streams_.size(); index < end; ++index) {
1351 std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1352 if (pair.first.get() != sub_stream) {
1353 continue;
1354 }
1355
1356 // Found the sub_stream.
1357 if (sub_stream->ok()) {
1358 VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1359 << sub_stream->DebugStreamPointers();
1360 pair.second = true;
1361 } else {
1362 // The returned stream is not ok. Streams have a monotonic state
1363 // machine; the stream will remain in !ok forever. Swap it with the last
1364 // stream and pop it off.
1365 VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1366 << sub_stream->DebugStreamPointers();
1367 const int64 last = sub_streams_.size() - 1;
1368 if (index != last) {
1369 std::swap(pair, sub_streams_[last]);
1370 }
1371 std::swap(bad_stream, sub_streams_.back().first);
1372 sub_streams_.pop_back();
1373 }
1374 return;
1375 }
1376
1377 LOG(FATAL) << DebugStreamPointers()
1378 << " did not create the returned sub-stream "
1379 << sub_stream->DebugStreamPointers();
1380 }
1381
ThenStartTimer(Timer * t)1382 Stream &Stream::ThenStartTimer(Timer *t) {
1383 VLOG_CALL(PARAM(t));
1384
1385 CheckError(parent_->StartTimer(this, t));
1386 return *this;
1387 }
1388
ThenStopTimer(Timer * t)1389 Stream &Stream::ThenStopTimer(Timer *t) {
1390 VLOG_CALL(PARAM(t));
1391
1392 CheckError(parent_->StopTimer(this, t));
1393 return *this;
1394 }
1395
ThenWaitFor(Stream * other)1396 Stream &Stream::ThenWaitFor(Stream *other) {
1397 VLOG_CALL(PARAM(other));
1398
1399 CHECK(this != other) << "stream cannot wait for itself";
1400 if (ok() && other->ok()) {
1401 CheckError(parent_->CreateStreamDependency(this, other));
1402 } else {
1403 SetError();
1404 LOG(INFO) << DebugStreamPointers() << " did not wait for "
1405 << other->DebugStreamPointers();
1406 }
1407 return *this;
1408 }
1409
ThenWaitFor(Event * event)1410 Stream &Stream::ThenWaitFor(Event *event) {
1411 VLOG_CALL(PARAM(event));
1412
1413 if (ok()) {
1414 port::Status status = parent_->WaitForEvent(this, event);
1415 if (!status.ok()) {
1416 LOG(ERROR) << "Error waiting for event in stream: "
1417 << status.error_message()
1418 << "; not marking stream as bad, as the Event object may be "
1419 << "at fault. Monitor for further errors.";
1420 }
1421 } else {
1422 LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1423 }
1424 return *this;
1425 }
1426
1427 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1428 // functions and logs for errors.
1429 template <typename... Args>
1430 struct ThenBlasImpl {
1431 // blas_func is the DoBlasXXX member function pointer, and args are its
1432 // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1433 Stream &operator()(Stream *stream,
1434 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1435 Args... args) {
1436 return Run(stream, blas_func, /*record_error=*/true, args...);
1437 }
1438
1439 // Like operator(), but only calls stream->CheckError() if record_error is
1440 // true.
1441 Stream &Run(Stream *stream,
1442 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1443 bool record_error, Args... args);
1444 };
1445
1446 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1447 Stream &ThenBlasImpl<Args...>::Run(
1448 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1449 bool record_error, Args... args) {
1450 if (stream->ok()) {
1451 bool ok;
1452 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1453 ok = (blas->*blas_func)(stream, args...);
1454 } else {
1455 LOG(WARNING)
1456 << "attempting to perform BLAS operation using StreamExecutor "
1457 "without BLAS support";
1458 ok = false;
1459 }
1460 if (record_error) {
1461 stream->CheckError(ok);
1462 }
1463 }
1464 return *stream;
1465 }
1466
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1467 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1468 int incx, DeviceMemory<float> *result) {
1469 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1470
1471 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1472 impl;
1473 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1474 result);
1475 }
1476
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1477 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1478 int incx, DeviceMemory<double> *result) {
1479 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1480
1481 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1482 DeviceMemory<double> *>
1483 impl;
1484 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1485 result);
1486 }
1487
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1488 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1489 const DeviceMemory<std::complex<float>> &x,
1490 int incx, DeviceMemory<float> *result) {
1491 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1492
1493 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1494 DeviceMemory<float> *>
1495 impl;
1496 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1497 result);
1498 }
1499
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1500 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1501 const DeviceMemory<std::complex<double>> &x,
1502 int incx, DeviceMemory<double> *result) {
1503 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1504
1505 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1506 DeviceMemory<double> *>
1507 impl;
1508 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1509 result);
1510 }
1511
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1512 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1513 const DeviceMemory<float> &x, int incx,
1514 DeviceMemory<float> *y, int incy) {
1515 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1516 PARAM(incy));
1517
1518 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1519 DeviceMemory<float> *, int>
1520 impl;
1521 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1522 y, incy);
1523 }
1524
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1525 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1526 const DeviceMemory<double> &x, int incx,
1527 DeviceMemory<double> *y, int incy) {
1528 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1529 PARAM(incy));
1530
1531 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1532 DeviceMemory<double> *, int>
1533 impl;
1534 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1535 y, incy);
1536 }
1537
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1538 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1539 const DeviceMemory<std::complex<float>> &x,
1540 int incx, DeviceMemory<std::complex<float>> *y,
1541 int incy) {
1542 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1543 PARAM(incy));
1544
1545 ThenBlasImpl<uint64, std::complex<float>,
1546 const DeviceMemory<std::complex<float>> &, int,
1547 DeviceMemory<std::complex<float>> *, int>
1548 impl;
1549 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1550 y, incy);
1551 }
1552
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1553 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1554 const DeviceMemory<std::complex<double>> &x,
1555 int incx, DeviceMemory<std::complex<double>> *y,
1556 int incy) {
1557 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1558 PARAM(incy));
1559
1560 ThenBlasImpl<uint64, std::complex<double>,
1561 const DeviceMemory<std::complex<double>> &, int,
1562 DeviceMemory<std::complex<double>> *, int>
1563 impl;
1564 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1565 y, incy);
1566 }
1567
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1568 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
1569 int incx, DeviceMemory<float> *y, int incy) {
1570 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1571
1572 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
1573 int>
1574 impl;
1575 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1576 incy);
1577 }
1578
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1579 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
1580 int incx, DeviceMemory<double> *y, int incy) {
1581 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1582
1583 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1584 DeviceMemory<double> *, int>
1585 impl;
1586 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1587 incy);
1588 }
1589
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1590 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1591 const DeviceMemory<std::complex<float>> &x,
1592 int incx, DeviceMemory<std::complex<float>> *y,
1593 int incy) {
1594 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1595
1596 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1597 DeviceMemory<std::complex<float>> *, int>
1598 impl;
1599 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1600 incy);
1601 }
1602
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1603 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1604 const DeviceMemory<std::complex<double>> &x,
1605 int incx, DeviceMemory<std::complex<double>> *y,
1606 int incy) {
1607 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1608
1609 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1610 DeviceMemory<std::complex<double>> *, int>
1611 impl;
1612 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1613 incy);
1614 }
1615
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)1616 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
1617 int incx, const DeviceMemory<float> &y, int incy,
1618 DeviceMemory<float> *result) {
1619 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1620 PARAM(result));
1621
1622 ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
1623 const DeviceMemory<float> &, int, DeviceMemory<float> *>
1624 impl;
1625 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1626 result);
1627 }
1628
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)1629 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
1630 int incx, const DeviceMemory<double> &y, int incy,
1631 DeviceMemory<double> *result) {
1632 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1633 PARAM(result));
1634
1635 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1636 const DeviceMemory<double> &, int, DeviceMemory<double> *>
1637 impl;
1638 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1639 result);
1640 }
1641
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1642 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1643 const DeviceMemory<std::complex<float>> &x,
1644 int incx,
1645 const DeviceMemory<std::complex<float>> &y,
1646 int incy,
1647 DeviceMemory<std::complex<float>> *result) {
1648 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1649 PARAM(result));
1650
1651 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1652 const DeviceMemory<std::complex<float>> &, int,
1653 DeviceMemory<std::complex<float>> *>
1654 impl;
1655 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1656 incy, result);
1657 }
1658
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1659 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1660 const DeviceMemory<std::complex<double>> &x,
1661 int incx,
1662 const DeviceMemory<std::complex<double>> &y,
1663 int incy,
1664 DeviceMemory<std::complex<double>> *result) {
1665 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1666 PARAM(result));
1667
1668 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1669 const DeviceMemory<std::complex<double>> &, int,
1670 DeviceMemory<std::complex<double>> *>
1671 impl;
1672 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1673 incy, result);
1674 }
1675
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1676 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1677 const DeviceMemory<std::complex<float>> &x,
1678 int incx,
1679 const DeviceMemory<std::complex<float>> &y,
1680 int incy,
1681 DeviceMemory<std::complex<float>> *result) {
1682 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1683 PARAM(result));
1684
1685 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1686 const DeviceMemory<std::complex<float>> &, int,
1687 DeviceMemory<std::complex<float>> *>
1688 impl;
1689 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1690 incy, result);
1691 }
1692
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1693 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1694 const DeviceMemory<std::complex<double>> &x,
1695 int incx,
1696 const DeviceMemory<std::complex<double>> &y,
1697 int incy,
1698 DeviceMemory<std::complex<double>> *result) {
1699 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1700 PARAM(result));
1701
1702 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1703 const DeviceMemory<std::complex<double>> &, int,
1704 DeviceMemory<std::complex<double>> *>
1705 impl;
1706 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1707 incy, result);
1708 }
1709
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1710 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
1711 int incx, DeviceMemory<float> *result) {
1712 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1713
1714 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1715 impl;
1716 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1717 result);
1718 }
1719
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1720 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
1721 int incx, DeviceMemory<double> *result) {
1722 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1723
1724 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1725 DeviceMemory<double> *>
1726 impl;
1727 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1728 result);
1729 }
1730
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1731 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1732 const DeviceMemory<std::complex<float>> &x,
1733 int incx, DeviceMemory<float> *result) {
1734 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1735
1736 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1737 DeviceMemory<float> *>
1738 impl;
1739 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1740 result);
1741 }
1742
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1743 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1744 const DeviceMemory<std::complex<double>> &x,
1745 int incx, DeviceMemory<double> *result) {
1746 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1747
1748 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1749 DeviceMemory<double> *>
1750 impl;
1751 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1752 result);
1753 }
1754
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)1755 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
1756 DeviceMemory<float> *y, int incy, float c,
1757 float s) {
1758 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1759 PARAM(c), PARAM(s));
1760
1761 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1762 float, float>
1763 impl;
1764 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1765 c, s);
1766 }
1767
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)1768 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
1769 int incx, DeviceMemory<double> *y, int incy,
1770 double c, double s) {
1771 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1772 PARAM(c), PARAM(s));
1773
1774 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1775 double, double>
1776 impl;
1777 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1778 c, s);
1779 }
1780
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)1781 Stream &Stream::ThenBlasRot(uint64 elem_count,
1782 DeviceMemory<std::complex<float>> *x, int incx,
1783 DeviceMemory<std::complex<float>> *y, int incy,
1784 float c, float s) {
1785 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1786 PARAM(c), PARAM(s));
1787
1788 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1789 DeviceMemory<std::complex<float>> *, int, float, float>
1790 impl;
1791 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1792 c, s);
1793 }
1794
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)1795 Stream &Stream::ThenBlasRot(uint64 elem_count,
1796 DeviceMemory<std::complex<double>> *x, int incx,
1797 DeviceMemory<std::complex<double>> *y, int incy,
1798 double c, double s) {
1799 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1800 PARAM(c), PARAM(s));
1801
1802 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1803 DeviceMemory<std::complex<double>> *, int, double, double>
1804 impl;
1805 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1806 c, s);
1807 }
1808
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)1809 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
1810 DeviceMemory<float> *c, DeviceMemory<float> *s) {
1811 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1812
1813 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1814 DeviceMemory<float> *, DeviceMemory<float> *>
1815 impl;
1816 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1817 }
1818
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)1819 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
1820 DeviceMemory<double> *c, DeviceMemory<double> *s) {
1821 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1822
1823 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1824 DeviceMemory<double> *, DeviceMemory<double> *>
1825 impl;
1826 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1827 }
1828
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)1829 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
1830 DeviceMemory<std::complex<float>> *b,
1831 DeviceMemory<float> *c,
1832 DeviceMemory<std::complex<float>> *s) {
1833 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1834
1835 ThenBlasImpl<DeviceMemory<std::complex<float>> *,
1836 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
1837 DeviceMemory<std::complex<float>> *>
1838 impl;
1839 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1840 }
1841
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)1842 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
1843 DeviceMemory<std::complex<double>> *b,
1844 DeviceMemory<double> *c,
1845 DeviceMemory<std::complex<double>> *s) {
1846 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1847
1848 ThenBlasImpl<DeviceMemory<std::complex<double>> *,
1849 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
1850 DeviceMemory<std::complex<double>> *>
1851 impl;
1852 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1853 }
1854
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)1855 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
1856 int incx, DeviceMemory<float> *y, int incy,
1857 const DeviceMemory<float> ¶m) {
1858 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1859 PARAM(param));
1860
1861 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1862 const DeviceMemory<float> &>
1863 impl;
1864 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1865 incy, param);
1866 }
1867
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)1868 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
1869 int incx, DeviceMemory<double> *y, int incy,
1870 const DeviceMemory<double> ¶m) {
1871 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1872 PARAM(param));
1873
1874 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1875 const DeviceMemory<double> &>
1876 impl;
1877 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1878 incy, param);
1879 }
1880
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)1881 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
1882 DeviceMemory<float> *x1,
1883 const DeviceMemory<float> &y1,
1884 DeviceMemory<float> *param) {
1885 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1886
1887 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1888 DeviceMemory<float> *, const DeviceMemory<float> &,
1889 DeviceMemory<float> *>
1890 impl;
1891 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1892 }
1893
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)1894 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
1895 DeviceMemory<double> *d2,
1896 DeviceMemory<double> *x1,
1897 const DeviceMemory<double> &y1,
1898 DeviceMemory<double> *param) {
1899 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1900
1901 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1902 DeviceMemory<double> *, const DeviceMemory<double> &,
1903 DeviceMemory<double> *>
1904 impl;
1905 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1906 }
1907
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)1908 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1909 DeviceMemory<float> *x, int incx) {
1910 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1911
1912 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
1913 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1914 }
1915
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)1916 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1917 DeviceMemory<double> *x, int incx) {
1918 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1919
1920 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
1921 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1922 }
1923
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1924 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1925 DeviceMemory<std::complex<float>> *x, int incx) {
1926 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1927
1928 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
1929 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1930 }
1931
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1932 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1933 DeviceMemory<std::complex<double>> *x, int incx) {
1934 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1935
1936 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
1937 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1938 }
1939
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1940 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
1941 DeviceMemory<std::complex<float>> *x, int incx) {
1942 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1943
1944 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
1945 int>
1946 impl;
1947 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1948 }
1949
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1950 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
1951 DeviceMemory<std::complex<double>> *x, int incx) {
1952 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1953
1954 ThenBlasImpl<uint64, std::complex<double>,
1955 DeviceMemory<std::complex<double>> *, int>
1956 impl;
1957 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1958 }
1959
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)1960 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
1961 int incx, DeviceMemory<float> *y, int incy) {
1962 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1963
1964 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
1965 impl;
1966 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1967 incy);
1968 }
1969
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)1970 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
1971 int incx, DeviceMemory<double> *y, int incy) {
1972 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1973
1974 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
1975 impl;
1976 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1977 incy);
1978 }
1979
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1980 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1981 DeviceMemory<std::complex<float>> *x, int incx,
1982 DeviceMemory<std::complex<float>> *y, int incy) {
1983 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1984
1985 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1986 DeviceMemory<std::complex<float>> *, int>
1987 impl;
1988 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1989 incy);
1990 }
1991
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1992 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1993 DeviceMemory<std::complex<double>> *x, int incx,
1994 DeviceMemory<std::complex<double>> *y, int incy) {
1995 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1996
1997 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1998 DeviceMemory<std::complex<double>> *, int>
1999 impl;
2000 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2001 incy);
2002 }
2003
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2004 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2005 int incx, DeviceMemory<int> *result) {
2006 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2007
2008 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2009 impl;
2010 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2011 result);
2012 }
2013
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2014 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2015 int incx, DeviceMemory<int> *result) {
2016 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2017
2018 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2019 impl;
2020 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2021 result);
2022 }
2023
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2024 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2025 const DeviceMemory<std::complex<float>> &x,
2026 int incx, DeviceMemory<int> *result) {
2027 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2028
2029 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2030 DeviceMemory<int> *>
2031 impl;
2032 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2033 result);
2034 }
2035
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2036 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2037 const DeviceMemory<std::complex<double>> &x,
2038 int incx, DeviceMemory<int> *result) {
2039 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2040
2041 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2042 DeviceMemory<int> *>
2043 impl;
2044 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2045 result);
2046 }
2047
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2048 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2049 int incx, DeviceMemory<int> *result) {
2050 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2051
2052 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2053 impl;
2054 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2055 result);
2056 }
2057
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2058 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2059 int incx, DeviceMemory<int> *result) {
2060 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2061
2062 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2063 impl;
2064 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2065 result);
2066 }
2067
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2068 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2069 const DeviceMemory<std::complex<float>> &x,
2070 int incx, DeviceMemory<int> *result) {
2071 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2072
2073 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2074 DeviceMemory<int> *>
2075 impl;
2076 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2077 result);
2078 }
2079
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2080 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2081 const DeviceMemory<std::complex<double>> &x,
2082 int incx, DeviceMemory<int> *result) {
2083 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2084
2085 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2086 DeviceMemory<int> *>
2087 impl;
2088 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2089 result);
2090 }
2091
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2092 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2093 uint64 kl, uint64 ku, float alpha,
2094 const DeviceMemory<float> &a, int lda,
2095 const DeviceMemory<float> &x, int incx, float beta,
2096 DeviceMemory<float> *y, int incy) {
2097 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2098 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2099 PARAM(beta), PARAM(y), PARAM(incy));
2100
2101 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2102 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2103 int, float, DeviceMemory<float> *, int>
2104 impl;
2105 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2106 a, lda, x, incx, beta, y, incy);
2107 }
2108
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2109 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2110 uint64 kl, uint64 ku, double alpha,
2111 const DeviceMemory<double> &a, int lda,
2112 const DeviceMemory<double> &x, int incx,
2113 double beta, DeviceMemory<double> *y, int incy) {
2114 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2115 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2116 PARAM(beta), PARAM(y), PARAM(incy));
2117
2118 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2119 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2120 int, double, DeviceMemory<double> *, int>
2121 impl;
2122 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2123 a, lda, x, incx, beta, y, incy);
2124 }
2125
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2126 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2127 uint64 kl, uint64 ku, std::complex<float> alpha,
2128 const DeviceMemory<std::complex<float>> &a,
2129 int lda,
2130 const DeviceMemory<std::complex<float>> &x,
2131 int incx, std::complex<float> beta,
2132 DeviceMemory<std::complex<float>> *y, int incy) {
2133 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2134 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2135 PARAM(beta), PARAM(y), PARAM(incy));
2136
2137 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2138 std::complex<float>, const DeviceMemory<std::complex<float>> &,
2139 int, const DeviceMemory<std::complex<float>> &, int,
2140 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2141 impl;
2142 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2143 a, lda, x, incx, beta, y, incy);
2144 }
2145
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2146 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2147 uint64 kl, uint64 ku, std::complex<double> alpha,
2148 const DeviceMemory<std::complex<double>> &a,
2149 int lda,
2150 const DeviceMemory<std::complex<double>> &x,
2151 int incx, std::complex<double> beta,
2152 DeviceMemory<std::complex<double>> *y, int incy) {
2153 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2154 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2155 PARAM(beta), PARAM(y), PARAM(incy));
2156
2157 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2158 std::complex<double>, const DeviceMemory<std::complex<double>> &,
2159 int, const DeviceMemory<std::complex<double>> &, int,
2160 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2161 impl;
2162 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2163 a, lda, x, incx, beta, y, incy);
2164 }
2165
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2166 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2167 float alpha, const DeviceMemory<float> &a, int lda,
2168 const DeviceMemory<float> &x, int incx, float beta,
2169 DeviceMemory<float> *y, int incy) {
2170 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2171 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2172 PARAM(incy));
2173
2174 ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2175 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2176 int, float, DeviceMemory<float> *, int>
2177 impl;
2178 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2179 x, incx, beta, y, incy);
2180 }
2181
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2182 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2183 double alpha, const DeviceMemory<double> &a,
2184 int lda, const DeviceMemory<double> &x, int incx,
2185 double beta, DeviceMemory<double> *y, int incy) {
2186 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2187 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2188 PARAM(incy));
2189
2190 ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2191 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2192 int, double, DeviceMemory<double> *, int>
2193 impl;
2194 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2195 x, incx, beta, y, incy);
2196 }
2197
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2198 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2199 std::complex<float> alpha,
2200 const DeviceMemory<std::complex<float>> &a,
2201 int lda,
2202 const DeviceMemory<std::complex<float>> &x,
2203 int incx, std::complex<float> beta,
2204 DeviceMemory<std::complex<float>> *y, int incy) {
2205 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2206 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2207 PARAM(incy));
2208
2209 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2210 const DeviceMemory<std::complex<float>> &, int,
2211 const DeviceMemory<std::complex<float>> &, int,
2212 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2213 impl;
2214 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2215 x, incx, beta, y, incy);
2216 }
2217
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2218 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2219 std::complex<double> alpha,
2220 const DeviceMemory<std::complex<double>> &a,
2221 int lda,
2222 const DeviceMemory<std::complex<double>> &x,
2223 int incx, std::complex<double> beta,
2224 DeviceMemory<std::complex<double>> *y, int incy) {
2225 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2226 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2227 PARAM(incy));
2228
2229 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2230 const DeviceMemory<std::complex<double>> &, int,
2231 const DeviceMemory<std::complex<double>> &, int,
2232 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2233 impl;
2234 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2235 x, incx, beta, y, incy);
2236 }
2237
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2238 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2239 const DeviceMemory<float> &x, int incx,
2240 const DeviceMemory<float> &y, int incy,
2241 DeviceMemory<float> *a, int lda) {
2242 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2243 PARAM(incy), PARAM(a), PARAM(lda));
2244
2245 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2246 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2247 impl;
2248 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2249 incy, a, lda);
2250 }
2251
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2252 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2253 const DeviceMemory<double> &x, int incx,
2254 const DeviceMemory<double> &y, int incy,
2255 DeviceMemory<double> *a, int lda) {
2256 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2257 PARAM(incy), PARAM(a), PARAM(lda));
2258
2259 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2260 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2261 impl;
2262 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2263 incy, a, lda);
2264 }
2265
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2266 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2267 const DeviceMemory<std::complex<float>> &x,
2268 int incx,
2269 const DeviceMemory<std::complex<float>> &y,
2270 int incy, DeviceMemory<std::complex<float>> *a,
2271 int lda) {
2272 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2273 PARAM(incy), PARAM(a), PARAM(lda));
2274
2275 ThenBlasImpl<uint64, uint64, std::complex<float>,
2276 const DeviceMemory<std::complex<float>> &, int,
2277 const DeviceMemory<std::complex<float>> &, int,
2278 DeviceMemory<std::complex<float>> *, int>
2279 impl;
2280 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2281 incy, a, lda);
2282 }
2283
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2284 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2285 const DeviceMemory<std::complex<double>> &x,
2286 int incx,
2287 const DeviceMemory<std::complex<double>> &y,
2288 int incy, DeviceMemory<std::complex<double>> *a,
2289 int lda) {
2290 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2291 PARAM(incy), PARAM(a), PARAM(lda));
2292
2293 ThenBlasImpl<uint64, uint64, std::complex<double>,
2294 const DeviceMemory<std::complex<double>> &, int,
2295 const DeviceMemory<std::complex<double>> &, int,
2296 DeviceMemory<std::complex<double>> *, int>
2297 impl;
2298 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2299 incy, a, lda);
2300 }
2301
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2302 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2303 const DeviceMemory<std::complex<float>> &x,
2304 int incx,
2305 const DeviceMemory<std::complex<float>> &y,
2306 int incy, DeviceMemory<std::complex<float>> *a,
2307 int lda) {
2308 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2309 PARAM(incy), PARAM(a), PARAM(lda));
2310
2311 ThenBlasImpl<uint64, uint64, std::complex<float>,
2312 const DeviceMemory<std::complex<float>> &, int,
2313 const DeviceMemory<std::complex<float>> &, int,
2314 DeviceMemory<std::complex<float>> *, int>
2315 impl;
2316 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2317 incy, a, lda);
2318 }
2319
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2320 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2321 const DeviceMemory<std::complex<double>> &x,
2322 int incx,
2323 const DeviceMemory<std::complex<double>> &y,
2324 int incy, DeviceMemory<std::complex<double>> *a,
2325 int lda) {
2326 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2327 PARAM(incy), PARAM(a), PARAM(lda));
2328
2329 ThenBlasImpl<uint64, uint64, std::complex<double>,
2330 const DeviceMemory<std::complex<double>> &, int,
2331 const DeviceMemory<std::complex<double>> &, int,
2332 DeviceMemory<std::complex<double>> *, int>
2333 impl;
2334 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2335 incy, a, lda);
2336 }
2337
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2338 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2339 std::complex<float> alpha,
2340 const DeviceMemory<std::complex<float>> &a,
2341 int lda,
2342 const DeviceMemory<std::complex<float>> &x,
2343 int incx, std::complex<float> beta,
2344 DeviceMemory<std::complex<float>> *y, int incy) {
2345 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2346 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2347
2348 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2349 const DeviceMemory<std::complex<float>> &, int,
2350 const DeviceMemory<std::complex<float>> &, int,
2351 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2352 impl;
2353 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2354 x, incx, beta, y, incy);
2355 }
2356
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2357 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2358 std::complex<double> alpha,
2359 const DeviceMemory<std::complex<double>> &a,
2360 int lda,
2361 const DeviceMemory<std::complex<double>> &x,
2362 int incx, std::complex<double> beta,
2363 DeviceMemory<std::complex<double>> *y, int incy) {
2364 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2365 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2366
2367 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2368 const DeviceMemory<std::complex<double>> &, int,
2369 const DeviceMemory<std::complex<double>> &, int,
2370 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2371 impl;
2372 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2373 x, incx, beta, y, incy);
2374 }
2375
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2376 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2377 std::complex<float> alpha,
2378 const DeviceMemory<std::complex<float>> &a,
2379 int lda,
2380 const DeviceMemory<std::complex<float>> &x,
2381 int incx, std::complex<float> beta,
2382 DeviceMemory<std::complex<float>> *y, int incy) {
2383 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2384 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2385
2386 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2387 const DeviceMemory<std::complex<float>> &, int,
2388 const DeviceMemory<std::complex<float>> &, int,
2389 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2390 impl;
2391 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2392 incx, beta, y, incy);
2393 }
2394
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2395 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2396 std::complex<double> alpha,
2397 const DeviceMemory<std::complex<double>> &a,
2398 int lda,
2399 const DeviceMemory<std::complex<double>> &x,
2400 int incx, std::complex<double> beta,
2401 DeviceMemory<std::complex<double>> *y, int incy) {
2402 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2403 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2404
2405 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2406 const DeviceMemory<std::complex<double>> &, int,
2407 const DeviceMemory<std::complex<double>> &, int,
2408 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2409 impl;
2410 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2411 incx, beta, y, incy);
2412 }
2413
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2414 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2415 const DeviceMemory<std::complex<float>> &x,
2416 int incx, DeviceMemory<std::complex<float>> *a,
2417 int lda) {
2418 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2419 PARAM(a), PARAM(lda));
2420
2421 ThenBlasImpl<blas::UpperLower, uint64, float,
2422 const DeviceMemory<std::complex<float>> &, int,
2423 DeviceMemory<std::complex<float>> *, int>
2424 impl;
2425 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2426 lda);
2427 }
2428
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2429 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2430 const DeviceMemory<std::complex<double>> &x,
2431 int incx, DeviceMemory<std::complex<double>> *a,
2432 int lda) {
2433 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2434 PARAM(a), PARAM(lda));
2435
2436 ThenBlasImpl<blas::UpperLower, uint64, double,
2437 const DeviceMemory<std::complex<double>> &, int,
2438 DeviceMemory<std::complex<double>> *, int>
2439 impl;
2440 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2441 lda);
2442 }
2443
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2444 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2445 std::complex<float> alpha,
2446 const DeviceMemory<std::complex<float>> &x,
2447 int incx,
2448 const DeviceMemory<std::complex<float>> &y,
2449 int incy, DeviceMemory<std::complex<float>> *a,
2450 int lda) {
2451 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2452 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2453
2454 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2455 const DeviceMemory<std::complex<float>> &, int,
2456 const DeviceMemory<std::complex<float>> &, int,
2457 DeviceMemory<std::complex<float>> *, int>
2458 impl;
2459 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2460 incy, a, lda);
2461 }
2462
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2463 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2464 std::complex<double> alpha,
2465 const DeviceMemory<std::complex<double>> &x,
2466 int incx,
2467 const DeviceMemory<std::complex<double>> &y,
2468 int incy, DeviceMemory<std::complex<double>> *a,
2469 int lda) {
2470 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2471 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2472
2473 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2474 const DeviceMemory<std::complex<double>> &, int,
2475 const DeviceMemory<std::complex<double>> &, int,
2476 DeviceMemory<std::complex<double>> *, int>
2477 impl;
2478 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2479 incy, a, lda);
2480 }
2481
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2482 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2483 std::complex<float> alpha,
2484 const DeviceMemory<std::complex<float>> &ap,
2485 const DeviceMemory<std::complex<float>> &x,
2486 int incx, std::complex<float> beta,
2487 DeviceMemory<std::complex<float>> *y, int incy) {
2488 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2489 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2490
2491 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2492 const DeviceMemory<std::complex<float>> &,
2493 const DeviceMemory<std::complex<float>> &, int,
2494 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2495 impl;
2496 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2497 beta, y, incy);
2498 }
2499
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2500 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2501 std::complex<double> alpha,
2502 const DeviceMemory<std::complex<double>> &ap,
2503 const DeviceMemory<std::complex<double>> &x,
2504 int incx, std::complex<double> beta,
2505 DeviceMemory<std::complex<double>> *y, int incy) {
2506 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2507 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2508
2509 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2510 const DeviceMemory<std::complex<double>> &,
2511 const DeviceMemory<std::complex<double>> &, int,
2512 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2513 impl;
2514 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2515 beta, y, incy);
2516 }
2517
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2518 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2519 const DeviceMemory<std::complex<float>> &x,
2520 int incx, DeviceMemory<std::complex<float>> *ap) {
2521 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2522 PARAM(ap));
2523
2524 ThenBlasImpl<blas::UpperLower, uint64, float,
2525 const DeviceMemory<std::complex<float>> &, int,
2526 DeviceMemory<std::complex<float>> *>
2527 impl;
2528 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2529 }
2530
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2531 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2532 const DeviceMemory<std::complex<double>> &x,
2533 int incx, DeviceMemory<std::complex<double>> *ap) {
2534 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2535 PARAM(ap));
2536
2537 ThenBlasImpl<blas::UpperLower, uint64, double,
2538 const DeviceMemory<std::complex<double>> &, int,
2539 DeviceMemory<std::complex<double>> *>
2540 impl;
2541 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2542 }
2543
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2544 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2545 std::complex<float> alpha,
2546 const DeviceMemory<std::complex<float>> &x,
2547 int incx,
2548 const DeviceMemory<std::complex<float>> &y,
2549 int incy, DeviceMemory<std::complex<float>> *ap) {
2550 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2551 PARAM(y), PARAM(incy), PARAM(ap));
2552
2553 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2554 const DeviceMemory<std::complex<float>> &, int,
2555 const DeviceMemory<std::complex<float>> &, int,
2556 DeviceMemory<std::complex<float>> *>
2557 impl;
2558 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2559 incy, ap);
2560 }
2561
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2562 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2563 std::complex<double> alpha,
2564 const DeviceMemory<std::complex<double>> &x,
2565 int incx,
2566 const DeviceMemory<std::complex<double>> &y,
2567 int incy, DeviceMemory<std::complex<double>> *ap) {
2568 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2569 PARAM(y), PARAM(incy), PARAM(ap));
2570
2571 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2572 const DeviceMemory<std::complex<double>> &, int,
2573 const DeviceMemory<std::complex<double>> &, int,
2574 DeviceMemory<std::complex<double>> *>
2575 impl;
2576 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2577 incy, ap);
2578 }
2579
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2580 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2581 float alpha, const DeviceMemory<float> &a, int lda,
2582 const DeviceMemory<float> &x, int incx, float beta,
2583 DeviceMemory<float> *y, int incy) {
2584 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2585 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2586
2587 ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2588 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2589 int, float, DeviceMemory<float> *, int>
2590 impl;
2591 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2592 x, incx, beta, y, incy);
2593 }
2594
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2595 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2596 double alpha, const DeviceMemory<double> &a,
2597 int lda, const DeviceMemory<double> &x, int incx,
2598 double beta, DeviceMemory<double> *y, int incy) {
2599 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2600 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2601
2602 ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2603 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2604 int, double, DeviceMemory<double> *, int>
2605 impl;
2606 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2607 x, incx, beta, y, incy);
2608 }
2609
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2610 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2611 const DeviceMemory<float> &ap,
2612 const DeviceMemory<float> &x, int incx, float beta,
2613 DeviceMemory<float> *y, int incy) {
2614 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2615 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2616
2617 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2618 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
2619 int>
2620 impl;
2621 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2622 beta, y, incy);
2623 }
2624
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2625 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
2626 const DeviceMemory<double> &ap,
2627 const DeviceMemory<double> &x, int incx,
2628 double beta, DeviceMemory<double> *y, int incy) {
2629 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2630 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2631
2632 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2633 const DeviceMemory<double> &, int, double,
2634 DeviceMemory<double> *, int>
2635 impl;
2636 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2637 beta, y, incy);
2638 }
2639
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)2640 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
2641 const DeviceMemory<float> &x, int incx,
2642 DeviceMemory<float> *ap) {
2643 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2644 PARAM(ap));
2645
2646 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2647 int, DeviceMemory<float> *>
2648 impl;
2649 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2650 }
2651
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)2652 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
2653 const DeviceMemory<double> &x, int incx,
2654 DeviceMemory<double> *ap) {
2655 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2656 PARAM(ap));
2657
2658 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2659 int, DeviceMemory<double> *>
2660 impl;
2661 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2662 }
2663
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)2664 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
2665 const DeviceMemory<float> &x, int incx,
2666 const DeviceMemory<float> &y, int incy,
2667 DeviceMemory<float> *ap) {
2668 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2669 PARAM(y), PARAM(incy), PARAM(ap));
2670
2671 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2672 int, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2673 impl;
2674 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2675 incy, ap);
2676 }
2677
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)2678 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
2679 const DeviceMemory<double> &x, int incx,
2680 const DeviceMemory<double> &y, int incy,
2681 DeviceMemory<double> *ap) {
2682 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2683 PARAM(y), PARAM(incy), PARAM(ap));
2684
2685 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2686 int, const DeviceMemory<double> &, int, DeviceMemory<double> *>
2687 impl;
2688 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2689 incy, ap);
2690 }
2691
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2692 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
2693 const DeviceMemory<float> &a, int lda,
2694 const DeviceMemory<float> &x, int incx, float beta,
2695 DeviceMemory<float> *y, int incy) {
2696 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2697 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2698
2699 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2700 int, const DeviceMemory<float> &, int, float,
2701 DeviceMemory<float> *, int>
2702 impl;
2703 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2704 incx, beta, y, incy);
2705 }
2706
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2707 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
2708 const DeviceMemory<double> &a, int lda,
2709 const DeviceMemory<double> &x, int incx,
2710 double beta, DeviceMemory<double> *y, int incy) {
2711 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2712 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2713
2714 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2715 int, const DeviceMemory<double> &, int, double,
2716 DeviceMemory<double> *, int>
2717 impl;
2718 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2719 incx, beta, y, incy);
2720 }
2721
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)2722 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
2723 const DeviceMemory<float> &x, int incx,
2724 DeviceMemory<float> *a, int lda) {
2725 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2726 PARAM(a), PARAM(lda));
2727
2728 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2729 int, DeviceMemory<float> *, int>
2730 impl;
2731 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2732 lda);
2733 }
2734
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)2735 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
2736 const DeviceMemory<double> &x, int incx,
2737 DeviceMemory<double> *a, int lda) {
2738 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2739 PARAM(a), PARAM(lda));
2740
2741 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2742 int, DeviceMemory<double> *, int>
2743 impl;
2744 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2745 lda);
2746 }
2747
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2748 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
2749 const DeviceMemory<float> &x, int incx,
2750 const DeviceMemory<float> &y, int incy,
2751 DeviceMemory<float> *a, int lda) {
2752 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2753 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2754
2755 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2756 int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2757 int>
2758 impl;
2759 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2760 incy, a, lda);
2761 }
2762
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2763 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
2764 const DeviceMemory<double> &x, int incx,
2765 const DeviceMemory<double> &y, int incy,
2766 DeviceMemory<double> *a, int lda) {
2767 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2768 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2769
2770 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2771 int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
2772 int>
2773 impl;
2774 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2775 incy, a, lda);
2776 }
2777
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2778 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2779 blas::Diagonal diag, uint64 n, uint64 k,
2780 const DeviceMemory<float> &a, int lda,
2781 DeviceMemory<float> *x, int incx) {
2782 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2783 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2784
2785 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2786 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2787 int>
2788 impl;
2789 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2790 lda, x, incx);
2791 }
2792
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2793 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2794 blas::Diagonal diag, uint64 n, uint64 k,
2795 const DeviceMemory<double> &a, int lda,
2796 DeviceMemory<double> *x, int incx) {
2797 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2798 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2799
2800 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2801 uint64, const DeviceMemory<double> &, int,
2802 DeviceMemory<double> *, int>
2803 impl;
2804 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2805 lda, x, incx);
2806 }
2807
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2808 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2809 blas::Diagonal diag, uint64 n, uint64 k,
2810 const DeviceMemory<std::complex<float>> &a,
2811 int lda, DeviceMemory<std::complex<float>> *x,
2812 int incx) {
2813 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2814 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2815
2816 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2817 uint64, const DeviceMemory<std::complex<float>> &, int,
2818 DeviceMemory<std::complex<float>> *, int>
2819 impl;
2820 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2821 lda, x, incx);
2822 }
2823
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2824 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2825 blas::Diagonal diag, uint64 n, uint64 k,
2826 const DeviceMemory<std::complex<double>> &a,
2827 int lda, DeviceMemory<std::complex<double>> *x,
2828 int incx) {
2829 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2830 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2831
2832 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2833 uint64, const DeviceMemory<std::complex<double>> &, int,
2834 DeviceMemory<std::complex<double>> *, int>
2835 impl;
2836 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2837 lda, x, incx);
2838 }
2839
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2840 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2841 blas::Diagonal diag, uint64 n, uint64 k,
2842 const DeviceMemory<float> &a, int lda,
2843 DeviceMemory<float> *x, int incx) {
2844 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2845 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2846
2847 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2848 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2849 int>
2850 impl;
2851 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2852 lda, x, incx);
2853 }
2854
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2855 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2856 blas::Diagonal diag, uint64 n, uint64 k,
2857 const DeviceMemory<double> &a, int lda,
2858 DeviceMemory<double> *x, int incx) {
2859 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2860 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2861
2862 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2863 uint64, const DeviceMemory<double> &, int,
2864 DeviceMemory<double> *, int>
2865 impl;
2866 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2867 lda, x, incx);
2868 }
2869
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2870 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2871 blas::Diagonal diag, uint64 n, uint64 k,
2872 const DeviceMemory<std::complex<float>> &a,
2873 int lda, DeviceMemory<std::complex<float>> *x,
2874 int incx) {
2875 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2876 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2877
2878 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2879 uint64, const DeviceMemory<std::complex<float>> &, int,
2880 DeviceMemory<std::complex<float>> *, int>
2881 impl;
2882 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2883 lda, x, incx);
2884 }
2885
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2886 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2887 blas::Diagonal diag, uint64 n, uint64 k,
2888 const DeviceMemory<std::complex<double>> &a,
2889 int lda, DeviceMemory<std::complex<double>> *x,
2890 int incx) {
2891 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2892 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2893
2894 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2895 uint64, const DeviceMemory<std::complex<double>> &, int,
2896 DeviceMemory<std::complex<double>> *, int>
2897 impl;
2898 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2899 lda, x, incx);
2900 }
2901
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2902 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2903 blas::Diagonal diag, uint64 n,
2904 const DeviceMemory<float> &ap,
2905 DeviceMemory<float> *x, int incx) {
2906 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2907 PARAM(x), PARAM(incx));
2908
2909 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2910 const DeviceMemory<float> &, DeviceMemory<float> *, int>
2911 impl;
2912 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2913 incx);
2914 }
2915
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2916 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2917 blas::Diagonal diag, uint64 n,
2918 const DeviceMemory<double> &ap,
2919 DeviceMemory<double> *x, int incx) {
2920 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2921 PARAM(x), PARAM(incx));
2922
2923 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2924 const DeviceMemory<double> &, DeviceMemory<double> *, int>
2925 impl;
2926 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2927 incx);
2928 }
2929
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2930 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2931 blas::Diagonal diag, uint64 n,
2932 const DeviceMemory<std::complex<float>> &ap,
2933 DeviceMemory<std::complex<float>> *x, int incx) {
2934 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2935 PARAM(x), PARAM(incx));
2936
2937 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2938 const DeviceMemory<std::complex<float>> &,
2939 DeviceMemory<std::complex<float>> *, int>
2940 impl;
2941 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2942 incx);
2943 }
2944
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2945 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2946 blas::Diagonal diag, uint64 n,
2947 const DeviceMemory<std::complex<double>> &ap,
2948 DeviceMemory<std::complex<double>> *x, int incx) {
2949 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2950 PARAM(x), PARAM(incx));
2951
2952 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2953 const DeviceMemory<std::complex<double>> &,
2954 DeviceMemory<std::complex<double>> *, int>
2955 impl;
2956 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2957 incx);
2958 }
2959
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2960 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2961 blas::Diagonal diag, uint64 n,
2962 const DeviceMemory<float> &ap,
2963 DeviceMemory<float> *x, int incx) {
2964 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2965 PARAM(x), PARAM(incx));
2966
2967 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2968 const DeviceMemory<float> &, DeviceMemory<float> *, int>
2969 impl;
2970 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2971 incx);
2972 }
2973
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2974 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2975 blas::Diagonal diag, uint64 n,
2976 const DeviceMemory<double> &ap,
2977 DeviceMemory<double> *x, int incx) {
2978 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2979 PARAM(x), PARAM(incx));
2980
2981 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2982 const DeviceMemory<double> &, DeviceMemory<double> *, int>
2983 impl;
2984 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2985 incx);
2986 }
2987
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2988 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2989 blas::Diagonal diag, uint64 n,
2990 const DeviceMemory<std::complex<float>> &ap,
2991 DeviceMemory<std::complex<float>> *x, int incx) {
2992 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2993 PARAM(x), PARAM(incx));
2994
2995 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2996 const DeviceMemory<std::complex<float>> &,
2997 DeviceMemory<std::complex<float>> *, int>
2998 impl;
2999 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3000 incx);
3001 }
3002
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3003 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3004 blas::Diagonal diag, uint64 n,
3005 const DeviceMemory<std::complex<double>> &ap,
3006 DeviceMemory<std::complex<double>> *x, int incx) {
3007 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3008 PARAM(x), PARAM(incx));
3009
3010 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3011 const DeviceMemory<std::complex<double>> &,
3012 DeviceMemory<std::complex<double>> *, int>
3013 impl;
3014 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3015 incx);
3016 }
3017
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3018 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3019 blas::Diagonal diag, uint64 n,
3020 const DeviceMemory<float> &a, int lda,
3021 DeviceMemory<float> *x, int incx) {
3022 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3023 PARAM(lda), PARAM(x), PARAM(incx));
3024
3025 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3026 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
3027 impl;
3028 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3029 lda, x, incx);
3030 }
3031
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3032 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3033 blas::Diagonal diag, uint64 n,
3034 const DeviceMemory<double> &a, int lda,
3035 DeviceMemory<double> *x, int incx) {
3036 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3037 PARAM(lda), PARAM(x), PARAM(incx));
3038
3039 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3040 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
3041 impl;
3042 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3043 lda, x, incx);
3044 }
3045
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3046 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3047 blas::Diagonal diag, uint64 n,
3048 const DeviceMemory<std::complex<float>> &a,
3049 int lda, DeviceMemory<std::complex<float>> *x,
3050 int incx) {
3051 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3052 PARAM(lda), PARAM(x), PARAM(incx));
3053
3054 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3055 const DeviceMemory<std::complex<float>> &, int,
3056 DeviceMemory<std::complex<float>> *, int>
3057 impl;
3058 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3059 lda, x, incx);
3060 }
3061
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3062 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3063 blas::Diagonal diag, uint64 n,
3064 const DeviceMemory<std::complex<double>> &a,
3065 int lda, DeviceMemory<std::complex<double>> *x,
3066 int incx) {
3067 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3068 PARAM(lda), PARAM(x), PARAM(incx));
3069
3070 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3071 const DeviceMemory<std::complex<double>> &, int,
3072 DeviceMemory<std::complex<double>> *, int>
3073 impl;
3074 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3075 lda, x, incx);
3076 }
3077
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3078 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3079 blas::Diagonal diag, uint64 n,
3080 const DeviceMemory<float> &a, int lda,
3081 DeviceMemory<float> *x, int incx) {
3082 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3083 PARAM(lda), PARAM(x), PARAM(incx));
3084
3085 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3086 const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
3087 impl;
3088 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3089 lda, x, incx);
3090 }
3091
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3092 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3093 blas::Diagonal diag, uint64 n,
3094 const DeviceMemory<double> &a, int lda,
3095 DeviceMemory<double> *x, int incx) {
3096 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3097 PARAM(lda), PARAM(x), PARAM(incx));
3098
3099 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3100 const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
3101 impl;
3102 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3103 lda, x, incx);
3104 }
3105
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3106 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3107 blas::Diagonal diag, uint64 n,
3108 const DeviceMemory<std::complex<float>> &a,
3109 int lda, DeviceMemory<std::complex<float>> *x,
3110 int incx) {
3111 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3112 PARAM(lda), PARAM(x), PARAM(incx));
3113
3114 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3115 const DeviceMemory<std::complex<float>> &, int,
3116 DeviceMemory<std::complex<float>> *, int>
3117 impl;
3118 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3119 lda, x, incx);
3120 }
3121
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3122 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3123 blas::Diagonal diag, uint64 n,
3124 const DeviceMemory<std::complex<double>> &a,
3125 int lda, DeviceMemory<std::complex<double>> *x,
3126 int incx) {
3127 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3128 PARAM(lda), PARAM(x), PARAM(incx));
3129
3130 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3131 const DeviceMemory<std::complex<double>> &, int,
3132 DeviceMemory<std::complex<double>> *, int>
3133 impl;
3134 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3135 lda, x, incx);
3136 }
3137
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3138 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3139 uint64 m, uint64 n, uint64 k, float alpha,
3140 const DeviceMemory<Eigen::half> &a, int lda,
3141 const DeviceMemory<Eigen::half> &b, int ldb,
3142 float beta, DeviceMemory<Eigen::half> *c,
3143 int ldc) {
3144 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3145 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3146 PARAM(beta), PARAM(c), PARAM(ldc));
3147
3148 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3149 const DeviceMemory<Eigen::half> &, int,
3150 const DeviceMemory<Eigen::half> &, int, float,
3151 DeviceMemory<Eigen::half> *, int>
3152 impl;
3153 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3154 alpha, a, lda, b, ldb, beta, c, ldc);
3155 }
3156
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3157 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3158 uint64 m, uint64 n, uint64 k, float alpha,
3159 const DeviceMemory<float> &a, int lda,
3160 const DeviceMemory<float> &b, int ldb, float beta,
3161 DeviceMemory<float> *c, int ldc) {
3162 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3163 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3164 PARAM(beta), PARAM(c), PARAM(ldc));
3165
3166 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3167 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3168 int, float, DeviceMemory<float> *, int>
3169 impl;
3170 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3171 alpha, a, lda, b, ldb, beta, c, ldc);
3172 }
3173
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3174 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3175 uint64 m, uint64 n, uint64 k, double alpha,
3176 const DeviceMemory<double> &a, int lda,
3177 const DeviceMemory<double> &b, int ldb,
3178 double beta, DeviceMemory<double> *c, int ldc) {
3179 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3180 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3181 PARAM(beta), PARAM(c), PARAM(ldc));
3182
3183 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3184 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3185 int, double, DeviceMemory<double> *, int>
3186 impl;
3187 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3188 alpha, a, lda, b, ldb, beta, c, ldc);
3189 }
3190
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3191 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3192 uint64 m, uint64 n, uint64 k,
3193 std::complex<float> alpha,
3194 const DeviceMemory<std::complex<float>> &a,
3195 int lda,
3196 const DeviceMemory<std::complex<float>> &b,
3197 int ldb, std::complex<float> beta,
3198 DeviceMemory<std::complex<float>> *c, int ldc) {
3199 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3200 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3201 PARAM(beta), PARAM(c), PARAM(ldc));
3202
3203 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3204 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3205 int, const DeviceMemory<std::complex<float>> &, int,
3206 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3207 impl;
3208 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3209 alpha, a, lda, b, ldb, beta, c, ldc);
3210 }
3211
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3212 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3213 uint64 m, uint64 n, uint64 k,
3214 std::complex<double> alpha,
3215 const DeviceMemory<std::complex<double>> &a,
3216 int lda,
3217 const DeviceMemory<std::complex<double>> &b,
3218 int ldb, std::complex<double> beta,
3219 DeviceMemory<std::complex<double>> *c, int ldc) {
3220 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3221 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3222 PARAM(beta), PARAM(c), PARAM(ldc));
3223
3224 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3225 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3226 int, const DeviceMemory<std::complex<double>> &, int,
3227 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3228 impl;
3229 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3230 alpha, a, lda, b, ldb, beta, c, ldc);
3231 }
3232
3233 namespace {
3234 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3235 // blas::ProfileResult*. This functor doesn't put the stream into an error
3236 // state if the op fails and the profile result is non-null. Instead, the
3237 // error-ness is returned in the profile result itself.
3238 template <typename... Args>
3239 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon852f14a50211::ThenBlasWithProfileImpl3240 Stream &operator()(Stream *stream,
3241 bool (blas::BlasSupport::*blas_func)(
3242 Stream *, Args..., blas::ProfileResult *),
3243 Args... args, blas::ProfileResult *profile_result) {
3244 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3245 bool record_error = profile_result == nullptr;
3246 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3247 }
3248 };
3249 } // anonymous namespace
3250
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3251 Stream &Stream::ThenBlasGemvWithProfiling(
3252 blas::Transpose trans, uint64 m, uint64 n, float alpha,
3253 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3254 int incx, float beta, DeviceMemory<float> *y, int incy,
3255 blas::ProfileResult *output_profile_result) {
3256 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3257 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3258 PARAM(incy));
3259
3260 ThenBlasWithProfileImpl<
3261 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3262 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3263 impl;
3264 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3265 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3266 }
3267
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3268 Stream &Stream::ThenBlasGemvWithProfiling(
3269 blas::Transpose trans, uint64 m, uint64 n, double alpha,
3270 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3271 int incx, double beta, DeviceMemory<double> *y, int incy,
3272 blas::ProfileResult *output_profile_result) {
3273 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3274 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3275 PARAM(incy));
3276
3277 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3278 const DeviceMemory<double> &, int,
3279 const DeviceMemory<double> &, int, double,
3280 DeviceMemory<double> *, int>
3281 impl;
3282 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3283 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3284 }
3285
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3286 Stream &Stream::ThenBlasGemvWithProfiling(
3287 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3288 const DeviceMemory<std::complex<float>> &a, int lda,
3289 const DeviceMemory<std::complex<float>> &x, int incx,
3290 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3291 blas::ProfileResult *output_profile_result) {
3292 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3293 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3294 PARAM(incy));
3295
3296 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3297 const DeviceMemory<std::complex<float>> &, int,
3298 const DeviceMemory<std::complex<float>> &, int,
3299 std::complex<float>,
3300 DeviceMemory<std::complex<float>> *, int>
3301 impl;
3302 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3303 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3304 }
3305
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3306 Stream &Stream::ThenBlasGemvWithProfiling(
3307 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3308 const DeviceMemory<std::complex<double>> &a, int lda,
3309 const DeviceMemory<std::complex<double>> &x, int incx,
3310 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3311 blas::ProfileResult *output_profile_result) {
3312 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3313 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3314 PARAM(incy));
3315
3316 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3317 const DeviceMemory<std::complex<double>> &, int,
3318 const DeviceMemory<std::complex<double>> &, int,
3319 std::complex<double>,
3320 DeviceMemory<std::complex<double>> *, int>
3321 impl;
3322 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3323 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3324 }
3325
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3326 Stream &Stream::ThenBlasGemmWithProfiling(
3327 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3328 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3329 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3330 DeviceMemory<Eigen::half> *c, int ldc,
3331 blas::ProfileResult *output_profile_result) {
3332 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3333 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3334 PARAM(beta), PARAM(c), PARAM(ldc));
3335
3336 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3337 uint64, float, const DeviceMemory<Eigen::half> &, int,
3338 const DeviceMemory<Eigen::half> &, int, float,
3339 DeviceMemory<Eigen::half> *, int>
3340 impl;
3341 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3342 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3343 output_profile_result);
3344 }
3345
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3346 Stream &Stream::ThenBlasGemmWithProfiling(
3347 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3348 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3349 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3350 int ldc, blas::ProfileResult *output_profile_result) {
3351 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3352 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3353 PARAM(beta), PARAM(c), PARAM(ldc));
3354
3355 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3356 uint64, float, const DeviceMemory<float> &, int,
3357 const DeviceMemory<float> &, int, float,
3358 DeviceMemory<float> *, int>
3359 impl;
3360 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3361 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3362 output_profile_result);
3363 }
3364
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3365 Stream &Stream::ThenBlasGemmWithProfiling(
3366 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3367 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3368 const DeviceMemory<double> &b, int ldb, double beta,
3369 DeviceMemory<double> *c, int ldc,
3370 blas::ProfileResult *output_profile_result) {
3371 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3372 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3373 PARAM(beta), PARAM(c), PARAM(ldc));
3374
3375 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3376 uint64, double, const DeviceMemory<double> &, int,
3377 const DeviceMemory<double> &, int, double,
3378 DeviceMemory<double> *, int>
3379 impl;
3380 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3381 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3382 output_profile_result);
3383 }
3384
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3385 Stream &Stream::ThenBlasGemmWithProfiling(
3386 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3387 uint64 k, std::complex<float> alpha,
3388 const DeviceMemory<std::complex<float>> &a, int lda,
3389 const DeviceMemory<std::complex<float>> &b, int ldb,
3390 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3391 blas::ProfileResult *output_profile_result) {
3392 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3393 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3394 PARAM(beta), PARAM(c), PARAM(ldc));
3395
3396 ThenBlasWithProfileImpl<
3397 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3398 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3399 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3400 DeviceMemory<std::complex<float>> *, int>
3401 impl;
3402 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3403 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3404 output_profile_result);
3405 }
3406
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3407 Stream &Stream::ThenBlasGemmWithProfiling(
3408 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3409 uint64 k, std::complex<double> alpha,
3410 const DeviceMemory<std::complex<double>> &a, int lda,
3411 const DeviceMemory<std::complex<double>> &b, int ldb,
3412 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3413 blas::ProfileResult *output_profile_result) {
3414 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3415 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3416 PARAM(beta), PARAM(c), PARAM(ldc));
3417
3418 ThenBlasWithProfileImpl<
3419 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3420 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3421 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3422 DeviceMemory<std::complex<double>> *, int>
3423 impl;
3424 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3425 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3426 output_profile_result);
3427 }
3428
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3429 Stream &Stream::ThenBlasGemmWithAlgorithm(
3430 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3431 uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3432 const DeviceMemory<Eigen::half> &a, int lda,
3433 const DeviceMemory<Eigen::half> &b, int ldb,
3434 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3435 int ldc, blas::ComputationType computation_type,
3436 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3437 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3438 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3439 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3440 PARAM(algorithm));
3441
3442 ThenBlasWithProfileImpl<
3443 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3444 const HostOrDeviceScalar<Eigen::half> &,
3445 const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3446 int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3447 int, blas::ComputationType, blas::AlgorithmType>
3448 impl;
3449 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3450 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3451 algorithm, output_profile_result);
3452 }
3453
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3454 Stream &Stream::ThenBlasGemmWithAlgorithm(
3455 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3456 uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3457 int lda, const DeviceMemory<int8> &b, int ldb,
3458 const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3459 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3460 blas::ProfileResult *output_profile_result) {
3461 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3462 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3463 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3464 PARAM(algorithm));
3465
3466 ThenBlasWithProfileImpl<
3467 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3468 const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3469 const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3470 DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3471 impl;
3472 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3473 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3474 algorithm, output_profile_result);
3475 }
3476
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3477 Stream &Stream::ThenBlasGemmWithAlgorithm(
3478 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3479 uint64 k, const HostOrDeviceScalar<float> &alpha,
3480 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3481 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3482 int ldc, blas::ComputationType computation_type,
3483 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3484 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3485 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3486 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3487 PARAM(algorithm));
3488
3489 ThenBlasWithProfileImpl<
3490 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3491 const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3492 const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3493 DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3494 impl;
3495 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3496 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3497 algorithm, output_profile_result);
3498 }
3499
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3500 Stream &Stream::ThenBlasGemmWithAlgorithm(
3501 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3502 uint64 k, const HostOrDeviceScalar<double> &alpha,
3503 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
3504 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
3505 int ldc, blas::ComputationType computation_type,
3506 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3507 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3508 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3509 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3510 PARAM(algorithm));
3511
3512 ThenBlasWithProfileImpl<
3513 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3514 const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
3515 const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
3516 DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
3517 impl;
3518 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3519 m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
3520 HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
3521 algorithm, output_profile_result);
3522 }
3523
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3524 Stream &Stream::ThenBlasGemmWithAlgorithm(
3525 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3526 uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
3527 const DeviceMemory<std::complex<float>> &a, int lda,
3528 const DeviceMemory<std::complex<float>> &b, int ldb,
3529 const HostOrDeviceScalar<std::complex<float>> &beta,
3530 DeviceMemory<std::complex<float>> *c, int ldc,
3531 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3532 blas::ProfileResult *output_profile_result) {
3533 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3534 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3535 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3536 PARAM(algorithm));
3537
3538 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3539 uint64,
3540 const HostOrDeviceScalar<std::complex<float>> &,
3541 const DeviceMemory<std::complex<float>> &, int,
3542 const DeviceMemory<std::complex<float>> &, int,
3543 const HostOrDeviceScalar<std::complex<float>> &,
3544 DeviceMemory<std::complex<float>> *, int,
3545 blas::ComputationType, blas::AlgorithmType>
3546 impl;
3547 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3548 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3549 algorithm, output_profile_result);
3550 }
3551
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3552 Stream &Stream::ThenBlasGemmWithAlgorithm(
3553 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3554 uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
3555 const DeviceMemory<std::complex<double>> &a, int lda,
3556 const DeviceMemory<std::complex<double>> &b, int ldb,
3557 const HostOrDeviceScalar<std::complex<double>> &beta,
3558 DeviceMemory<std::complex<double>> *c, int ldc,
3559 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3560 blas::ProfileResult *output_profile_result) {
3561 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3562 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3563 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3564 PARAM(algorithm));
3565
3566 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3567 uint64,
3568 const HostOrDeviceScalar<std::complex<double>> &,
3569 const DeviceMemory<std::complex<double>> &, int,
3570 const DeviceMemory<std::complex<double>> &, int,
3571 const HostOrDeviceScalar<std::complex<double>> &,
3572 DeviceMemory<std::complex<double>> *, int,
3573 blas::ComputationType, blas::AlgorithmType>
3574 impl;
3575 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3576 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3577 algorithm, output_profile_result);
3578 }
3579
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3580 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3581 uint64 n, std::complex<float> alpha,
3582 const DeviceMemory<std::complex<float>> &a,
3583 int lda,
3584 const DeviceMemory<std::complex<float>> &b,
3585 int ldb, std::complex<float> beta,
3586 DeviceMemory<std::complex<float>> *c, int ldc) {
3587 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3588 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3589 PARAM(ldc));
3590
3591 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3592 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3593 int, const DeviceMemory<std::complex<float>> &, int,
3594 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3595 impl;
3596 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3597 lda, b, ldb, beta, c, ldc);
3598 }
3599
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3600 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3601 uint64 n, std::complex<double> alpha,
3602 const DeviceMemory<std::complex<double>> &a,
3603 int lda,
3604 const DeviceMemory<std::complex<double>> &b,
3605 int ldb, std::complex<double> beta,
3606 DeviceMemory<std::complex<double>> *c, int ldc) {
3607 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3608 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3609 PARAM(ldc));
3610
3611 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3612 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3613 int, const DeviceMemory<std::complex<double>> &, int,
3614 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3615 impl;
3616 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3617 lda, b, ldb, beta, c, ldc);
3618 }
3619
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3620 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3621 uint64 n, uint64 k, float alpha,
3622 const DeviceMemory<std::complex<float>> &a,
3623 int lda, float beta,
3624 DeviceMemory<std::complex<float>> *c, int ldc) {
3625 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3626 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3627
3628 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3629 const DeviceMemory<std::complex<float>> &, int, float,
3630 DeviceMemory<std::complex<float>> *, int>
3631 impl;
3632 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3633 lda, beta, c, ldc);
3634 }
3635
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3636 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3637 uint64 n, uint64 k, double alpha,
3638 const DeviceMemory<std::complex<double>> &a,
3639 int lda, double beta,
3640 DeviceMemory<std::complex<double>> *c, int ldc) {
3641 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3642 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3643
3644 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3645 const DeviceMemory<std::complex<double>> &, int, double,
3646 DeviceMemory<std::complex<double>> *, int>
3647 impl;
3648 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3649 lda, beta, c, ldc);
3650 }
3651
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3652 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3653 uint64 n, uint64 k, std::complex<float> alpha,
3654 const DeviceMemory<std::complex<float>> &a,
3655 int lda,
3656 const DeviceMemory<std::complex<float>> &b,
3657 int ldb, float beta,
3658 DeviceMemory<std::complex<float>> *c, int ldc) {
3659 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3660 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3661 PARAM(ldc));
3662
3663 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3664 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3665 int, const DeviceMemory<std::complex<float>> &, int, float,
3666 DeviceMemory<std::complex<float>> *, int>
3667 impl;
3668 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3669 a, lda, b, ldb, beta, c, ldc);
3670 }
3671
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3672 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3673 uint64 n, uint64 k, std::complex<double> alpha,
3674 const DeviceMemory<std::complex<double>> &a,
3675 int lda,
3676 const DeviceMemory<std::complex<double>> &b,
3677 int ldb, double beta,
3678 DeviceMemory<std::complex<double>> *c, int ldc) {
3679 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3680 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3681 PARAM(ldc));
3682
3683 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3684 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3685 int, const DeviceMemory<std::complex<double>> &, int, double,
3686 DeviceMemory<std::complex<double>> *, int>
3687 impl;
3688 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3689 a, lda, b, ldb, beta, c, ldc);
3690 }
3691
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3692 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3693 uint64 n, float alpha,
3694 const DeviceMemory<float> &a, int lda,
3695 const DeviceMemory<float> &b, int ldb, float beta,
3696 DeviceMemory<float> *c, int ldc) {
3697 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3698 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3699 PARAM(ldc));
3700
3701 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
3702 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3703 int, float, DeviceMemory<float> *, int>
3704 impl;
3705 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3706 lda, b, ldb, beta, c, ldc);
3707 }
3708
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3709 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3710 uint64 n, double alpha,
3711 const DeviceMemory<double> &a, int lda,
3712 const DeviceMemory<double> &b, int ldb,
3713 double beta, DeviceMemory<double> *c, int ldc) {
3714 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3715 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3716 PARAM(ldc));
3717
3718 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
3719 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3720 int, double, DeviceMemory<double> *, int>
3721 impl;
3722 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3723 lda, b, ldb, beta, c, ldc);
3724 }
3725
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3726 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3727 uint64 n, std::complex<float> alpha,
3728 const DeviceMemory<std::complex<float>> &a,
3729 int lda,
3730 const DeviceMemory<std::complex<float>> &b,
3731 int ldb, std::complex<float> beta,
3732 DeviceMemory<std::complex<float>> *c, int ldc) {
3733 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3734 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3735 PARAM(ldc));
3736
3737 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3738 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3739 int, const DeviceMemory<std::complex<float>> &, int,
3740 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3741 impl;
3742 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3743 lda, b, ldb, beta, c, ldc);
3744 }
3745
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3746 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3747 uint64 n, std::complex<double> alpha,
3748 const DeviceMemory<std::complex<double>> &a,
3749 int lda,
3750 const DeviceMemory<std::complex<double>> &b,
3751 int ldb, std::complex<double> beta,
3752 DeviceMemory<std::complex<double>> *c, int ldc) {
3753 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3754 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3755 PARAM(ldc));
3756
3757 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3758 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3759 int, const DeviceMemory<std::complex<double>> &, int,
3760 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3761 impl;
3762 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3763 lda, b, ldb, beta, c, ldc);
3764 }
3765
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)3766 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3767 uint64 n, uint64 k, float alpha,
3768 const DeviceMemory<float> &a, int lda, float beta,
3769 DeviceMemory<float> *c, int ldc) {
3770 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3771 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3772
3773 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3774 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3775 int>
3776 impl;
3777 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3778 lda, beta, c, ldc);
3779 }
3780
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)3781 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3782 uint64 n, uint64 k, double alpha,
3783 const DeviceMemory<double> &a, int lda,
3784 double beta, DeviceMemory<double> *c, int ldc) {
3785 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3786 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3787
3788 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3789 const DeviceMemory<double> &, int, double,
3790 DeviceMemory<double> *, int>
3791 impl;
3792 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3793 lda, beta, c, ldc);
3794 }
3795
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3796 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3797 uint64 n, uint64 k, std::complex<float> alpha,
3798 const DeviceMemory<std::complex<float>> &a,
3799 int lda, std::complex<float> beta,
3800 DeviceMemory<std::complex<float>> *c, int ldc) {
3801 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3802 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3803
3804 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3805 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3806 int, std::complex<float>, DeviceMemory<std::complex<float>> *,
3807 int>
3808 impl;
3809 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3810 lda, beta, c, ldc);
3811 }
3812
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3813 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3814 uint64 n, uint64 k, std::complex<double> alpha,
3815 const DeviceMemory<std::complex<double>> &a,
3816 int lda, std::complex<double> beta,
3817 DeviceMemory<std::complex<double>> *c, int ldc) {
3818 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3819 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3820
3821 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3822 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3823 int, std::complex<double>, DeviceMemory<std::complex<double>> *,
3824 int>
3825 impl;
3826 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3827 lda, beta, c, ldc);
3828 }
3829
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3830 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3831 uint64 n, uint64 k, float alpha,
3832 const DeviceMemory<float> &a, int lda,
3833 const DeviceMemory<float> &b, int ldb, float beta,
3834 DeviceMemory<float> *c, int ldc) {
3835 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3836 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3837 PARAM(ldc));
3838
3839 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3840 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3841 int, float, DeviceMemory<float> *, int>
3842 impl;
3843 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3844 a, lda, b, ldb, beta, c, ldc);
3845 }
3846
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3847 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3848 uint64 n, uint64 k, double alpha,
3849 const DeviceMemory<double> &a, int lda,
3850 const DeviceMemory<double> &b, int ldb,
3851 double beta, DeviceMemory<double> *c, int ldc) {
3852 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3853 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3854 PARAM(ldc));
3855
3856 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3857 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3858 int, double, DeviceMemory<double> *, int>
3859 impl;
3860 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3861 a, lda, b, ldb, beta, c, ldc);
3862 }
3863
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3864 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3865 uint64 n, uint64 k, std::complex<float> alpha,
3866 const DeviceMemory<std::complex<float>> &a,
3867 int lda,
3868 const DeviceMemory<std::complex<float>> &b,
3869 int ldb, std::complex<float> beta,
3870 DeviceMemory<std::complex<float>> *c, int ldc) {
3871 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3872 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3873 PARAM(ldc));
3874
3875 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3876 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3877 int, const DeviceMemory<std::complex<float>> &, int,
3878 std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3879 impl;
3880 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3881 a, lda, b, ldb, beta, c, ldc);
3882 }
3883
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3884 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3885 uint64 n, uint64 k, std::complex<double> alpha,
3886 const DeviceMemory<std::complex<double>> &a,
3887 int lda,
3888 const DeviceMemory<std::complex<double>> &b,
3889 int ldb, std::complex<double> beta,
3890 DeviceMemory<std::complex<double>> *c, int ldc) {
3891 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3892 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3893 PARAM(ldc));
3894
3895 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3896 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3897 int, const DeviceMemory<std::complex<double>> &, int,
3898 std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3899 impl;
3900 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3901 a, lda, b, ldb, beta, c, ldc);
3902 }
3903
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3904 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3905 blas::Transpose transa, blas::Diagonal diag,
3906 uint64 m, uint64 n, float alpha,
3907 const DeviceMemory<float> &a, int lda,
3908 DeviceMemory<float> *b, int ldb) {
3909 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3910 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3911
3912 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3913 uint64, uint64, float, const DeviceMemory<float> &, int,
3914 DeviceMemory<float> *, int>
3915 impl;
3916 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3917 n, alpha, a, lda, b, ldb);
3918 }
3919
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3920 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3921 blas::Transpose transa, blas::Diagonal diag,
3922 uint64 m, uint64 n, double alpha,
3923 const DeviceMemory<double> &a, int lda,
3924 DeviceMemory<double> *b, int ldb) {
3925 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3926 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3927
3928 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3929 uint64, uint64, double, const DeviceMemory<double> &, int,
3930 DeviceMemory<double> *, int>
3931 impl;
3932 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3933 n, alpha, a, lda, b, ldb);
3934 }
3935
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3936 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3937 blas::Transpose transa, blas::Diagonal diag,
3938 uint64 m, uint64 n, std::complex<float> alpha,
3939 const DeviceMemory<std::complex<float>> &a,
3940 int lda, DeviceMemory<std::complex<float>> *b,
3941 int ldb) {
3942 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3943 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3944
3945 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3946 uint64, uint64, std::complex<float>,
3947 const DeviceMemory<std::complex<float>> &, int,
3948 DeviceMemory<std::complex<float>> *, int>
3949 impl;
3950 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3951 n, alpha, a, lda, b, ldb);
3952 }
3953
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3954 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3955 blas::Transpose transa, blas::Diagonal diag,
3956 uint64 m, uint64 n, std::complex<double> alpha,
3957 const DeviceMemory<std::complex<double>> &a,
3958 int lda, DeviceMemory<std::complex<double>> *b,
3959 int ldb) {
3960 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3961 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3962
3963 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3964 uint64, uint64, std::complex<double>,
3965 const DeviceMemory<std::complex<double>> &, int,
3966 DeviceMemory<std::complex<double>> *, int>
3967 impl;
3968 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3969 n, alpha, a, lda, b, ldb);
3970 }
3971
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3972 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3973 blas::Transpose transa, blas::Diagonal diag,
3974 uint64 m, uint64 n, float alpha,
3975 const DeviceMemory<float> &a, int lda,
3976 DeviceMemory<float> *b, int ldb) {
3977 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3978 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3979
3980 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3981 uint64, uint64, float, const DeviceMemory<float> &, int,
3982 DeviceMemory<float> *, int>
3983 impl;
3984 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3985 n, alpha, a, lda, b, ldb);
3986 }
3987
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3988 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3989 blas::Transpose transa, blas::Diagonal diag,
3990 uint64 m, uint64 n, double alpha,
3991 const DeviceMemory<double> &a, int lda,
3992 DeviceMemory<double> *b, int ldb) {
3993 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3994 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3995
3996 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3997 uint64, uint64, double, const DeviceMemory<double> &, int,
3998 DeviceMemory<double> *, int>
3999 impl;
4000 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4001 n, alpha, a, lda, b, ldb);
4002 }
4003
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4004 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4005 blas::Transpose transa, blas::Diagonal diag,
4006 uint64 m, uint64 n, std::complex<float> alpha,
4007 const DeviceMemory<std::complex<float>> &a,
4008 int lda, DeviceMemory<std::complex<float>> *b,
4009 int ldb) {
4010 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4011 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4012
4013 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4014 uint64, uint64, std::complex<float>,
4015 const DeviceMemory<std::complex<float>> &, int,
4016 DeviceMemory<std::complex<float>> *, int>
4017 impl;
4018 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4019 n, alpha, a, lda, b, ldb);
4020 }
4021
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4022 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4023 blas::Transpose transa, blas::Diagonal diag,
4024 uint64 m, uint64 n, std::complex<double> alpha,
4025 const DeviceMemory<std::complex<double>> &a,
4026 int lda, DeviceMemory<std::complex<double>> *b,
4027 int ldb) {
4028 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4029 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4030
4031 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4032 uint64, uint64, std::complex<double>,
4033 const DeviceMemory<std::complex<double>> &, int,
4034 DeviceMemory<std::complex<double>> *, int>
4035 impl;
4036 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4037 n, alpha, a, lda, b, ldb);
4038 }
4039
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4040 Stream &Stream::ThenBlasGemmBatched(
4041 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4042 uint64 k, float alpha,
4043 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4044 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4045 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4046 int batch_count) {
4047 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4048 b, ldb, beta, c, ldc, batch_count,
4049 /*scratch_allocator=*/nullptr);
4050 }
4051
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4052 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4053 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4054 uint64 k, float alpha,
4055 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4056 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4057 const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4058 int batch_count, ScratchAllocator *scratch_allocator) {
4059 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4060 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4061 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4062
4063 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4064 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4065 const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4066 float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4067 int, int, ScratchAllocator *>
4068 impl;
4069 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4070 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4071 scratch_allocator);
4072 }
4073
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4074 Stream &Stream::ThenBlasGemmBatched(
4075 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4076 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4077 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4078 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4079 int batch_count) {
4080 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4081 b, ldb, beta, c, ldc, batch_count,
4082 /*scratch_allocator=*/nullptr);
4083 }
4084
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4085 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4086 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4087 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4088 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4089 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4090 int batch_count, ScratchAllocator *scratch_allocator) {
4091 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4092 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4093 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4094
4095 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4096 const port::ArraySlice<DeviceMemory<float> *> &, int,
4097 const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4098 const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4099 ScratchAllocator *>
4100 impl;
4101 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4102 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4103 scratch_allocator);
4104 }
4105
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4106 Stream &Stream::ThenBlasGemmBatched(
4107 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4108 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4109 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4110 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4111 int batch_count) {
4112 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4113 b, ldb, beta, c, ldc, batch_count,
4114 /*scratch_allocator=*/nullptr);
4115 }
4116
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4117 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4118 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4119 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4120 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4121 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4122 int batch_count, ScratchAllocator *scratch_allocator) {
4123 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4124 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4125 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4126
4127 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4128 const port::ArraySlice<DeviceMemory<double> *> &, int,
4129 const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4130 const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4131 ScratchAllocator *>
4132 impl;
4133 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4134 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4135 scratch_allocator);
4136 }
4137
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4138 Stream &Stream::ThenBlasGemmBatched(
4139 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4140 uint64 k, std::complex<float> alpha,
4141 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4142 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4143 std::complex<float> beta,
4144 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4145 int batch_count) {
4146 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4147 b, ldb, beta, c, ldc, batch_count,
4148 /*scratch_allocator=*/nullptr);
4149 }
4150
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4151 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4152 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4153 uint64 k, std::complex<float> alpha,
4154 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4155 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4156 std::complex<float> beta,
4157 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4158 int batch_count, ScratchAllocator *scratch_allocator) {
4159 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4160 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4161 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4162
4163 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4164 std::complex<float>,
4165 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4166 int,
4167 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4168 int, std::complex<float>,
4169 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4170 int, int, ScratchAllocator *>
4171 impl;
4172 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4173 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4174 scratch_allocator);
4175 }
4176
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4177 Stream &Stream::ThenBlasGemmBatched(
4178 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4179 uint64 k, std::complex<double> alpha,
4180 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4181 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4182 std::complex<double> beta,
4183 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4184 int batch_count) {
4185 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4186 b, ldb, beta, c, ldc, batch_count,
4187 /*scratch_allocator=*/nullptr);
4188 }
4189
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4190 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4191 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4192 uint64 k, std::complex<double> alpha,
4193 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4194 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4195 std::complex<double> beta,
4196 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4197 int batch_count, ScratchAllocator *scratch_allocator) {
4198 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4199 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4200 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4201
4202 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4203 std::complex<double>,
4204 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4205 int,
4206 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4207 int, std::complex<double>,
4208 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4209 int, int, ScratchAllocator *>
4210 impl;
4211 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4212 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4213 scratch_allocator);
4214 }
4215
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4216 Stream &Stream::ThenBlasGemmStridedBatched(
4217 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4218 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4219 int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4220 float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4221 int batch_count) {
4222 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4223 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4224 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4225 PARAM(stride_c), PARAM(batch_count));
4226
4227 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4228 const DeviceMemory<Eigen::half> &, int, int64,
4229 const DeviceMemory<Eigen::half> &, int, int64, float,
4230 DeviceMemory<Eigen::half> *, int, int64, int>
4231 impl;
4232 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4233 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4234 c, ldc, stride_c, batch_count);
4235 }
4236
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4237 Stream &Stream::ThenBlasGemmStridedBatched(
4238 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4239 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4240 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4241 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4242 int batch_count) {
4243 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4244 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4245 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4246 PARAM(stride_c), PARAM(batch_count));
4247
4248 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4249 const DeviceMemory<float> &, int, int64,
4250 const DeviceMemory<float> &, int, int64, float,
4251 DeviceMemory<float> *, int, int64, int>
4252 impl;
4253 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4254 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4255 c, ldc, stride_c, batch_count);
4256 }
4257
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4258 Stream &Stream::ThenBlasGemmStridedBatched(
4259 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4260 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4261 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4262 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4263 int batch_count) {
4264 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4265 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4266 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4267 PARAM(stride_c), PARAM(batch_count));
4268
4269 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4270 const DeviceMemory<double> &, int, int64,
4271 const DeviceMemory<double> &, int, int64, double,
4272 DeviceMemory<double> *, int, int64, int>
4273 impl;
4274 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4275 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4276 c, ldc, stride_c, batch_count);
4277 }
4278
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4279 Stream &Stream::ThenBlasGemmStridedBatched(
4280 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4281 uint64 k, std::complex<float> alpha,
4282 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4283 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4284 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4285 int64 stride_c, int batch_count) {
4286 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4287 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4288 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4289 PARAM(stride_c), PARAM(batch_count));
4290
4291 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4292 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4293 int, int64, const DeviceMemory<std::complex<float>> &, int,
4294 int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4295 int, int64, int>
4296 impl;
4297 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4298 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4299 c, ldc, stride_c, batch_count);
4300 }
4301
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4302 Stream &Stream::ThenBlasGemmStridedBatched(
4303 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4304 uint64 k, std::complex<double> alpha,
4305 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4306 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4307 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4308 int64 stride_c, int batch_count) {
4309 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4310 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4311 PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4312 PARAM(stride_c), PARAM(batch_count));
4313
4314 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4315 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4316 int, int64, const DeviceMemory<std::complex<double>> &, int,
4317 int64, std::complex<double>,
4318 DeviceMemory<std::complex<double>> *, int, int64, int>
4319 impl;
4320 return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4321 transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4322 c, ldc, stride_c, batch_count);
4323 }
4324
4325 template <typename ABType, typename CType>
ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<CType> & alpha,const DeviceMemory<ABType> & a,const DeviceMemory<ABType> & b,const HostOrDeviceScalar<CType> & beta,DeviceMemory<CType> * c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,const DeviceMemory<CType> & bias,blas::ProfileResult * output_profile_result)4326 Stream &Stream::ThenBlasLtMatmulImpl(
4327 const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
4328 const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
4329 const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
4330 ScratchAllocator *scratch_allocator,
4331 const blas::IBlasLtMatmulAlgorithm *algorithm,
4332 const DeviceMemory<CType> &bias,
4333 blas::ProfileResult *output_profile_result) {
4334 VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
4335 PARAM(c), PARAM(algorithm), PARAM(bias));
4336
4337 ThenBlasWithProfileImpl<
4338 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
4339 const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
4340 const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
4341 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4342 const DeviceMemory<CType> &>
4343 impl;
4344 return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
4345 c, scratch_allocator, algorithm, bias, output_profile_result);
4346 }
4347
4348 // Explicit template instantiations for each supported type combination.
4349 template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
4350 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
4351 const DeviceMemory<int8> &, const DeviceMemory<int8> &,
4352 const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
4353 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4354 const DeviceMemory<int32> &, blas::ProfileResult *);
4355
4356 template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
4357 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
4358 const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
4359 const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
4360 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4361 const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
4362
4363 template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
4364 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
4365 const DeviceMemory<float> &, const DeviceMemory<float> &,
4366 const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
4367 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4368 const DeviceMemory<float> &, blas::ProfileResult *);
4369
4370 template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
4371 const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
4372 const DeviceMemory<double> &, const DeviceMemory<double> &,
4373 const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
4374 ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4375 const DeviceMemory<double> &, blas::ProfileResult *);
4376
4377 template Stream &
4378 Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
4379 const blas::IBlasLtMatmulPlan *,
4380 const HostOrDeviceScalar<std::complex<float>> &,
4381 const DeviceMemory<std::complex<float>> &,
4382 const DeviceMemory<std::complex<float>> &,
4383 const HostOrDeviceScalar<std::complex<float>> &,
4384 DeviceMemory<std::complex<float>> *, ScratchAllocator *,
4385 const blas::IBlasLtMatmulAlgorithm *,
4386 const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
4387
4388 template Stream &
4389 Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
4390 const blas::IBlasLtMatmulPlan *,
4391 const HostOrDeviceScalar<std::complex<double>> &,
4392 const DeviceMemory<std::complex<double>> &,
4393 const DeviceMemory<std::complex<double>> &,
4394 const HostOrDeviceScalar<std::complex<double>> &,
4395 DeviceMemory<std::complex<double>> *, ScratchAllocator *,
4396 const blas::IBlasLtMatmulAlgorithm *,
4397 const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
4398
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4399 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4400 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4401
4402 if (rng::RngSupport *rng = parent_->AsRng()) {
4403 CheckError(rng->SetSeed(this, seed, seed_bytes));
4404 } else {
4405 SetError();
4406 LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4407 }
4408 return *this;
4409 }
4410
ThenPopulateRandUniform(DeviceMemory<float> * values)4411 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4412 VLOG_CALL(PARAM(values));
4413
4414 if (rng::RngSupport *rng = parent_->AsRng()) {
4415 CheckError(rng->DoPopulateRandUniform(this, values));
4416 } else {
4417 SetError();
4418 LOG(INFO) << DebugStreamPointers()
4419 << " attempting to perform RNG operation using StreamExecutor"
4420 " without RNG support.";
4421 }
4422 return *this;
4423 }
4424
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4425 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4426 DeviceMemory<float> *values) {
4427 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4428
4429 if (rng::RngSupport *rng = parent_->AsRng()) {
4430 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4431 } else {
4432 SetError();
4433 LOG(INFO) << DebugStreamPointers()
4434 << " attempting to perform RNG operation using StreamExecutor"
4435 " without RNG support.";
4436 }
4437 return *this;
4438 }
4439
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4440 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4441 DeviceMemory<double> *values) {
4442 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4443
4444 if (rng::RngSupport *rng = parent_->AsRng()) {
4445 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4446 } else {
4447 SetError();
4448 LOG(INFO) << DebugStreamPointers()
4449 << " attempting to perform RNG operation using StreamExecutor"
4450 " without RNG support.";
4451 }
4452 return *this;
4453 }
4454
ThenPopulateRandUniform(DeviceMemory<double> * values)4455 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4456 VLOG_CALL(PARAM(values));
4457
4458 if (rng::RngSupport *rng = parent_->AsRng()) {
4459 CheckError(rng->DoPopulateRandUniform(this, values));
4460 } else {
4461 SetError();
4462 LOG(INFO) << DebugStreamPointers()
4463 << " attempting to perform RNG operation using StreamExecutor"
4464 " without RNG support.";
4465 }
4466 return *this;
4467 }
4468
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4469 Stream &Stream::ThenPopulateRandUniform(
4470 DeviceMemory<std::complex<float>> *values) {
4471 VLOG_CALL(PARAM(values));
4472
4473 if (rng::RngSupport *rng = parent_->AsRng()) {
4474 CheckError(rng->DoPopulateRandUniform(this, values));
4475 } else {
4476 SetError();
4477 LOG(INFO) << DebugStreamPointers()
4478 << " attempting to perform RNG operation using StreamExecutor"
4479 " without RNG support.";
4480 }
4481 return *this;
4482 }
4483
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4484 Stream &Stream::ThenPopulateRandUniform(
4485 DeviceMemory<std::complex<double>> *values) {
4486 VLOG_CALL(PARAM(values));
4487
4488 if (rng::RngSupport *rng = parent_->AsRng()) {
4489 CheckError(rng->DoPopulateRandUniform(this, values));
4490 } else {
4491 SetError();
4492 LOG(INFO) << DebugStreamPointers()
4493 << " attempting to perform RNG operation using StreamExecutor"
4494 " without RNG support.";
4495 }
4496 return *this;
4497 }
4498
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4499 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4500 uint64 size) {
4501 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4502
4503 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4504 return *this;
4505 }
4506
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4507 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4508 uint64 size) {
4509 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4510
4511 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4512 return *this;
4513 }
4514
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4515 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4516 const DeviceMemoryBase &gpu_src, uint64 size) {
4517 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4518
4519 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4520 return *this;
4521 }
4522
ThenMemZero(DeviceMemoryBase * location,uint64 size)4523 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4524 VLOG_CALL(PARAM(location), PARAM(size));
4525
4526 CheckStatus(parent_->MemZero(this, location, size));
4527 return *this;
4528 }
4529
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4530 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4531 uint64 size) {
4532 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4533
4534 CheckStatus(parent_->Memset32(this, location, pattern, size));
4535 return *this;
4536 }
4537
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4538 Stream &Stream::ThenRnnForward(
4539 const dnn::RnnDescriptor &rnn_desc,
4540 const dnn::RnnSequenceTensorDescriptor &input_desc,
4541 const DeviceMemory<Eigen::half> &input_data,
4542 const dnn::RnnStateTensorDescriptor &input_h_desc,
4543 const DeviceMemory<Eigen::half> &input_h_data,
4544 const dnn::RnnStateTensorDescriptor &input_c_desc,
4545 const DeviceMemory<Eigen::half> &input_c_data,
4546 const DeviceMemory<Eigen::half> ¶ms,
4547 const dnn::RnnSequenceTensorDescriptor &output_desc,
4548 DeviceMemory<Eigen::half> *output_data,
4549 const dnn::RnnStateTensorDescriptor &output_h_desc,
4550 DeviceMemory<Eigen::half> *output_h_data,
4551 const dnn::RnnStateTensorDescriptor &output_c_desc,
4552 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4553 ScratchAllocator *reserve_space_allocator,
4554 ScratchAllocator *workspace_allocator,
4555 dnn::ProfileResult *output_profile_result) {
4556 // TODO(zhengxq): add VLOG PARAM calls.
4557 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4558 auto status = dnn->DoRnnForward(
4559 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4560 input_c_desc, input_c_data, params, output_desc, output_data,
4561 output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4562 reserve_space_allocator, workspace_allocator, output_profile_result);
4563 if (!status && !output_profile_result) {
4564 SetError();
4565 }
4566 } else {
4567 SetErrorAndLogNoDnnSupport();
4568 }
4569 return *this;
4570 }
4571
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4572 Stream &Stream::ThenRnnForward(
4573 const dnn::RnnDescriptor &rnn_desc,
4574 const dnn::RnnSequenceTensorDescriptor &input_desc,
4575 const DeviceMemory<float> &input_data,
4576 const dnn::RnnStateTensorDescriptor &input_h_desc,
4577 const DeviceMemory<float> &input_h_data,
4578 const dnn::RnnStateTensorDescriptor &input_c_desc,
4579 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4580 const dnn::RnnSequenceTensorDescriptor &output_desc,
4581 DeviceMemory<float> *output_data,
4582 const dnn::RnnStateTensorDescriptor &output_h_desc,
4583 DeviceMemory<float> *output_h_data,
4584 const dnn::RnnStateTensorDescriptor &output_c_desc,
4585 DeviceMemory<float> *output_c_data, bool is_training,
4586 ScratchAllocator *reserve_space_allocator,
4587 ScratchAllocator *workspace_allocator,
4588 dnn::ProfileResult *output_profile_result) {
4589 // TODO(zhengxq): add VLOG PARAM calls.
4590 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4591 auto status = dnn->DoRnnForward(
4592 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4593 input_c_desc, input_c_data, params, output_desc, output_data,
4594 output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4595 reserve_space_allocator, workspace_allocator, output_profile_result);
4596 if (!status && !output_profile_result) {
4597 SetError();
4598 }
4599 } else {
4600 SetErrorAndLogNoDnnSupport();
4601 }
4602 return *this;
4603 }
4604
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4605 Stream &Stream::ThenRnnForward(
4606 const dnn::RnnDescriptor &rnn_desc,
4607 const dnn::RnnSequenceTensorDescriptor &input_desc,
4608 const DeviceMemory<double> &input_data,
4609 const dnn::RnnStateTensorDescriptor &input_h_desc,
4610 const DeviceMemory<double> &input_h_data,
4611 const dnn::RnnStateTensorDescriptor &input_c_desc,
4612 const DeviceMemory<double> &input_c_data,
4613 const DeviceMemory<double> ¶ms,
4614 const dnn::RnnSequenceTensorDescriptor &output_desc,
4615 DeviceMemory<double> *output_data,
4616 const dnn::RnnStateTensorDescriptor &output_h_desc,
4617 DeviceMemory<double> *output_h_data,
4618 const dnn::RnnStateTensorDescriptor &output_c_desc,
4619 DeviceMemory<double> *output_c_data, bool is_training,
4620 ScratchAllocator *reserve_space_allocator,
4621 ScratchAllocator *workspace_allocator,
4622 dnn::ProfileResult *output_profile_result) {
4623 // TODO(zhengxq): add VLOG PARAM calls.
4624 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4625 auto status = dnn->DoRnnForward(
4626 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4627 input_c_desc, input_c_data, params, output_desc, output_data,
4628 output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4629 reserve_space_allocator, workspace_allocator, output_profile_result);
4630 if (!status && !output_profile_result) {
4631 SetError();
4632 }
4633 } else {
4634 SetErrorAndLogNoDnnSupport();
4635 }
4636 return *this;
4637 }
4638
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4639 Stream &Stream::ThenRnnBackward(
4640 const dnn::RnnDescriptor &rnn_desc,
4641 const dnn::RnnSequenceTensorDescriptor &input_desc,
4642 const DeviceMemory<Eigen::half> &input_data,
4643 const dnn::RnnStateTensorDescriptor &input_h_desc,
4644 const DeviceMemory<Eigen::half> &input_h_data,
4645 const dnn::RnnStateTensorDescriptor &input_c_desc,
4646 const DeviceMemory<Eigen::half> &input_c_data,
4647 const DeviceMemory<Eigen::half> ¶ms,
4648 const dnn::RnnSequenceTensorDescriptor &output_desc,
4649 const DeviceMemory<Eigen::half> &output_data,
4650 const dnn::RnnStateTensorDescriptor &output_h_desc,
4651 const DeviceMemory<Eigen::half> &output_h_data,
4652 const dnn::RnnStateTensorDescriptor &output_c_desc,
4653 const DeviceMemory<Eigen::half> &output_c_data,
4654 const DeviceMemory<Eigen::half> &output_backprop_data,
4655 const DeviceMemory<Eigen::half> &output_h_backprop_data,
4656 const DeviceMemory<Eigen::half> &output_c_backprop_data,
4657 DeviceMemory<Eigen::half> *input_backprop_data,
4658 DeviceMemory<Eigen::half> *input_h_backprop_data,
4659 DeviceMemory<Eigen::half> *input_c_backprop_data,
4660 DeviceMemory<Eigen::half> *params_backprop_data,
4661 DeviceMemory<uint8> *reserve_space_data,
4662 ScratchAllocator *workspace_allocator,
4663 dnn::ProfileResult *output_profile_result) {
4664 // TODO(zhengxq): add VLOG PARAM calls.
4665 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4666 auto status = dnn->DoRnnBackward(
4667 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4668 input_c_desc, input_c_data, params, output_desc, output_data,
4669 output_h_desc, output_h_data, output_c_desc, output_c_data,
4670 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4671 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4672 params_backprop_data, reserve_space_data, workspace_allocator,
4673 output_profile_result);
4674 if (!status && !output_profile_result) {
4675 SetError();
4676 }
4677 } else {
4678 SetError();
4679 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4680 }
4681 return *this;
4682 }
4683
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4684 Stream &Stream::ThenRnnBackward(
4685 const dnn::RnnDescriptor &rnn_desc,
4686 const dnn::RnnSequenceTensorDescriptor &input_desc,
4687 const DeviceMemory<float> &input_data,
4688 const dnn::RnnStateTensorDescriptor &input_h_desc,
4689 const DeviceMemory<float> &input_h_data,
4690 const dnn::RnnStateTensorDescriptor &input_c_desc,
4691 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4692 const dnn::RnnSequenceTensorDescriptor &output_desc,
4693 const DeviceMemory<float> &output_data,
4694 const dnn::RnnStateTensorDescriptor &output_h_desc,
4695 const DeviceMemory<float> &output_h_data,
4696 const dnn::RnnStateTensorDescriptor &output_c_desc,
4697 const DeviceMemory<float> &output_c_data,
4698 const DeviceMemory<float> &output_backprop_data,
4699 const DeviceMemory<float> &output_h_backprop_data,
4700 const DeviceMemory<float> &output_c_backprop_data,
4701 DeviceMemory<float> *input_backprop_data,
4702 DeviceMemory<float> *input_h_backprop_data,
4703 DeviceMemory<float> *input_c_backprop_data,
4704 DeviceMemory<float> *params_backprop_data,
4705 DeviceMemory<uint8> *reserve_space_data,
4706 ScratchAllocator *workspace_allocator,
4707 dnn::ProfileResult *output_profile_result) {
4708 // TODO(zhengxq): add VLOG PARAM calls.
4709 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4710 auto status = dnn->DoRnnBackward(
4711 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4712 input_c_desc, input_c_data, params, output_desc, output_data,
4713 output_h_desc, output_h_data, output_c_desc, output_c_data,
4714 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4715 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4716 params_backprop_data, reserve_space_data, workspace_allocator,
4717 output_profile_result);
4718 if (!status && !output_profile_result) {
4719 SetError();
4720 }
4721 } else {
4722 SetError();
4723 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4724 }
4725 return *this;
4726 }
4727
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4728 Stream &Stream::ThenRnnBackward(
4729 const dnn::RnnDescriptor &rnn_desc,
4730 const dnn::RnnSequenceTensorDescriptor &input_desc,
4731 const DeviceMemory<double> &input_data,
4732 const dnn::RnnStateTensorDescriptor &input_h_desc,
4733 const DeviceMemory<double> &input_h_data,
4734 const dnn::RnnStateTensorDescriptor &input_c_desc,
4735 const DeviceMemory<double> &input_c_data,
4736 const DeviceMemory<double> ¶ms,
4737 const dnn::RnnSequenceTensorDescriptor &output_desc,
4738 const DeviceMemory<double> &output_data,
4739 const dnn::RnnStateTensorDescriptor &output_h_desc,
4740 const DeviceMemory<double> &output_h_data,
4741 const dnn::RnnStateTensorDescriptor &output_c_desc,
4742 const DeviceMemory<double> &output_c_data,
4743 const DeviceMemory<double> &output_backprop_data,
4744 const DeviceMemory<double> &output_h_backprop_data,
4745 const DeviceMemory<double> &output_c_backprop_data,
4746 DeviceMemory<double> *input_backprop_data,
4747 DeviceMemory<double> *input_h_backprop_data,
4748 DeviceMemory<double> *input_c_backprop_data,
4749 DeviceMemory<double> *params_backprop_data,
4750 DeviceMemory<uint8> *reserve_space_data,
4751 ScratchAllocator *workspace_allocator,
4752 dnn::ProfileResult *output_profile_result) {
4753 // TODO(zhengxq): add VLOG PARAM calls.
4754 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4755 auto status = dnn->DoRnnBackward(
4756 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4757 input_c_desc, input_c_data, params, output_desc, output_data,
4758 output_h_desc, output_h_data, output_c_desc, output_c_data,
4759 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4760 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4761 params_backprop_data, reserve_space_data, workspace_allocator,
4762 output_profile_result);
4763 if (!status && !output_profile_result) {
4764 SetError();
4765 }
4766 } else {
4767 SetError();
4768 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4769 }
4770 return *this;
4771 }
4772
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)4773 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
4774 const DeviceMemory<float> &probs_data,
4775 absl::Span<const int> labels_data,
4776 absl::Span<const int> labels_lengths_data,
4777 absl::Span<const int> input_lengths_data,
4778 DeviceMemory<float> *costs_data,
4779 const dnn::RnnStateTensorDescriptor &grads_desc,
4780 DeviceMemory<float> *grads_data,
4781 ScratchAllocator *workspace_allocator) {
4782 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4783 DeviceMemory<uint8> scratch_memory;
4784 int ctc_loss_algo_id;
4785 auto status =
4786 dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
4787 labels_data, labels_lengths_data,
4788 input_lengths_data, workspace_allocator,
4789 &scratch_memory, &ctc_loss_algo_id)
4790 .ok();
4791 if (status) {
4792 status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
4793 labels_lengths_data, input_lengths_data,
4794 costs_data, grads_desc, grads_data,
4795 &scratch_memory, ctc_loss_algo_id);
4796 }
4797 if (!status) {
4798 SetError();
4799 }
4800 } else {
4801 SetErrorAndLogNoDnnSupport();
4802 }
4803 return *this;
4804 }
4805
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4806 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4807 dnn::DataType input_type,
4808 const DeviceMemoryBase &input_data,
4809 const dnn::BatchDescriptor &output_desc,
4810 dnn::DataType output_type, float scale,
4811 DeviceMemoryBase *output_data) {
4812 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4813 PARAM(output_desc), PARAM(output_type), PARAM(scale),
4814 PARAM(output_data));
4815 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4816 CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
4817 output_desc, output_type, scale,
4818 output_data));
4819 } else {
4820 SetErrorAndLogNoDnnSupport();
4821 }
4822 return *this;
4823 }
4824
ThenDoHostCallback(std::function<void ()> callback)4825 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4826 VLOG_CALL(PARAM(callback));
4827
4828 if (!ok()) {
4829 LOG(INFO) << DebugStreamPointers()
4830 << " was in error state before adding host callback";
4831 }
4832 CheckError(parent_->HostCallback(this, std::move(callback)));
4833 return *this;
4834 }
4835
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)4836 Stream &Stream::ThenDoHostCallbackWithStatus(
4837 std::function<port::Status()> callback) {
4838 VLOG_CALL(PARAM(callback));
4839
4840 if (!ok()) {
4841 LOG(INFO) << DebugStreamPointers()
4842 << " was in error state before adding host callback";
4843 }
4844 CheckError(parent_->HostCallback(this, std::move(callback)));
4845 return *this;
4846 }
4847
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)4848 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
4849 std::function<void()> callback) {
4850 VLOG_CALL(PARAM(callback));
4851
4852 if (!ok()) {
4853 LOG(INFO) << DebugStreamPointers()
4854 << " was in error state before adding callback to be run after "
4855 "next block-host-until-done.";
4856 }
4857 absl::MutexLock lock(&mu_);
4858 after_block_host_until_done_callbacks_.push_back(std::move(callback));
4859 return *this;
4860 }
4861
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4862 Stream &Stream::ThenFft(fft::Plan *plan,
4863 const DeviceMemory<std::complex<float>> &input,
4864 DeviceMemory<std::complex<float>> *output) {
4865 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4866
4867 if (fft::FftSupport *fft = parent_->AsFft()) {
4868 CheckError(fft->DoFft(this, plan, input, output));
4869 } else {
4870 SetError();
4871 LOG(INFO) << DebugStreamPointers()
4872 << " attempting to perform FFT operation using StreamExecutor"
4873 " without FFT support";
4874 }
4875 return *this;
4876 }
4877
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4878 Stream &Stream::ThenFft(fft::Plan *plan,
4879 const DeviceMemory<std::complex<double>> &input,
4880 DeviceMemory<std::complex<double>> *output) {
4881 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4882
4883 if (fft::FftSupport *fft = parent_->AsFft()) {
4884 CheckError(fft->DoFft(this, plan, input, output));
4885 } else {
4886 SetError();
4887 LOG(INFO) << DebugStreamPointers()
4888 << " attempting to perform FFT operation using StreamExecutor"
4889 " without FFT support";
4890 }
4891 return *this;
4892 }
4893
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4894 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4895 DeviceMemory<std::complex<float>> *output) {
4896 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4897
4898 if (fft::FftSupport *fft = parent_->AsFft()) {
4899 CheckError(fft->DoFft(this, plan, input, output));
4900 } else {
4901 SetError();
4902 LOG(INFO) << DebugStreamPointers()
4903 << " attempting to perform FFT operation using StreamExecutor"
4904 " without FFT support";
4905 }
4906 return *this;
4907 }
4908
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4909 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4910 DeviceMemory<std::complex<double>> *output) {
4911 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4912
4913 if (fft::FftSupport *fft = parent_->AsFft()) {
4914 CheckError(fft->DoFft(this, plan, input, output));
4915 } else {
4916 SetError();
4917 LOG(INFO) << DebugStreamPointers()
4918 << " attempting to perform FFT operation using StreamExecutor"
4919 " without FFT support";
4920 }
4921 return *this;
4922 }
4923
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)4924 Stream &Stream::ThenFft(fft::Plan *plan,
4925 const DeviceMemory<std::complex<float>> &input,
4926 DeviceMemory<float> *output) {
4927 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4928
4929 if (fft::FftSupport *fft = parent_->AsFft()) {
4930 CheckError(fft->DoFft(this, plan, input, output));
4931 } else {
4932 SetError();
4933 LOG(INFO) << DebugStreamPointers()
4934 << " attempting to perform FFT operation using StreamExecutor"
4935 " without FFT support";
4936 }
4937 return *this;
4938 }
4939
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)4940 Stream &Stream::ThenFft(fft::Plan *plan,
4941 const DeviceMemory<std::complex<double>> &input,
4942 DeviceMemory<double> *output) {
4943 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4944
4945 if (fft::FftSupport *fft = parent_->AsFft()) {
4946 CheckError(fft->DoFft(this, plan, input, output));
4947 } else {
4948 SetError();
4949 LOG(INFO) << DebugStreamPointers()
4950 << " attempting to perform FFT operation using StreamExecutor"
4951 " without FFT support";
4952 }
4953 return *this;
4954 }
4955
4956 // It looks confusing, but all this is doing is inserting a callback at the
4957 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)4958 Stream &Stream::ThenEnqueueOnBackgroundThread(
4959 std::function<void(StreamExecutor *)> task) {
4960 VLOG_CALL(PARAM(task));
4961
4962 StreamExecutor *stream_executor = this->parent_;
4963 std::function<void()> bound_task = std::bind(task, stream_executor);
4964
4965 return ThenDoHostCallback([stream_executor, bound_task]() {
4966 stream_executor->EnqueueOnBackgroundThread(bound_task);
4967 });
4968 }
4969
BlockHostUntilDone()4970 port::Status Stream::BlockHostUntilDone() {
4971 VLOG_CALL();
4972
4973 if (!ok()) {
4974 port::Status status = port::Status(
4975 port::error::INTERNAL,
4976 "stream did not block host until done; was already in an error state");
4977 LOG(INFO) << DebugStreamPointers() << " " << status;
4978 return status;
4979 }
4980
4981 temporary_memory_manager_.DeallocateFinalizedTemporaries();
4982
4983 port::Status error = parent_->BlockHostUntilDone(this);
4984 CheckError(error.ok());
4985
4986 RunAfterBlockHostUntilDoneCallbacks();
4987 return error;
4988 }
4989
RunAfterBlockHostUntilDoneCallbacks()4990 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
4991 std::vector<std::function<void()>> callbacks;
4992 {
4993 absl::MutexLock lock(&mu_);
4994 std::swap(callbacks, after_block_host_until_done_callbacks_);
4995 }
4996 for (const auto &fn : callbacks) {
4997 fn();
4998 }
4999 }
5000
DebugStreamPointers() const5001 std::string Stream::DebugStreamPointers() const {
5002 // Relies on the ToVlogString(const void*) overload above.
5003 return absl::StrCat("[stream=", ToVlogString(this),
5004 ",impl=", ToVlogString(implementation_.get()), "]");
5005 }
5006
CheckStatus(port::Status status)5007 void Stream::CheckStatus(port::Status status) {
5008 if (status.ok()) {
5009 return;
5010 }
5011 LOG(ERROR) << status;
5012 absl::MutexLock lock(&mu_);
5013 status_ = status;
5014 }
5015
5016 } // namespace stream_executor
5017