• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/stream.h"
17 
18 #include "tensorflow/stream_executor/platform/port.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "third_party/eigen3/Eigen/Core"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/host_or_device_scalar.h"
24 #include "tensorflow/stream_executor/lib/stacktrace.h"
25 #include "tensorflow/stream_executor/platform.h"
26 #include "tensorflow/stream_executor/platform/logging.h"
27 #include "tensorflow/stream_executor/rng.h"
28 #include "tensorflow/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
30 
31 namespace stream_executor {
32 
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39   return descriptor.ToShortString();
40 }
41 
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43   return descriptor.ToShortString();
44 }
45 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47   return descriptor.ToShortString();
48 }
49 
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51   return descriptor.ToShortString();
52 }
53 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55   return descriptor.ToShortString();
56 }
57 
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59   return dnn::ActivationModeString(mode);
60 }
61 
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63   return algo_config.ToString();
64 }
65 
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67   return dnn::ElementwiseOperationString(op);
68 }
69 
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71   return dnn::QuantizedActivationModeString(mode);
72 }
73 
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75 
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77 
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79 
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81 
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83   return blas::ComputationTypeString(ty);
84 }
85 
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87   if (ptr == nullptr) {
88     return "null";
89   }
90 
91   // StrCat does not convert pointers to text.
92   std::ostringstream out;
93   out << ptr;
94   return out.str();
95 }
96 
97 template <class T>
ToVlogString(const std::complex<T> & c)98 string ToVlogString(const std::complex<T> &c) {
99   // StrCat does not convert std::complex to text.
100   std::ostringstream out;
101   out << c;
102   return out.str();
103 }
104 
105 template <class T>
ToVlogString(const std::function<T> & f)106 string ToVlogString(const std::function<T> &f) {
107   return f == nullptr ? "null" : "<non-null function>";
108 }
109 
ToVlogString(const DeviceMemoryBase & memory)110 string ToVlogString(const DeviceMemoryBase &memory) {
111   return ToVlogString(memory.opaque());
112 }
113 
ToVlogString(const DeviceMemoryBase * memory)114 string ToVlogString(const DeviceMemoryBase *memory) {
115   return memory == nullptr ? "null" : ToVlogString(*memory);
116 }
117 
ToVlogString(const Eigen::half & h)118 string ToVlogString(const Eigen::half &h) {
119   return absl::StrCat(static_cast<float>(h));
120 }
121 
ToVlogString(int i)122 string ToVlogString(int i) { return absl::StrCat(i); }
123 
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return absl::StrCat(i); }
125 
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return absl::StrCat(i); }
127 
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return absl::StrCat(i); }
129 
ToVlogString(float f)130 string ToVlogString(float f) { return absl::StrCat(f); }
131 
ToVlogString(double d)132 string ToVlogString(double d) { return absl::StrCat(d); }
133 
134 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)135 string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
136   if (memory_or_constant.is_pointer()) {
137     return ToVlogString(memory_or_constant.pointer());
138   }
139   return ToVlogString(memory_or_constant.value());
140 }
141 
142 template <class T>
ToVlogString(port::ArraySlice<T> elements)143 string ToVlogString(port::ArraySlice<T> elements) {
144   string str = absl::StrCat(
145       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
146       elements.size(), "]{");
147   const char *separator = "";
148   size_t max_to_show = std::numeric_limits<size_t>::max();
149   if (!VLOG_IS_ON(2)) {
150     max_to_show = 5;
151   } else if (!VLOG_IS_ON(3)) {
152     max_to_show = 20;
153   } else if (!VLOG_IS_ON(11)) {
154     max_to_show = 1000;
155   }
156   for (size_t i = 0; i < elements.size(); ++i) {
157     if (i == max_to_show) {
158       str += ", ...";
159       break;
160     }
161     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
162     separator = ", ";
163   }
164   str += "}";
165   return str;
166 }
167 
168 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)169 string ToVlogString(port::MutableArraySlice<T> elements) {
170   return ToVlogString(port::ArraySlice<T>(elements));
171 }
172 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)173 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
174   switch (depth_to_space_layout) {
175     case dnn::DepthToSpaceLayout::DepthHeightWidth:
176       return "DepthToSpaceLayout::DepthHeightWidth";
177   }
178   return "unknown DepthToSpaceLayout";
179 }
180 
ToVlogString(dnn::DataType data_type)181 string ToVlogString(dnn::DataType data_type) {
182   switch (data_type) {
183     case dnn::DataType::kFloat:
184       return "dnn::DataType::kFloat";
185     case dnn::DataType::kDouble:
186       return "dnn::DataType::kDouble";
187     case dnn::DataType::kHalf:
188       return "dnn::DataType::kHalf";
189     case dnn::DataType::kInt8:
190       return "dnn::DataType::kInt8";
191     case dnn::DataType::kInt32:
192       return "dnn::DataType::kInt32";
193     default:
194       return "unknown DataType";
195   }
196 }
197 
198 // Used together with PARAM to VLOG calls made to the stream. Intended
199 // to be used like this:
200 //
201 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
202 //
203 // where a and b are the parameters to MyFunction.
204 //
205 // See VLOG_CALL for a short-hand for this. This way of doing it saves
206 // a tremendous amount of boilerplate code given how many functions
207 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)208 string CallStr(const char *function_name, Stream *stream,
209                std::vector<std::pair<const char *, string>> params) {
210   // Do not call this function unless VLOG is on since just
211   // constructing all the strings in params is expensive.
212   CHECK(VLOG_IS_ON(1));
213 
214   string str = absl::StrCat(stream->DebugStreamPointers(),
215                             " Called Stream::", function_name, "(");
216   const char *separator = "";
217   for (const auto &param : params) {
218     absl::StrAppend(&str, separator, param.first, "=", param.second);
219     separator = ", ";
220   }
221   absl::StrAppend(&str, ")");
222   if (VLOG_IS_ON(10)) {
223     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
224   }
225   return str;
226 }
227 
228 // Use this macro to avoid having to type every parameter twice to log
229 // it with VLOG and CallStr.
230 #define PARAM(parameter) \
231   { #parameter, ToVlogString(parameter) }
232 
233 // Use this macro to avoid having to type out the name of each
234 // function and to save some boilerplate. Intended to be used like this:
235 //
236 //   VLOG_CALL(PARAM(a), PARAM(b))
237 //
238 // This saves a tremendous amount of boilerplate compared to the alternative:
239 //
240 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
241 //           << ", b=" << ToVlogString(b);
242 //
243 // Note here that most of the parameter names are not short and that
244 // most of the functions take many more than 2 parameters.
245 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
246 
247 }  // namespace
248 
Stream(StreamExecutor * parent)249 Stream::Stream(StreamExecutor *parent)
250     : parent_(parent),
251       implementation_(parent->implementation()->GetStreamImplementation()),
252       allocated_(false),
253       ok_(false),
254       temporary_memory_manager_(this) {
255   VLOG_CALL(PARAM(parent));
256 }
257 
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)258 Stream::Stream(StreamExecutor *parent,
259                internal::StreamInterface *implementation)
260     : parent_(parent),
261       implementation_(implementation),
262       allocated_(false),
263       ok_(false),
264       temporary_memory_manager_(this) {
265   VLOG_CALL(PARAM(parent), PARAM(implementation));
266 }
267 
~Stream()268 Stream::~Stream() {
269   VLOG_CALL();
270 
271   // Ensure the stream is completed.
272   auto status = BlockHostUntilDone();
273   if (!status.ok()) {
274     LOG(WARNING) << "Error blocking host until done in stream destructor: "
275                  << status;
276   }
277   temporary_memory_manager_.ForceDeallocateAll();
278   RunAfterBlockHostUntilDoneCallbacks();
279 
280   if (allocated_) {
281     parent_->DeallocateStream(this);
282   }
283 }
284 
RefreshStatus()285 port::Status Stream::RefreshStatus() {
286   port::Status status = parent_->GetStatus(this);
287   CheckStatus(status);
288   return status;
289 }
290 
Init()291 Stream &Stream::Init() {
292   VLOG_CALL();
293 
294   absl::MutexLock lock(&mu_);
295   CHECK_EQ(false, allocated_)
296       << "stream appears to already have been initialized";
297   CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
298 
299   if (parent_->AllocateStream(this)) {
300     // Successful initialization!
301     allocated_ = true;
302     ok_ = true;
303   } else {
304     LOG(ERROR) << "failed to allocate stream during initialization";
305   }
306 
307   return *this;
308 }
309 
InitTimer(Timer * timer)310 Stream &Stream::InitTimer(Timer *timer) {
311   VLOG_CALL(PARAM(timer));
312 
313   if (ok()) {
314     CheckError(parent_->AllocateTimer(timer));
315   } else {
316     LOG(INFO) << "did not allocate timer: " << timer;
317   }
318   return *this;
319 }
320 
InitWithTimer(Timer * timer)321 Stream &Stream::InitWithTimer(Timer *timer) {
322   VLOG_CALL(PARAM(timer));
323 
324   return Init().InitTimer(timer);
325 }
326 
ThenRecordEvent(Event * event)327 Stream &Stream::ThenRecordEvent(Event *event) {
328   VLOG_CALL(PARAM(event));
329 
330   port::Status status = parent_->RecordEvent(this, event);
331   if (!status.ok()) {
332     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
333                << "; not marking stream as bad, as the Event object may be "
334                << "at fault. Monitor for further errors.";
335   }
336 
337   return *this;
338 }
339 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)340 Stream &Stream::ThenBatchNormalizationForward(
341     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
342     const DeviceMemory<float> &offset,
343     const DeviceMemory<float> &estimated_mean,
344     const DeviceMemory<float> &estimated_variance,
345     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
346     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
347     const double exponential_average_factor,
348     dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
349     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
350     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
351     bool is_training,
352     std::function<const DeviceMemory<float> &()> var_to_inv_var,
353     std::function<void()> inv_var_to_var,
354     ScratchAllocator *reserve_space_allocator,
355     ScratchAllocator *workspace_allocator) {
356   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
357             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
358   if (ok()) {
359     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
360       CheckError(dnn->DoBatchNormalizationForward(
361           this, x, scale, offset, estimated_mean, estimated_variance,
362           side_input, x_desc, scale_offset_desc, epsilon,
363           exponential_average_factor, activation_mode, y, batch_mean, batch_var,
364           saved_mean, saved_inv_var, is_training, reserve_space_allocator,
365           workspace_allocator, std::move(var_to_inv_var),
366           std::move(inv_var_to_var)));
367     } else {
368       SetErrorAndLogNoDnnSupport();
369     }
370   }
371   return *this;
372 }
373 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)374 Stream &Stream::ThenBatchNormalizationBackward(
375     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
376     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
377     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
378     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
379     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
380     DeviceMemory<float> *offset_backprop,
381     DeviceMemory<uint8> *reserve_space_data,
382     ScratchAllocator *workspace_allocator) {
383   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
384             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
385             PARAM(scale_backprop), PARAM(offset_backprop));
386   if (ok()) {
387     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
388       CheckError(dnn->DoBatchNormalizationBackward(
389           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
390           epsilon, x_backprop, scale_backprop, offset_backprop,
391           reserve_space_data, workspace_allocator));
392     } else {
393       SetErrorAndLogNoDnnSupport();
394     }
395   }
396   return *this;
397 }
398 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)399 Stream &Stream::ThenBatchNormalizationForward(
400     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
401     const DeviceMemory<float> &offset,
402     const DeviceMemory<float> &estimated_mean,
403     const DeviceMemory<float> &estimated_variance,
404     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
405     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
406     const double exponential_average_factor,
407     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
408     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
409     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
410     bool is_training,
411     std::function<const DeviceMemory<float> &()> var_to_inv_var,
412     std::function<void()> inv_var_to_var,
413     ScratchAllocator *reserve_space_allocator,
414     ScratchAllocator *workspace_allocator) {
415   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
416             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
417   if (ok()) {
418     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
419       CheckError(dnn->DoBatchNormalizationForward(
420           this, x, scale, offset, estimated_mean, estimated_variance,
421           side_input, x_desc, scale_offset_desc, epsilon,
422           exponential_average_factor, activation_mode, y, batch_mean, batch_var,
423           saved_mean, saved_inv_var, is_training, reserve_space_allocator,
424           workspace_allocator, std::move(var_to_inv_var),
425           std::move(inv_var_to_var)));
426     } else {
427       SetErrorAndLogNoDnnSupport();
428     }
429   }
430   return *this;
431 }
432 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)433 Stream &Stream::ThenBatchNormalizationBackward(
434     const DeviceMemory<Eigen::half> &y_backprop,
435     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
436     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
437     const dnn::BatchDescriptor &x_desc,
438     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
439     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
440     DeviceMemory<float> *offset_backprop,
441     DeviceMemory<uint8> *reserve_space_data,
442     ScratchAllocator *workspace_allocator) {
443   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
444             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
445             PARAM(scale_backprop), PARAM(offset_backprop));
446   if (ok()) {
447     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
448       CheckError(dnn->DoBatchNormalizationBackward(
449           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
450           epsilon, x_backprop, scale_backprop, offset_backprop,
451           reserve_space_data, workspace_allocator));
452 
453     } else {
454       SetErrorAndLogNoDnnSupport();
455     }
456   }
457   return *this;
458 }
459 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)460 Stream &Stream::ThenFusedConvolveWithAlgorithm(
461     const dnn::BatchDescriptor &conv_input_descriptor,
462     const DeviceMemory<double> &conv_input_data, double conv_input_scale,
463     const dnn::FilterDescriptor &filter_descriptor,
464     const DeviceMemory<double> &filter_data,
465     const dnn::ConvolutionDescriptor &convolution_descriptor,
466     const DeviceMemory<double> &side_input_data, double side_input_scale,
467     const dnn::BatchDescriptor &bias_descriptor,
468     const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
469     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
470     ScratchAllocator *scratch_allocator,
471     const dnn::AlgorithmConfig &algorithm_config,
472     dnn::ProfileResult *output_profile_result) {
473   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
474             PARAM(conv_input_scale), PARAM(filter_descriptor),
475             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
476             PARAM(side_input_data), PARAM(side_input_scale),
477             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
478             PARAM(algorithm_config));
479 
480   if (ok()) {
481     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
482       auto status = dnn->DoFusedConvolve(
483           this, conv_input_descriptor, conv_input_data, conv_input_scale,
484           filter_descriptor, filter_data, convolution_descriptor,
485           side_input_data, side_input_scale, bias_descriptor, biases,
486           activation_mode, output_descriptor, output, scratch_allocator,
487           algorithm_config, output_profile_result);
488       if (!status && !output_profile_result) {
489         SetError();
490       }
491     } else {
492       SetErrorAndLogNoDnnSupport();
493     }
494   }
495   return *this;
496 }
497 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)498 Stream &Stream::ThenFusedConvolveWithAlgorithm(
499     const dnn::BatchDescriptor &conv_input_descriptor,
500     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
501     const dnn::FilterDescriptor &filter_descriptor,
502     const DeviceMemory<float> &filter_data,
503     const dnn::ConvolutionDescriptor &convolution_descriptor,
504     const DeviceMemory<float> &side_input_data, float side_input_scale,
505     const dnn::BatchDescriptor &bias_descriptor,
506     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
507     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
508     ScratchAllocator *scratch_allocator,
509     const dnn::AlgorithmConfig &algorithm_config,
510     dnn::ProfileResult *output_profile_result) {
511   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
512             PARAM(conv_input_scale), PARAM(filter_descriptor),
513             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
514             PARAM(side_input_data), PARAM(side_input_scale),
515             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
516             PARAM(algorithm_config));
517 
518   if (ok()) {
519     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
520       auto status = dnn->DoFusedConvolve(
521           this, conv_input_descriptor, conv_input_data, conv_input_scale,
522           filter_descriptor, filter_data, convolution_descriptor,
523           side_input_data, side_input_scale, bias_descriptor, biases,
524           activation_mode, output_descriptor, output, scratch_allocator,
525           algorithm_config, output_profile_result);
526       if (!status && !output_profile_result) {
527         SetError();
528       }
529     } else {
530       SetErrorAndLogNoDnnSupport();
531     }
532   }
533   return *this;
534 }
535 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)536 Stream &Stream::ThenFusedConvolveWithAlgorithm(
537     const dnn::BatchDescriptor &conv_input_descriptor,
538     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
539     const dnn::FilterDescriptor &filter_descriptor,
540     const DeviceMemory<Eigen::half> &filter_data,
541     const dnn::ConvolutionDescriptor &convolution_descriptor,
542     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
543     const dnn::BatchDescriptor &bias_descriptor,
544     const DeviceMemory<Eigen::half> &biases,
545     dnn::ActivationMode activation_mode,
546     const dnn::BatchDescriptor &output_descriptor,
547     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
548     const dnn::AlgorithmConfig &algorithm_config,
549     dnn::ProfileResult *output_profile_result) {
550   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
551             PARAM(conv_input_scale), PARAM(filter_descriptor),
552             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
553             PARAM(side_input_data), PARAM(side_input_scale),
554             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
555             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
556 
557   if (ok()) {
558     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
559       auto status = dnn->DoFusedConvolve(
560           this, conv_input_descriptor, conv_input_data, conv_input_scale,
561           filter_descriptor, filter_data, convolution_descriptor,
562           side_input_data, side_input_scale, bias_descriptor, biases,
563           activation_mode, output_descriptor, output, scratch_allocator,
564           algorithm_config, output_profile_result);
565       if (!status && !output_profile_result) {
566         SetError();
567       }
568     } else {
569       SetErrorAndLogNoDnnSupport();
570     }
571   }
572   return *this;
573 }
574 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)575 Stream &Stream::ThenFusedConvolveWithAlgorithm(
576     const dnn::BatchDescriptor &conv_input_descriptor,
577     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
578     const dnn::FilterDescriptor &filter_descriptor,
579     const DeviceMemory<int8> &filter_data,
580     const dnn::ConvolutionDescriptor &convolution_descriptor,
581     const DeviceMemory<int8> &side_input_data, float side_input_scale,
582     const dnn::BatchDescriptor &bias_descriptor,
583     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
584     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
585     ScratchAllocator *scratch_allocator,
586     const dnn::AlgorithmConfig &algorithm_config,
587     dnn::ProfileResult *output_profile_result) {
588   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
589             PARAM(conv_input_scale), PARAM(filter_descriptor),
590             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
591             PARAM(side_input_data), PARAM(side_input_scale),
592             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
593             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
594 
595   if (ok()) {
596     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
597       auto status = dnn->DoFusedConvolve(
598           this, conv_input_descriptor, conv_input_data, conv_input_scale,
599           filter_descriptor, filter_data, convolution_descriptor,
600           side_input_data, side_input_scale, bias_descriptor, biases,
601           activation_mode, output_descriptor, output, scratch_allocator,
602           algorithm_config, output_profile_result);
603       if (!status && !output_profile_result) {
604         SetError();
605       }
606     } else {
607       SetErrorAndLogNoDnnSupport();
608     }
609   }
610   return *this;
611 }
612 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)613 Stream &Stream::ThenFusedConvolveWithAlgorithm(
614     const dnn::BatchDescriptor &conv_input_descriptor,
615     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
616     const dnn::FilterDescriptor &filter_descriptor,
617     const DeviceMemory<int8> &filter_data,
618     const dnn::ConvolutionDescriptor &convolution_descriptor,
619     const DeviceMemory<float> &side_input_data, float side_input_scale,
620     const dnn::BatchDescriptor &bias_descriptor,
621     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
622     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
623     ScratchAllocator *scratch_allocator,
624     const dnn::AlgorithmConfig &algorithm_config,
625     dnn::ProfileResult *output_profile_result) {
626   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
627             PARAM(conv_input_scale), PARAM(filter_descriptor),
628             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
629             PARAM(side_input_data), PARAM(side_input_scale),
630             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
631             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
632 
633   if (ok()) {
634     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
635       auto status = dnn->DoFusedConvolve(
636           this, conv_input_descriptor, conv_input_data, conv_input_scale,
637           filter_descriptor, filter_data, convolution_descriptor,
638           side_input_data, side_input_scale, bias_descriptor, biases,
639           activation_mode, output_descriptor, output, scratch_allocator,
640           algorithm_config, output_profile_result);
641       if (!status && !output_profile_result) {
642         SetError();
643       }
644     } else {
645       SetErrorAndLogNoDnnSupport();
646     }
647   }
648   return *this;
649 }
650 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)651 Stream &Stream::ThenConvolveWithAlgorithm(
652     const dnn::BatchDescriptor &input_descriptor,
653     const DeviceMemory<double> &input_data,
654     const dnn::FilterDescriptor &filter_descriptor,
655     const DeviceMemory<double> &filter_data,
656     const dnn::ConvolutionDescriptor &convolution_descriptor,
657     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
658     ScratchAllocator *scratch_allocator,
659     const dnn::AlgorithmConfig &algorithm_config,
660     dnn::ProfileResult *output_profile_result) {
661   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
662             PARAM(filter_descriptor), PARAM(filter_data),
663             PARAM(convolution_descriptor), PARAM(output_descriptor),
664             PARAM(output), PARAM(algorithm_config));
665 
666   if (ok()) {
667     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
668       DeviceMemory<uint8> scratch_memory;
669       dnn::AlgorithmDesc algorithm_desc;
670       auto status =
671           dnn->PrepareForConvolution(
672                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
673                  input_data, filter_descriptor, filter_data, output_descriptor,
674                  *output, convolution_descriptor, algorithm_config,
675                  scratch_allocator, &algorithm_desc, &scratch_memory)
676               .ok();
677       if (status) {
678         status = dnn->DoConvolve(
679             this, input_descriptor, input_data, filter_descriptor, filter_data,
680             convolution_descriptor, output_descriptor, output, algorithm_desc,
681             &scratch_memory, output_profile_result);
682       }
683       if (!status && !output_profile_result) {
684         SetError();
685       }
686     } else {
687       SetErrorAndLogNoDnnSupport();
688     }
689   }
690   return *this;
691 }
692 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)693 Stream &Stream::ThenConvolveWithAlgorithm(
694     const dnn::BatchDescriptor &input_descriptor,
695     const DeviceMemory<float> &input_data,
696     const dnn::FilterDescriptor &filter_descriptor,
697     const DeviceMemory<float> &filter_data,
698     const dnn::ConvolutionDescriptor &convolution_descriptor,
699     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
700     ScratchAllocator *scratch_allocator,
701     const dnn::AlgorithmConfig &algorithm_config,
702     dnn::ProfileResult *output_profile_result) {
703   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
704             PARAM(filter_descriptor), PARAM(filter_data),
705             PARAM(convolution_descriptor), PARAM(output_descriptor),
706             PARAM(output), PARAM(algorithm_config));
707 
708   if (ok()) {
709     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
710       DeviceMemory<uint8> scratch_memory;
711       dnn::AlgorithmDesc algorithm_desc;
712       auto status =
713           dnn->PrepareForConvolution(
714                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
715                  input_data, filter_descriptor, filter_data, output_descriptor,
716                  *output, convolution_descriptor, algorithm_config,
717                  scratch_allocator, &algorithm_desc, &scratch_memory)
718               .ok();
719       if (status) {
720         status = dnn->DoConvolve(
721             this, input_descriptor, input_data, filter_descriptor, filter_data,
722             convolution_descriptor, output_descriptor, output, algorithm_desc,
723             &scratch_memory, output_profile_result);
724       }
725       if (!status && !output_profile_result) {
726         SetError();
727       }
728     } else {
729       SetErrorAndLogNoDnnSupport();
730     }
731   }
732   return *this;
733 }
734 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)735 Stream &Stream::ThenConvolveWithAlgorithm(
736     const dnn::BatchDescriptor &input_descriptor,
737     const DeviceMemory<Eigen::half> &input_data,
738     const dnn::FilterDescriptor &filter_descriptor,
739     const DeviceMemory<Eigen::half> &filter_data,
740     const dnn::ConvolutionDescriptor &convolution_descriptor,
741     const dnn::BatchDescriptor &output_descriptor,
742     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
743     const dnn::AlgorithmConfig &algorithm_config,
744     dnn::ProfileResult *output_profile_result) {
745   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
746             PARAM(filter_descriptor), PARAM(filter_data),
747             PARAM(convolution_descriptor), PARAM(output_descriptor),
748             PARAM(output), PARAM(algorithm_config));
749 
750   if (ok()) {
751     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
752       DeviceMemory<uint8> scratch_memory;
753       dnn::AlgorithmDesc algorithm_desc;
754       auto status =
755           dnn->PrepareForConvolution(
756                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
757                  input_data, filter_descriptor, filter_data, output_descriptor,
758                  *output, convolution_descriptor, algorithm_config,
759                  scratch_allocator, &algorithm_desc, &scratch_memory)
760               .ok();
761       if (status) {
762         status = dnn->DoConvolve(
763             this, input_descriptor, input_data, filter_descriptor, filter_data,
764             convolution_descriptor, output_descriptor, output, algorithm_desc,
765             &scratch_memory, output_profile_result);
766       }
767       if (!status && !output_profile_result) {
768         SetError();
769       }
770     } else {
771       SetErrorAndLogNoDnnSupport();
772     }
773   }
774   return *this;
775 }
776 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<int8> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)777 Stream &Stream::ThenConvolveWithAlgorithm(
778     const dnn::BatchDescriptor &input_descriptor,
779     const DeviceMemory<int8> &input_data,
780     const dnn::FilterDescriptor &filter_descriptor,
781     const DeviceMemory<int8> &filter_data,
782     const dnn::ConvolutionDescriptor &convolution_descriptor,
783     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
784     ScratchAllocator *scratch_allocator,
785     const dnn::AlgorithmConfig &algorithm_config,
786     dnn::ProfileResult *output_profile_result) {
787   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
788             PARAM(filter_descriptor), PARAM(filter_data),
789             PARAM(convolution_descriptor), PARAM(output_descriptor),
790             PARAM(output), PARAM(algorithm_config));
791 
792   if (ok()) {
793     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794       DeviceMemory<uint8> scratch_memory;
795       dnn::AlgorithmDesc algorithm_desc;
796       auto status =
797           dnn->PrepareForConvolution(
798                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
799                  input_data, filter_descriptor, filter_data, output_descriptor,
800                  *output, convolution_descriptor, algorithm_config,
801                  scratch_allocator, &algorithm_desc, &scratch_memory)
802               .ok();
803       if (status) {
804         status = dnn->DoConvolve(
805             this, input_descriptor, input_data, filter_descriptor, filter_data,
806             convolution_descriptor, output_descriptor, output, algorithm_desc,
807             &scratch_memory, output_profile_result);
808       }
809       if (!status && !output_profile_result) {
810         SetError();
811       }
812     } else {
813       SetErrorAndLogNoDnnSupport();
814     }
815   }
816   return *this;
817 }
818 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<int8> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)819 Stream &Stream::ThenConvolveWithAlgorithm(
820     const dnn::BatchDescriptor &input_descriptor,
821     const DeviceMemory<int8> &input_data,
822     const dnn::FilterDescriptor &filter_descriptor,
823     const DeviceMemory<int8> &filter_data,
824     const dnn::ConvolutionDescriptor &convolution_descriptor,
825     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
826     ScratchAllocator *scratch_allocator,
827     const dnn::AlgorithmConfig &algorithm_config,
828     dnn::ProfileResult *output_profile_result) {
829   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
830             PARAM(filter_descriptor), PARAM(filter_data),
831             PARAM(convolution_descriptor), PARAM(output_descriptor),
832             PARAM(output), PARAM(algorithm_config));
833 
834   if (ok()) {
835     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
836       DeviceMemory<uint8> scratch_memory;
837       dnn::AlgorithmDesc algorithm_desc;
838       auto status =
839           dnn->PrepareForConvolution(
840                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
841                  input_data, filter_descriptor, filter_data, output_descriptor,
842                  *output, convolution_descriptor, algorithm_config,
843                  scratch_allocator, &algorithm_desc, &scratch_memory)
844               .ok();
845       if (status) {
846         status = dnn->DoConvolve(
847             this, input_descriptor, input_data, filter_descriptor, filter_data,
848             convolution_descriptor, output_descriptor, output, algorithm_desc,
849             &scratch_memory, output_profile_result);
850       }
851       if (!status && !output_profile_result) {
852         SetError();
853       }
854     } else {
855       SetErrorAndLogNoDnnSupport();
856     }
857   }
858   return *this;
859 }
860 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)861 Stream &Stream::ThenConvolve(
862     const dnn::BatchDescriptor &input_descriptor,
863     const DeviceMemory<float> &input_data,
864     const dnn::FilterDescriptor &filter_descriptor,
865     const DeviceMemory<float> &filter_data,
866     const dnn::ConvolutionDescriptor &convolution_descriptor,
867     const dnn::BatchDescriptor &output_descriptor,
868     DeviceMemory<float> *output) {
869   return ThenConvolveWithAlgorithm(
870       input_descriptor, input_data, filter_descriptor, filter_data,
871       convolution_descriptor, output_descriptor, output,
872       /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
873       /*output_profile_result=*/nullptr);
874 }
875 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)876 Stream &Stream::ThenConvolveQuantized(
877     const dnn::BatchDescriptor &input_descriptor,
878     const DeviceMemory<float> &input_data,
879     const dnn::FilterDescriptor &filter_descriptor,
880     const DeviceMemory<int8> &filter_coefficients,
881     const DeviceMemory<float> &coefficient_scales,
882     const dnn::ConvolutionDescriptor &convolution_descriptor,
883     const dnn::BatchDescriptor &output_descriptor,
884     DeviceMemory<float> *output) {
885   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
886             PARAM(filter_descriptor), PARAM(filter_coefficients),
887             PARAM(coefficient_scales), PARAM(convolution_descriptor),
888             PARAM(output_descriptor), PARAM(output));
889 
890   if (ok()) {
891     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
892       CheckError(dnn->DoConvolveQuantized(
893           this, input_descriptor, input_data, filter_descriptor,
894           filter_coefficients, coefficient_scales, convolution_descriptor,
895           output_descriptor, output));
896     } else {
897       SetError();
898       LOG(WARNING)
899           << "attempting to perform DNN operation using StreamExecutor "
900              "without DNN support";
901     }
902   }
903   return *this;
904 }
905 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)906 Stream &Stream::ThenConvolveQuantized(
907     const dnn::BatchDescriptor &input_descriptor,
908     const DeviceMemory<float> &input_data,
909     const dnn::FilterDescriptor &filter_descriptor,
910     const DeviceMemory<int16> &filter_coefficients,
911     const DeviceMemory<float> &coefficient_scales,
912     const dnn::ConvolutionDescriptor &convolution_descriptor,
913     const dnn::BatchDescriptor &output_descriptor,
914     DeviceMemory<float> *output) {
915   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
916             PARAM(filter_descriptor), PARAM(filter_coefficients),
917             PARAM(coefficient_scales), PARAM(convolution_descriptor),
918             PARAM(output_descriptor), PARAM(output));
919 
920   if (ok()) {
921     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
922       CheckError(dnn->DoConvolveQuantized(
923           this, input_descriptor, input_data, filter_descriptor,
924           filter_coefficients, coefficient_scales, convolution_descriptor,
925           output_descriptor, output));
926     } else {
927       SetError();
928       LOG(WARNING)
929           << "attempting to perform DNN operation using StreamExecutor "
930              "without DNN support";
931     }
932   }
933   return *this;
934 }
935 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)936 Stream &Stream::ThenSeparableConvolve(
937     const dnn::BatchDescriptor &batch_descriptor,
938     const DeviceMemory<float> &input_data,
939     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
940     const DeviceMemory<float> &first_weights,
941     const DeviceMemory<float> &second_weights,
942     const dnn::ConvolutionDescriptor &convolution_descriptor,
943     const dnn::BatchDescriptor &output_descriptor,
944     DeviceMemory<float> *output) {
945   VLOG_CALL(
946       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
947       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
948       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
949 
950   if (ok()) {
951     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
952       CheckError(dnn->DoSeparableConvolve(
953           this, batch_descriptor, input_data, filter_descriptor,
954           depth_multiplier, first_weights, second_weights,
955           convolution_descriptor, output_descriptor, output));
956     } else {
957       SetErrorAndLogNoDnnSupport();
958     }
959   }
960   return *this;
961 }
962 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<double> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)963 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
964     const dnn::FilterDescriptor &filter_descriptor,
965     const DeviceMemory<double> &filter_data,
966     const dnn::BatchDescriptor &output_descriptor,
967     DeviceMemory<double> backward_output_data,
968     const dnn::ConvolutionDescriptor &convolution_descriptor,
969     const dnn::BatchDescriptor &input_descriptor,
970     DeviceMemory<double> *backward_input_data,
971     ScratchAllocator *scratch_allocator,
972     const dnn::AlgorithmConfig &algorithm_config,
973     dnn::ProfileResult *output_profile_result) {
974   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
975             PARAM(output_descriptor), PARAM(backward_output_data),
976             PARAM(convolution_descriptor), PARAM(input_descriptor),
977             PARAM(backward_input_data));
978 
979   if (ok()) {
980     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
981       DeviceMemory<uint8> scratch_memory;
982       dnn::AlgorithmDesc algorithm_desc;
983       auto status =
984           dnn->PrepareForConvolution(
985                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
986                  *backward_input_data, filter_descriptor, filter_data,
987                  output_descriptor, backward_output_data,
988                  convolution_descriptor, algorithm_config, scratch_allocator,
989                  &algorithm_desc, &scratch_memory)
990               .ok();
991       if (status) {
992         status = dnn->DoConvolveBackwardData(
993             this, filter_descriptor, filter_data, output_descriptor,
994             backward_output_data, convolution_descriptor, input_descriptor,
995             backward_input_data, algorithm_desc, &scratch_memory,
996             output_profile_result);
997       }
998       if (!status && !output_profile_result) {
999         SetError();
1000       }
1001     } else {
1002       SetErrorAndLogNoDnnSupport();
1003     }
1004   }
1005   return *this;
1006 }
1007 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1008 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
1009     const dnn::FilterDescriptor &filter_descriptor,
1010     const DeviceMemory<float> &filter_data,
1011     const dnn::BatchDescriptor &output_descriptor,
1012     DeviceMemory<float> backward_output_data,
1013     const dnn::ConvolutionDescriptor &convolution_descriptor,
1014     const dnn::BatchDescriptor &input_descriptor,
1015     DeviceMemory<float> *backward_input_data,
1016     ScratchAllocator *scratch_allocator,
1017     const dnn::AlgorithmConfig &algorithm_config,
1018     dnn::ProfileResult *output_profile_result) {
1019   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
1020             PARAM(output_descriptor), PARAM(backward_output_data),
1021             PARAM(convolution_descriptor), PARAM(input_descriptor),
1022             PARAM(backward_input_data));
1023 
1024   if (ok()) {
1025     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1026       DeviceMemory<uint8> scratch_memory;
1027       dnn::AlgorithmDesc algorithm_desc;
1028       auto status =
1029           dnn->PrepareForConvolution(
1030                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
1031                  *backward_input_data, filter_descriptor, filter_data,
1032                  output_descriptor, backward_output_data,
1033                  convolution_descriptor, algorithm_config, scratch_allocator,
1034                  &algorithm_desc, &scratch_memory)
1035               .ok();
1036       if (status) {
1037         status = dnn->DoConvolveBackwardData(
1038             this, filter_descriptor, filter_data, output_descriptor,
1039             backward_output_data, convolution_descriptor, input_descriptor,
1040             backward_input_data, algorithm_desc, &scratch_memory,
1041             output_profile_result);
1042       }
1043       if (!status && !output_profile_result) {
1044         SetError();
1045       }
1046     } else {
1047       SetErrorAndLogNoDnnSupport();
1048     }
1049   }
1050   return *this;
1051 }
1052 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1053 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
1054     const dnn::FilterDescriptor &filter_descriptor,
1055     const DeviceMemory<Eigen::half> &filter_data,
1056     const dnn::BatchDescriptor &output_descriptor,
1057     DeviceMemory<Eigen::half> backward_output_data,
1058     const dnn::ConvolutionDescriptor &convolution_descriptor,
1059     const dnn::BatchDescriptor &input_descriptor,
1060     DeviceMemory<Eigen::half> *backward_input_data,
1061     ScratchAllocator *scratch_allocator,
1062     const dnn::AlgorithmConfig &algorithm_config,
1063     dnn::ProfileResult *output_profile_result) {
1064   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
1065             PARAM(output_descriptor), PARAM(backward_output_data),
1066             PARAM(convolution_descriptor), PARAM(input_descriptor),
1067             PARAM(backward_input_data), PARAM(algorithm_config));
1068 
1069   if (ok()) {
1070     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1071       DeviceMemory<uint8> scratch_memory;
1072       dnn::AlgorithmDesc algorithm_desc;
1073       auto status =
1074           dnn->PrepareForConvolution(
1075                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
1076                  *backward_input_data, filter_descriptor, filter_data,
1077                  output_descriptor, backward_output_data,
1078                  convolution_descriptor, algorithm_config, scratch_allocator,
1079                  &algorithm_desc, &scratch_memory)
1080               .ok();
1081       if (status) {
1082         status = dnn->DoConvolveBackwardData(
1083             this, filter_descriptor, filter_data, output_descriptor,
1084             backward_output_data, convolution_descriptor, input_descriptor,
1085             backward_input_data, algorithm_desc, &scratch_memory,
1086             output_profile_result);
1087       }
1088       if (!status && !output_profile_result) {
1089         SetError();
1090       }
1091     } else {
1092       SetErrorAndLogNoDnnSupport();
1093     }
1094   }
1095   return *this;
1096 }
1097 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<double> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1098 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1099     const dnn::BatchDescriptor &input_descriptor,
1100     const DeviceMemory<double> &input_data,
1101     const dnn::BatchDescriptor &output_descriptor,
1102     DeviceMemory<double> backward_output_data,
1103     const dnn::ConvolutionDescriptor &convolution_descriptor,
1104     const dnn::FilterDescriptor &filter_descriptor,
1105     DeviceMemory<double> *backward_filter_data,
1106     ScratchAllocator *scratch_allocator,
1107     const dnn::AlgorithmConfig &algorithm_config,
1108     dnn::ProfileResult *output_profile_result) {
1109   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1110             PARAM(output_descriptor), PARAM(backward_output_data),
1111             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1112             PARAM(backward_filter_data));
1113 
1114   if (ok()) {
1115     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1116       DeviceMemory<uint8> scratch_memory;
1117       dnn::AlgorithmDesc algorithm_desc;
1118       auto status =
1119           dnn->PrepareForConvolution(
1120                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1121                  input_data, filter_descriptor, *backward_filter_data,
1122                  output_descriptor, backward_output_data,
1123                  convolution_descriptor, algorithm_config, scratch_allocator,
1124                  &algorithm_desc, &scratch_memory)
1125               .ok();
1126       if (status) {
1127         status = dnn->DoConvolveBackwardFilter(
1128             this, input_descriptor, input_data, output_descriptor,
1129             backward_output_data, convolution_descriptor, filter_descriptor,
1130             backward_filter_data, algorithm_desc, &scratch_memory,
1131             output_profile_result);
1132       }
1133       if (!status && !output_profile_result) {
1134         SetError();
1135       }
1136     } else {
1137       SetErrorAndLogNoDnnSupport();
1138     }
1139   }
1140   return *this;
1141 }
1142 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1143 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1144     const dnn::BatchDescriptor &input_descriptor,
1145     const DeviceMemory<float> &input_data,
1146     const dnn::BatchDescriptor &output_descriptor,
1147     DeviceMemory<float> backward_output_data,
1148     const dnn::ConvolutionDescriptor &convolution_descriptor,
1149     const dnn::FilterDescriptor &filter_descriptor,
1150     DeviceMemory<float> *backward_filter_data,
1151     ScratchAllocator *scratch_allocator,
1152     const dnn::AlgorithmConfig &algorithm_config,
1153     dnn::ProfileResult *output_profile_result) {
1154   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1155             PARAM(output_descriptor), PARAM(backward_output_data),
1156             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1157             PARAM(backward_filter_data));
1158 
1159   if (ok()) {
1160     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1161       DeviceMemory<uint8> scratch_memory;
1162       dnn::AlgorithmDesc algorithm_desc;
1163       auto status =
1164           dnn->PrepareForConvolution(
1165                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1166                  input_data, filter_descriptor, *backward_filter_data,
1167                  output_descriptor, backward_output_data,
1168                  convolution_descriptor, algorithm_config, scratch_allocator,
1169                  &algorithm_desc, &scratch_memory)
1170               .ok();
1171       if (status) {
1172         status = dnn->DoConvolveBackwardFilter(
1173             this, input_descriptor, input_data, output_descriptor,
1174             backward_output_data, convolution_descriptor, filter_descriptor,
1175             backward_filter_data, algorithm_desc, &scratch_memory,
1176             output_profile_result);
1177       }
1178       if (!status && !output_profile_result) {
1179         SetError();
1180       }
1181     } else {
1182       SetErrorAndLogNoDnnSupport();
1183     }
1184   }
1185   return *this;
1186 }
1187 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1188 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1189     const dnn::BatchDescriptor &input_descriptor,
1190     const DeviceMemory<Eigen::half> &input_data,
1191     const dnn::BatchDescriptor &output_descriptor,
1192     DeviceMemory<Eigen::half> backward_output_data,
1193     const dnn::ConvolutionDescriptor &convolution_descriptor,
1194     const dnn::FilterDescriptor &filter_descriptor,
1195     DeviceMemory<Eigen::half> *backward_filter_data,
1196     ScratchAllocator *scratch_allocator,
1197     const dnn::AlgorithmConfig &algorithm_config,
1198     dnn::ProfileResult *output_profile_result) {
1199   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1200             PARAM(output_descriptor), PARAM(backward_output_data),
1201             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1202             PARAM(backward_filter_data));
1203 
1204   if (ok()) {
1205     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1206       DeviceMemory<uint8> scratch_memory;
1207       dnn::AlgorithmDesc algorithm_desc;
1208       auto status =
1209           dnn->PrepareForConvolution(
1210                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
1211                  input_data, filter_descriptor, *backward_filter_data,
1212                  output_descriptor, backward_output_data,
1213                  convolution_descriptor, algorithm_config, scratch_allocator,
1214                  &algorithm_desc, &scratch_memory)
1215               .ok();
1216       if (status) {
1217         status = dnn->DoConvolveBackwardFilter(
1218             this, input_descriptor, input_data, output_descriptor,
1219             backward_output_data, convolution_descriptor, filter_descriptor,
1220             backward_filter_data, algorithm_desc, &scratch_memory,
1221             output_profile_result);
1222       }
1223       if (!status && !output_profile_result) {
1224         SetError();
1225       }
1226     } else {
1227       SetErrorAndLogNoDnnSupport();
1228     }
1229   }
1230   return *this;
1231 }
1232 
1233 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1234 Stream &Stream::ThenConvolveBackwardBiasImpl(
1235     const dnn::BatchDescriptor &input_descriptor,
1236     const DeviceMemory<T> &input_data,
1237     const dnn::BatchDescriptor &bias_descriptor,
1238     DeviceMemory<T> *backward_bias_data) {
1239   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1240             PARAM(backward_bias_data));
1241 
1242   if (ok()) {
1243     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1244       CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1245                                              bias_descriptor,
1246                                              backward_bias_data));
1247     } else {
1248       SetErrorAndLogNoDnnSupport();
1249     }
1250   }
1251   return *this;
1252 }
1253 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1254 Stream &Stream::ThenConvolveBackwardBias(
1255     const dnn::BatchDescriptor &input_descriptor,
1256     const DeviceMemory<double> &input_data,
1257     const dnn::BatchDescriptor &bias_descriptor,
1258     DeviceMemory<double> *backward_bias_data) {
1259   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1260                                       bias_descriptor, backward_bias_data);
1261 }
1262 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1263 Stream &Stream::ThenConvolveBackwardBias(
1264     const dnn::BatchDescriptor &input_descriptor,
1265     const DeviceMemory<float> &input_data,
1266     const dnn::BatchDescriptor &bias_descriptor,
1267     DeviceMemory<float> *backward_bias_data) {
1268   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1269                                       bias_descriptor, backward_bias_data);
1270 }
1271 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1272 Stream &Stream::ThenConvolveBackwardBias(
1273     const dnn::BatchDescriptor &input_descriptor,
1274     const DeviceMemory<Eigen::half> &input_data,
1275     const dnn::BatchDescriptor &bias_descriptor,
1276     DeviceMemory<Eigen::half> *backward_bias_data) {
1277   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1278                                       bias_descriptor, backward_bias_data);
1279 }
1280 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1281 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1282                            const DeviceMemory<float> &weights,
1283                            const dnn::BatchDescriptor &input_dimensions,
1284                            const dnn::BatchDescriptor &output_dimensions,
1285                            DeviceMemory<float> *output_data) {
1286   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1287             PARAM(output_dimensions), PARAM(output_data));
1288 
1289   if (ok()) {
1290     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1291       CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1292                                output_dimensions, output_data));
1293     } else {
1294       SetErrorAndLogNoDnnSupport();
1295     }
1296   }
1297   return *this;
1298 }
1299 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1300 Stream &Stream::ThenMatMulQuantized(
1301     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1302     const DeviceMemory<float> &weight_scales,
1303     const dnn::BatchDescriptor &input_dimensions,
1304     const dnn::BatchDescriptor &output_dimensions,
1305     DeviceMemory<float> *output_data) {
1306   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1307             PARAM(input_dimensions), PARAM(output_dimensions),
1308             PARAM(output_data));
1309 
1310   if (ok()) {
1311     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1312       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1313                                         weight_scales, input_dimensions,
1314                                         output_dimensions, output_data));
1315     } else {
1316       SetErrorAndLogNoDnnSupport();
1317     }
1318   }
1319   return *this;
1320 }
1321 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1322 Stream &Stream::ThenMatMulQuantized(
1323     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1324     const DeviceMemory<float> &weight_scales,
1325     const dnn::BatchDescriptor &input_dimensions,
1326     const dnn::BatchDescriptor &output_dimensions,
1327     DeviceMemory<float> *output_data) {
1328   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1329             PARAM(input_dimensions), PARAM(output_dimensions),
1330             PARAM(output_data));
1331 
1332   if (ok()) {
1333     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1334       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1335                                         weight_scales, input_dimensions,
1336                                         output_dimensions, output_data));
1337     } else {
1338       SetErrorAndLogNoDnnSupport();
1339     }
1340   }
1341   return *this;
1342 }
1343 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1344 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1345                             const DeviceMemory<float> &biases,
1346                             const dnn::BatchDescriptor &dimensions,
1347                             DeviceMemory<float> *output_data) {
1348   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1349             PARAM(output_data));
1350 
1351   if (ok()) {
1352     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1353       CheckError(
1354           dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1355     } else {
1356       SetErrorAndLogNoDnnSupport();
1357     }
1358   }
1359   return *this;
1360 }
1361 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)1362 Stream &Stream::ThenPoolForward(
1363     const dnn::PoolingDescriptor &pooling_dimensions,
1364     const dnn::BatchDescriptor &input_dimensions,
1365     const DeviceMemory<double> &input_data,
1366     const dnn::BatchDescriptor &output_dimensions,
1367     DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
1368   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1369             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1370             PARAM(workspace_allocator));
1371 
1372   if (ok()) {
1373     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1374       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1375                                     input_data, output_dimensions, output_data,
1376                                     workspace_allocator));
1377     } else {
1378       SetError();
1379       LOG(WARNING)
1380           << "attempting to perform DNN operation using StreamExecutor "
1381              "without DNN support";
1382     }
1383   }
1384   return *this;
1385 }
1386 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)1387 Stream &Stream::ThenPoolForward(
1388     const dnn::PoolingDescriptor &pooling_dimensions,
1389     const dnn::BatchDescriptor &input_dimensions,
1390     const DeviceMemory<float> &input_data,
1391     const dnn::BatchDescriptor &output_dimensions,
1392     DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
1393   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1394             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1395             PARAM(workspace_allocator));
1396 
1397   if (ok()) {
1398     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1399       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1400                                     input_data, output_dimensions, output_data,
1401                                     workspace_allocator));
1402     } else {
1403       SetErrorAndLogNoDnnSupport();
1404     }
1405   }
1406   return *this;
1407 }
1408 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)1409 Stream &Stream::ThenPoolForward(
1410     const dnn::PoolingDescriptor &pooling_dimensions,
1411     const dnn::BatchDescriptor &input_dimensions,
1412     const DeviceMemory<Eigen::half> &input_data,
1413     const dnn::BatchDescriptor &output_dimensions,
1414     DeviceMemory<Eigen::half> *output_data,
1415     ScratchAllocator *workspace_allocator) {
1416   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1417             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1418             PARAM(workspace_allocator));
1419 
1420   if (ok()) {
1421     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1422       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1423                                     input_data, output_dimensions, output_data,
1424                                     workspace_allocator));
1425     } else {
1426       SetErrorAndLogNoDnnSupport();
1427     }
1428   }
1429   return *this;
1430 }
1431 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)1432 Stream &Stream::ThenPoolForward(
1433     const dnn::PoolingDescriptor &pooling_dimensions,
1434     const dnn::BatchDescriptor &input_dimensions,
1435     const DeviceMemory<int8> &input_data,
1436     const dnn::BatchDescriptor &output_dimensions,
1437     DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
1438   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1439             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1440             PARAM(workspace_allocator));
1441 
1442   if (ok()) {
1443     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1444       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1445                                     input_data, output_dimensions, output_data,
1446                                     workspace_allocator));
1447     } else {
1448       SetErrorAndLogNoDnnSupport();
1449     }
1450   }
1451   return *this;
1452 }
1453 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)1454 Stream &Stream::ThenPoolBackward(
1455     const dnn::PoolingDescriptor &pooling_dimensions,
1456     const dnn::BatchDescriptor &input_dimensions,
1457     const DeviceMemory<double> &input_data,
1458     const dnn::BatchDescriptor &output_dimensions,
1459     const DeviceMemory<double> &output_data,
1460     const DeviceMemory<double> &input_diff_data,
1461     DeviceMemory<double> *output_diff_data,
1462     ScratchAllocator *workspace_allocator) {
1463   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1464             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1465             PARAM(input_diff_data), PARAM(output_diff_data),
1466             PARAM(workspace_allocator));
1467 
1468   if (ok()) {
1469     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1470       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1471                                      input_data, output_dimensions, output_data,
1472                                      input_diff_data, output_diff_data,
1473                                      workspace_allocator));
1474     } else {
1475       SetError();
1476       LOG(WARNING)
1477           << "attempting to perform DNN operation using StreamExecutor "
1478              "without DNN support";
1479     }
1480   }
1481   return *this;
1482 }
1483 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)1484 Stream &Stream::ThenPoolBackward(
1485     const dnn::PoolingDescriptor &pooling_dimensions,
1486     const dnn::BatchDescriptor &input_dimensions,
1487     const DeviceMemory<float> &input_data,
1488     const dnn::BatchDescriptor &output_dimensions,
1489     const DeviceMemory<float> &output_data,
1490     const DeviceMemory<float> &input_diff_data,
1491     DeviceMemory<float> *output_diff_data,
1492     ScratchAllocator *workspace_allocator) {
1493   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1494             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1495             PARAM(input_diff_data), PARAM(output_diff_data),
1496             PARAM(workspace_allocator));
1497 
1498   if (ok()) {
1499     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1500       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1501                                      input_data, output_dimensions, output_data,
1502                                      input_diff_data, output_diff_data,
1503                                      workspace_allocator));
1504     } else {
1505       SetErrorAndLogNoDnnSupport();
1506     }
1507   }
1508   return *this;
1509 }
1510 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)1511 Stream &Stream::ThenPoolBackward(
1512     const dnn::PoolingDescriptor &pooling_dimensions,
1513     const dnn::BatchDescriptor &input_dimensions,
1514     const DeviceMemory<Eigen::half> &input_data,
1515     const dnn::BatchDescriptor &output_dimensions,
1516     const DeviceMemory<Eigen::half> &output_data,
1517     const DeviceMemory<Eigen::half> &input_diff_data,
1518     DeviceMemory<Eigen::half> *output_diff_data,
1519     ScratchAllocator *workspace_allocator) {
1520   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1521             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1522             PARAM(input_diff_data), PARAM(output_diff_data),
1523             PARAM(workspace_allocator));
1524 
1525   if (ok()) {
1526     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1527       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1528                                      input_data, output_dimensions, output_data,
1529                                      input_diff_data, output_diff_data,
1530                                      workspace_allocator));
1531     } else {
1532       SetErrorAndLogNoDnnSupport();
1533     }
1534   }
1535   return *this;
1536 }
1537 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1538 Stream &Stream::ThenNormalizeWithDimensions(
1539     const dnn::NormalizeDescriptor &normalize_descriptor,
1540     const dnn::BatchDescriptor &dimensions,
1541     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1542   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1543             PARAM(output_data));
1544 
1545   if (ok()) {
1546     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1547       CheckError(dnn->DoNormalizeWithDimensions(
1548           this, normalize_descriptor, dimensions, input_data, output_data));
1549     } else {
1550       SetErrorAndLogNoDnnSupport();
1551     }
1552   }
1553   return *this;
1554 }
1555 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)1556 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1557     const dnn::NormalizeDescriptor &normalize_descriptor,
1558     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1559     const DeviceMemory<float> &normalized_data,
1560     const DeviceMemory<float> &normalized_variable_gradient,
1561     DeviceMemory<float> *raw_variable_gradient,
1562     ScratchAllocator *workspace_allocator) {
1563   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1564             PARAM(normalized_data), PARAM(normalized_variable_gradient),
1565             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1566 
1567   if (ok()) {
1568     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1569       CheckError(dnn->DoNormalizeBackwardWithDimensions(
1570           this, normalize_descriptor, dimensions, raw_data, normalized_data,
1571           normalized_variable_gradient, raw_variable_gradient,
1572           workspace_allocator));
1573     } else {
1574       SetErrorAndLogNoDnnSupport();
1575     }
1576   }
1577   return *this;
1578 }
1579 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1580 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1581                              const dnn::BatchDescriptor &dimensions,
1582                              const DeviceMemory<float> &input_data,
1583                              DeviceMemory<float> *output_data) {
1584   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1585                                  output_data, /*options=*/0);
1586 }
1587 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1588 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1589                                         const dnn::BatchDescriptor &dimensions,
1590                                         const DeviceMemory<float> &input_data,
1591                                         DeviceMemory<float> *output_data,
1592                                         uint64 options) {
1593   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1594             PARAM(output_data), PARAM(options));
1595 
1596   if (ok()) {
1597     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1598       CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1599                                  output_data, options));
1600     } else {
1601       SetErrorAndLogNoDnnSupport();
1602     }
1603   }
1604   return *this;
1605 }
1606 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1607 Stream &Stream::ThenDepthConcatenate(
1608     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1609     port::ArraySlice<const DeviceMemory<float> *> input_data,
1610     DeviceMemory<float> *output_data) {
1611   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1612 
1613   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1614     if (input_dimensions[i].count() != input_dimensions[0].count() ||
1615         input_dimensions[i].height() != input_dimensions[0].height() ||
1616         input_dimensions[i].width() != input_dimensions[0].width()) {
1617       SetError();
1618       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1619                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1620                  << "input_dimensions[" << i
1621                  << "]: " << input_dimensions[i].ToString();
1622       return *this;
1623     }
1624   }
1625 
1626   if (ok()) {
1627     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1628       CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1629                                          output_data));
1630     } else {
1631       SetErrorAndLogNoDnnSupport();
1632     }
1633   }
1634   return *this;
1635 }
1636 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1637 Stream &Stream::ThenSpaceConcatenate(
1638     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1639     port::ArraySlice<const DeviceMemory<float> *> input_data,
1640     DeviceMemory<float> *output_data,
1641     dnn::SpaceConcatenateMode concat_direction) {
1642   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1643 
1644   // Check that the input dimensions of all the other batches match those of the
1645   // first batch.
1646   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1647     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1648         (input_dimensions[i].count() != input_dimensions[0].count() ||
1649          input_dimensions[i].height() != input_dimensions[0].height() ||
1650          input_dimensions[i].feature_map_count() !=
1651              input_dimensions[0].feature_map_count())) {
1652       SetError();
1653       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1654                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1655                  << "input_dimensions[" << i
1656                  << "]: " << input_dimensions[i].ToString();
1657       return *this;
1658     }
1659 
1660     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1661         (input_dimensions[i].count() != input_dimensions[0].count() ||
1662          input_dimensions[i].width() != input_dimensions[0].width() ||
1663          input_dimensions[i].feature_map_count() !=
1664              input_dimensions[0].feature_map_count())) {
1665       SetError();
1666       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1667                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1668                  << "input_dimensions[" << i
1669                  << "]: " << input_dimensions[i].ToString();
1670       return *this;
1671     }
1672   }
1673   if (ok()) {
1674     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1675       CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1676                                          output_data, concat_direction));
1677     } else {
1678       SetErrorAndLogNoDnnSupport();
1679     }
1680   }
1681   return *this;
1682 }
1683 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1684 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1685                             const DeviceMemory<float> &input_data,
1686                             const dnn::BatchDescriptor &output_dimensions,
1687                             DeviceMemory<float> *output_data) {
1688   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1689             PARAM(output_dimensions), PARAM(output_data));
1690 
1691   if (ok()) {
1692     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1693       CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1694                                 output_dimensions, output_data));
1695     } else {
1696       SetErrorAndLogNoDnnSupport();
1697     }
1698   }
1699   return *this;
1700 }
1701 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1702 Stream &Stream::ThenDepthToSpace(
1703     const dnn::BatchDescriptor &input_dimensions,
1704     const DeviceMemory<float> &input_data,
1705     const dnn::DepthToSpaceLayout &depth_to_space_layout,
1706     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1707   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1708             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1709             PARAM(output_data));
1710 
1711   if (ok()) {
1712     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1713       CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1714                                      depth_to_space_layout,
1715                                      sqrt_depth_reduction, output_data));
1716     } else {
1717       SetErrorAndLogNoDnnSupport();
1718     }
1719   }
1720   return *this;
1721 }
1722 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1723 Stream &Stream::ThenSpaceToDepth(
1724     const dnn::BatchDescriptor &input_dimensions,
1725     const DeviceMemory<float> &input_data,
1726     const dnn::DepthToSpaceLayout &space_to_depth_layout,
1727     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1728   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1729             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1730             PARAM(output_data));
1731 
1732   if (ok()) {
1733     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1734       CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1735                                      space_to_depth_layout, sqrt_depth_increase,
1736                                      output_data));
1737     } else {
1738       SetErrorAndLogNoDnnSupport();
1739     }
1740   }
1741   return *this;
1742 }
1743 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1744 Stream &Stream::ThenElementwiseOperate(
1745     dnn::ElementwiseOperation operation,
1746     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1747     port::ArraySlice<const DeviceMemory<float> *> input_data,
1748     const dnn::BatchDescriptor &output_dimensions,
1749     DeviceMemory<float> *output_data) {
1750   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1751             PARAM(output_dimensions), PARAM(output_data));
1752 
1753   if (ok()) {
1754     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1755       CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1756                                            input_data, output_dimensions,
1757                                            output_data));
1758     } else {
1759       SetErrorAndLogNoDnnSupport();
1760     }
1761   }
1762   return *this;
1763 }
1764 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1765 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1766     dnn::ElementwiseOperation operation,
1767     port::ArraySlice<int> input_multiplicands, int output_divisor,
1768     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1769     port::ArraySlice<const DeviceMemory<float> *> input_data,
1770     const dnn::BatchDescriptor &output_dimensions,
1771     DeviceMemory<float> *output_data) {
1772   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1773             PARAM(input_dimensions), PARAM(input_data),
1774             PARAM(output_dimensions), PARAM(output_data));
1775 
1776   if (ok()) {
1777     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1778       CheckError(dnn->DoElementwiseOperateScaledQuantized(
1779           this, operation, input_multiplicands, output_divisor,
1780           input_dimensions, input_data, output_dimensions, output_data));
1781     } else {
1782       SetErrorAndLogNoDnnSupport();
1783     }
1784   }
1785   return *this;
1786 }
1787 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1788 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1789                           const DeviceMemory<float> &input_data, int64 left_pad,
1790                           int64 right_pad, int64 top_pad, int64 bottom_pad,
1791                           DeviceMemory<float> *output_data) {
1792   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1793             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1794             PARAM(output_data));
1795 
1796   if (ok()) {
1797     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1798       CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1799                               top_pad, bottom_pad, output_data));
1800     } else {
1801       SetErrorAndLogNoDnnSupport();
1802     }
1803   }
1804   return *this;
1805 }
1806 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1807 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1808                             const DeviceMemory<float> &input_data,
1809                             int64 left_trim, int64 right_trim, int64 top_trim,
1810                             int64 bottom_trim,
1811                             DeviceMemory<float> *output_data) {
1812   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1813             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1814             PARAM(output_data));
1815 
1816   if (ok()) {
1817     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1818       CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1819                                 right_trim, top_trim, bottom_trim,
1820                                 output_data));
1821     } else {
1822       SetErrorAndLogNoDnnSupport();
1823     }
1824   }
1825   return *this;
1826 }
1827 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1828 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1829                                 const DeviceMemory<float> &input_data,
1830                                 int64 replicate_x, int64 replicate_y,
1831                                 DeviceMemory<float> *output_data) {
1832   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1833             PARAM(replicate_y), PARAM(output_data));
1834 
1835   if (ok()) {
1836     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1837       CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1838                                     replicate_y, output_data));
1839     } else {
1840       SetErrorAndLogNoDnnSupport();
1841     }
1842   }
1843   return *this;
1844 }
1845 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1846 Stream &Stream::ThenMemcpyD2HQuantized(
1847     const DeviceMemory<float> &gpu_unquantized_src,
1848     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1849   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1850             PARAM(size));
1851 
1852   if (ok()) {
1853     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1854       CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1855                                            host_dst, size));
1856     } else {
1857       SetErrorAndLogNoDnnSupport();
1858     }
1859   }
1860   return *this;
1861 }
1862 
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1863 Stream &Stream::ThenMemcpyH2DQuantized(
1864     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1865     DeviceMemory<float> *gpu_unquantized_dst) {
1866   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1867             PARAM(gpu_unquantized_dst));
1868 
1869   if (ok()) {
1870     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1871       CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1872                                            gpu_unquantized_dst));
1873     } else {
1874       SetErrorAndLogNoDnnSupport();
1875     }
1876   }
1877   return *this;
1878 }
1879 
GetOrCreateSubStream()1880 Stream *Stream::GetOrCreateSubStream() {
1881   absl::MutexLock lock(&mu_);
1882 
1883   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1884   // we encounter along the way.
1885   for (int64 index = 0; index < sub_streams_.size();) {
1886     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1887     if (pair.second) {
1888       // The sub_stream is reusable.
1889       Stream *sub_stream = pair.first.get();
1890       if (sub_stream->ok()) {
1891         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1892                 << sub_stream->DebugStreamPointers();
1893         pair.second = false;
1894         return sub_stream;
1895       }
1896 
1897       // The stream is reusable and not ok. Streams have a monotonic state
1898       // machine; the stream will remain in !ok forever. Swap it with the last
1899       // stream and pop it off.
1900       const int64 last = sub_streams_.size() - 1;
1901       if (index != last) {
1902         std::swap(pair, sub_streams_[last]);
1903       }
1904       sub_streams_.pop_back();
1905       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1906               << sub_stream->DebugStreamPointers();
1907     } else {
1908       // The sub_stream is not reusable, move on to the next one.
1909       ++index;
1910     }
1911   }
1912 
1913   // No streams are reusable; create a new stream.
1914   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1915                             false);
1916   Stream *sub_stream = sub_streams_.back().first.get();
1917   sub_stream->Init();
1918   if (!sub_stream->ok_) {
1919     LOG(ERROR) << "sub-stream failed to be initialized";
1920   }
1921   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1922           << sub_stream->DebugStreamPointers();
1923 
1924   return sub_stream;
1925 }
1926 
ReturnSubStream(Stream * sub_stream)1927 void Stream::ReturnSubStream(Stream *sub_stream) {
1928   absl::MutexLock lock(&mu_);
1929 
1930   // Look for the sub-stream.
1931   for (int64 index = 0; index < sub_streams_.size(); ++index) {
1932     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1933     if (pair.first.get() != sub_stream) {
1934       continue;
1935     }
1936 
1937     // Found the sub_stream.
1938     if (sub_stream->ok()) {
1939       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1940               << sub_stream->DebugStreamPointers();
1941       pair.second = true;
1942     } else {
1943       // The returned stream is not ok. Streams have a monotonic state
1944       // machine; the stream will remain in !ok forever. Swap it with the last
1945       // stream and pop it off.
1946       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1947               << sub_stream->DebugStreamPointers();
1948       const int64 last = sub_streams_.size() - 1;
1949       if (index != last) {
1950         std::swap(pair, sub_streams_[last]);
1951       }
1952       sub_streams_.pop_back();
1953     }
1954     return;
1955   }
1956 
1957   LOG(FATAL) << DebugStreamPointers()
1958              << " did not create the returned sub-stream "
1959              << sub_stream->DebugStreamPointers();
1960 }
1961 
ThenStartTimer(Timer * t)1962 Stream &Stream::ThenStartTimer(Timer *t) {
1963   VLOG_CALL(PARAM(t));
1964 
1965   if (ok()) {
1966     CheckError(parent_->StartTimer(this, t));
1967   } else {
1968     LOG(INFO) << DebugStreamPointers()
1969               << " did not enqueue 'start timer': " << t;
1970   }
1971   return *this;
1972 }
1973 
ThenStopTimer(Timer * t)1974 Stream &Stream::ThenStopTimer(Timer *t) {
1975   VLOG_CALL(PARAM(t));
1976 
1977   if (ok()) {
1978     CheckError(parent_->StopTimer(this, t));
1979   } else {
1980     LOG(INFO) << DebugStreamPointers()
1981               << " did not enqueue 'stop timer': " << t;
1982   }
1983   return *this;
1984 }
1985 
ThenWaitFor(Stream * other)1986 Stream &Stream::ThenWaitFor(Stream *other) {
1987   VLOG_CALL(PARAM(other));
1988 
1989   CHECK(this != other) << "stream cannot wait for itself";
1990   if (ok() && other->ok()) {
1991     CheckError(parent_->CreateStreamDependency(this, other));
1992   } else {
1993     SetError();
1994     LOG(INFO) << DebugStreamPointers() << " did not wait for "
1995               << other->DebugStreamPointers();
1996   }
1997   return *this;
1998 }
1999 
ThenWaitFor(Event * event)2000 Stream &Stream::ThenWaitFor(Event *event) {
2001   VLOG_CALL(PARAM(event));
2002 
2003   if (ok()) {
2004     port::Status status = parent_->WaitForEvent(this, event);
2005     if (!status.ok()) {
2006       LOG(ERROR) << "Error waiting for event in stream: "
2007                  << status.error_message()
2008                  << "; not marking stream as bad, as the Event object may be "
2009                  << "at fault. Monitor for further errors.";
2010     }
2011   } else {
2012     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
2013   }
2014   return *this;
2015 }
2016 
2017 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
2018 // functions and logs for errors.
2019 template <typename... Args>
2020 struct ThenBlasImpl {
2021   // blas_func is the DoBlasXXX member function pointer, and args are its
2022   // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl2023   Stream &operator()(Stream *stream,
2024                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2025                      Args... args) {
2026     return Run(stream, blas_func, /*record_error=*/true, args...);
2027   }
2028 
2029   // Like operator(), but only calls stream->CheckError() if record_error is
2030   // true.
2031   Stream &Run(Stream *stream,
2032               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2033               bool record_error, Args... args);
2034 };
2035 
2036 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)2037 Stream &ThenBlasImpl<Args...>::Run(
2038     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
2039     bool record_error, Args... args) {
2040   if (stream->ok()) {
2041     bool ok;
2042     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
2043       ok = (blas->*blas_func)(stream, args...);
2044     } else {
2045       LOG(WARNING)
2046           << "attempting to perform BLAS operation using StreamExecutor "
2047              "without BLAS support";
2048       ok = false;
2049     }
2050     if (record_error) {
2051       stream->CheckError(ok);
2052     }
2053   }
2054   return *stream;
2055 }
2056 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2057 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
2058                              int incx, DeviceMemory<float> *result) {
2059   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2060 
2061   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2062       impl;
2063   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2064               result);
2065 }
2066 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2067 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
2068                              int incx, DeviceMemory<double> *result) {
2069   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2070 
2071   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2072                DeviceMemory<double> *> impl;
2073   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2074               result);
2075 }
2076 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2077 Stream &Stream::ThenBlasAsum(uint64 elem_count,
2078                              const DeviceMemory<std::complex<float>> &x,
2079                              int incx, DeviceMemory<float> *result) {
2080   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2081 
2082   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2083                DeviceMemory<float> *> impl;
2084   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2085               result);
2086 }
2087 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2088 Stream &Stream::ThenBlasAsum(uint64 elem_count,
2089                              const DeviceMemory<std::complex<double>> &x,
2090                              int incx, DeviceMemory<double> *result) {
2091   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2092 
2093   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2094                DeviceMemory<double> *> impl;
2095   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
2096               result);
2097 }
2098 
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2099 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
2100                              const DeviceMemory<float> &x, int incx,
2101                              DeviceMemory<float> *y, int incy) {
2102   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2103             PARAM(incy));
2104 
2105   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
2106                DeviceMemory<float> *, int> impl;
2107   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2108               y, incy);
2109 }
2110 
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2111 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
2112                              const DeviceMemory<double> &x, int incx,
2113                              DeviceMemory<double> *y, int incy) {
2114   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2115             PARAM(incy));
2116 
2117   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
2118                DeviceMemory<double> *, int> impl;
2119   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2120               y, incy);
2121 }
2122 
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2123 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
2124                              const DeviceMemory<std::complex<float>> &x,
2125                              int incx, DeviceMemory<std::complex<float>> *y,
2126                              int incy) {
2127   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2128             PARAM(incy));
2129 
2130   ThenBlasImpl<uint64, std::complex<float>,
2131                const DeviceMemory<std::complex<float>> &, int,
2132                DeviceMemory<std::complex<float>> *, int> impl;
2133   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2134               y, incy);
2135 }
2136 
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2137 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
2138                              const DeviceMemory<std::complex<double>> &x,
2139                              int incx, DeviceMemory<std::complex<double>> *y,
2140                              int incy) {
2141   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2142             PARAM(incy));
2143 
2144   ThenBlasImpl<uint64, std::complex<double>,
2145                const DeviceMemory<std::complex<double>> &, int,
2146                DeviceMemory<std::complex<double>> *, int> impl;
2147   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2148               y, incy);
2149 }
2150 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2151 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2152                              int incx, DeviceMemory<float> *y, int incy) {
2153   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2154 
2155   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2156                int> impl;
2157   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2158               incy);
2159 }
2160 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2161 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2162                              int incx, DeviceMemory<double> *y, int incy) {
2163   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2164 
2165   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2166                DeviceMemory<double> *, int> impl;
2167   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2168               incy);
2169 }
2170 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2171 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2172                              const DeviceMemory<std::complex<float>> &x,
2173                              int incx, DeviceMemory<std::complex<float>> *y,
2174                              int incy) {
2175   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2176 
2177   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2178                DeviceMemory<std::complex<float>> *, int> impl;
2179   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2180               incy);
2181 }
2182 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2183 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2184                              const DeviceMemory<std::complex<double>> &x,
2185                              int incx, DeviceMemory<std::complex<double>> *y,
2186                              int incy) {
2187   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2188 
2189   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2190                DeviceMemory<std::complex<double>> *, int> impl;
2191   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2192               incy);
2193 }
2194 
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2195 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2196                             int incx, const DeviceMemory<float> &y, int incy,
2197                             DeviceMemory<float> *result) {
2198   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2199             PARAM(result));
2200 
2201   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2202                const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2203   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2204               result);
2205 }
2206 
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2207 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2208                             int incx, const DeviceMemory<double> &y, int incy,
2209                             DeviceMemory<double> *result) {
2210   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2211             PARAM(result));
2212 
2213   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2214                const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2215   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2216               result);
2217 }
2218 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2219 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2220                              const DeviceMemory<std::complex<float>> &x,
2221                              int incx,
2222                              const DeviceMemory<std::complex<float>> &y,
2223                              int incy,
2224                              DeviceMemory<std::complex<float>> *result) {
2225   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2226             PARAM(result));
2227 
2228   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2229                const DeviceMemory<std::complex<float>> &, int,
2230                DeviceMemory<std::complex<float>> *> impl;
2231   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2232               incy, result);
2233 }
2234 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2235 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2236                              const DeviceMemory<std::complex<double>> &x,
2237                              int incx,
2238                              const DeviceMemory<std::complex<double>> &y,
2239                              int incy,
2240                              DeviceMemory<std::complex<double>> *result) {
2241   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2242             PARAM(result));
2243 
2244   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2245                const DeviceMemory<std::complex<double>> &, int,
2246                DeviceMemory<std::complex<double>> *> impl;
2247   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2248               incy, result);
2249 }
2250 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2251 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2252                              const DeviceMemory<std::complex<float>> &x,
2253                              int incx,
2254                              const DeviceMemory<std::complex<float>> &y,
2255                              int incy,
2256                              DeviceMemory<std::complex<float>> *result) {
2257   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2258             PARAM(result));
2259 
2260   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2261                const DeviceMemory<std::complex<float>> &, int,
2262                DeviceMemory<std::complex<float>> *> impl;
2263   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2264               incy, result);
2265 }
2266 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2267 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2268                              const DeviceMemory<std::complex<double>> &x,
2269                              int incx,
2270                              const DeviceMemory<std::complex<double>> &y,
2271                              int incy,
2272                              DeviceMemory<std::complex<double>> *result) {
2273   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2274             PARAM(result));
2275 
2276   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2277                const DeviceMemory<std::complex<double>> &, int,
2278                DeviceMemory<std::complex<double>> *> impl;
2279   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2280               incy, result);
2281 }
2282 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2283 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2284                              int incx, DeviceMemory<float> *result) {
2285   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2286 
2287   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2288       impl;
2289   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2290               result);
2291 }
2292 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2293 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2294                              int incx, DeviceMemory<double> *result) {
2295   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2296 
2297   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2298                DeviceMemory<double> *> impl;
2299   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2300               result);
2301 }
2302 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2303 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2304                              const DeviceMemory<std::complex<float>> &x,
2305                              int incx, DeviceMemory<float> *result) {
2306   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2307 
2308   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2309                DeviceMemory<float> *> impl;
2310   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2311               result);
2312 }
2313 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2314 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2315                              const DeviceMemory<std::complex<double>> &x,
2316                              int incx, DeviceMemory<double> *result) {
2317   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2318 
2319   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2320                DeviceMemory<double> *> impl;
2321   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2322               result);
2323 }
2324 
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2325 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2326                             DeviceMemory<float> *y, int incy, float c,
2327                             float s) {
2328   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2329             PARAM(c), PARAM(s));
2330 
2331   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2332                float, float> impl;
2333   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2334               c, s);
2335 }
2336 
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2337 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2338                             int incx, DeviceMemory<double> *y, int incy,
2339                             double c, double s) {
2340   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2341             PARAM(c), PARAM(s));
2342 
2343   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2344                double, double> impl;
2345   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2346               c, s);
2347 }
2348 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2349 Stream &Stream::ThenBlasRot(uint64 elem_count,
2350                             DeviceMemory<std::complex<float>> *x, int incx,
2351                             DeviceMemory<std::complex<float>> *y, int incy,
2352                             float c, float s) {
2353   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2354             PARAM(c), PARAM(s));
2355 
2356   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2357                DeviceMemory<std::complex<float>> *, int, float, float> impl;
2358   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2359               c, s);
2360 }
2361 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2362 Stream &Stream::ThenBlasRot(uint64 elem_count,
2363                             DeviceMemory<std::complex<double>> *x, int incx,
2364                             DeviceMemory<std::complex<double>> *y, int incy,
2365                             double c, double s) {
2366   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2367             PARAM(c), PARAM(s));
2368 
2369   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2370                DeviceMemory<std::complex<double>> *, int, double, double> impl;
2371   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2372               c, s);
2373 }
2374 
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2375 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2376                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
2377   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2378 
2379   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2380                DeviceMemory<float> *, DeviceMemory<float> *> impl;
2381   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2382 }
2383 
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2384 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2385                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
2386   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2387 
2388   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2389                DeviceMemory<double> *, DeviceMemory<double> *> impl;
2390   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2391 }
2392 
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2393 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2394                              DeviceMemory<std::complex<float>> *b,
2395                              DeviceMemory<float> *c,
2396                              DeviceMemory<std::complex<float>> *s) {
2397   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2398 
2399   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2400                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2401                DeviceMemory<std::complex<float>> *> impl;
2402   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2403 }
2404 
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2405 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2406                              DeviceMemory<std::complex<double>> *b,
2407                              DeviceMemory<double> *c,
2408                              DeviceMemory<std::complex<double>> *s) {
2409   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2410 
2411   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2412                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2413                DeviceMemory<std::complex<double>> *> impl;
2414   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2415 }
2416 
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2417 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2418                              int incx, DeviceMemory<float> *y, int incy,
2419                              const DeviceMemory<float> &param) {
2420   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2421             PARAM(param));
2422 
2423   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2424                const DeviceMemory<float> &> impl;
2425   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2426               incy, param);
2427 }
2428 
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2429 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2430                              int incx, DeviceMemory<double> *y, int incy,
2431                              const DeviceMemory<double> &param) {
2432   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2433             PARAM(param));
2434 
2435   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2436                const DeviceMemory<double> &> impl;
2437   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2438               incy, param);
2439 }
2440 
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2441 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2442                               DeviceMemory<float> *x1,
2443                               const DeviceMemory<float> &y1,
2444                               DeviceMemory<float> *param) {
2445   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2446 
2447   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2448                DeviceMemory<float> *, const DeviceMemory<float> &,
2449                DeviceMemory<float> *> impl;
2450   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2451 }
2452 
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2453 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2454                               DeviceMemory<double> *d2,
2455                               DeviceMemory<double> *x1,
2456                               const DeviceMemory<double> &y1,
2457                               DeviceMemory<double> *param) {
2458   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2459 
2460   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2461                DeviceMemory<double> *, const DeviceMemory<double> &,
2462                DeviceMemory<double> *> impl;
2463   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2464 }
2465 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2466 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2467                              DeviceMemory<float> *x, int incx) {
2468   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2469 
2470   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2471   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2472 }
2473 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2474 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2475                              DeviceMemory<double> *x, int incx) {
2476   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2477 
2478   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2479   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2480 }
2481 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2482 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2483                              DeviceMemory<std::complex<float>> *x, int incx) {
2484   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2485 
2486   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2487   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2488 }
2489 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2490 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2491                              DeviceMemory<std::complex<double>> *x, int incx) {
2492   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2493 
2494   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2495   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2496 }
2497 
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2498 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2499                              DeviceMemory<std::complex<float>> *x, int incx) {
2500   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2501 
2502   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2503                int> impl;
2504   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2505 }
2506 
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2507 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2508                              DeviceMemory<std::complex<double>> *x, int incx) {
2509   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2510 
2511   ThenBlasImpl<uint64, std::complex<double>,
2512                DeviceMemory<std::complex<double>> *, int> impl;
2513   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2514 }
2515 
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2516 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2517                              int incx, DeviceMemory<float> *y, int incy) {
2518   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2519 
2520   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2521       impl;
2522   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2523               incy);
2524 }
2525 
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2526 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2527                              int incx, DeviceMemory<double> *y, int incy) {
2528   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2529 
2530   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2531       impl;
2532   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2533               incy);
2534 }
2535 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2536 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2537                              DeviceMemory<std::complex<float>> *x, int incx,
2538                              DeviceMemory<std::complex<float>> *y, int incy) {
2539   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2540 
2541   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2542                DeviceMemory<std::complex<float>> *, int> impl;
2543   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2544               incy);
2545 }
2546 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2547 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2548                              DeviceMemory<std::complex<double>> *x, int incx,
2549                              DeviceMemory<std::complex<double>> *y, int incy) {
2550   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2551 
2552   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2553                DeviceMemory<std::complex<double>> *, int> impl;
2554   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2555               incy);
2556 }
2557 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2558 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2559                               int incx, DeviceMemory<int> *result) {
2560   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2561 
2562   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2563       impl;
2564   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2565               result);
2566 }
2567 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2568 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2569                               int incx, DeviceMemory<int> *result) {
2570   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2571 
2572   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2573       impl;
2574   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2575               result);
2576 }
2577 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2578 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2579                               const DeviceMemory<std::complex<float>> &x,
2580                               int incx, DeviceMemory<int> *result) {
2581   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2582 
2583   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2584                DeviceMemory<int> *> impl;
2585   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2586               result);
2587 }
2588 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2589 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2590                               const DeviceMemory<std::complex<double>> &x,
2591                               int incx, DeviceMemory<int> *result) {
2592   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2593 
2594   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2595                DeviceMemory<int> *> impl;
2596   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2597               result);
2598 }
2599 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2600 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2601                               int incx, DeviceMemory<int> *result) {
2602   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2603 
2604   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2605       impl;
2606   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2607               result);
2608 }
2609 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2610 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2611                               int incx, DeviceMemory<int> *result) {
2612   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2613 
2614   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2615       impl;
2616   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2617               result);
2618 }
2619 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2620 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2621                               const DeviceMemory<std::complex<float>> &x,
2622                               int incx, DeviceMemory<int> *result) {
2623   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2624 
2625   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2626                DeviceMemory<int> *> impl;
2627   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2628               result);
2629 }
2630 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2631 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2632                               const DeviceMemory<std::complex<double>> &x,
2633                               int incx, DeviceMemory<int> *result) {
2634   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2635 
2636   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2637                DeviceMemory<int> *> impl;
2638   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2639               result);
2640 }
2641 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2642 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2643                              uint64 kl, uint64 ku, float alpha,
2644                              const DeviceMemory<float> &a, int lda,
2645                              const DeviceMemory<float> &x, int incx, float beta,
2646                              DeviceMemory<float> *y, int incy) {
2647   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2648             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2649             PARAM(beta), PARAM(y), PARAM(incy));
2650 
2651   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2652                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2653                int, float, DeviceMemory<float> *, int> impl;
2654   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2655               a, lda, x, incx, beta, y, incy);
2656 }
2657 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2658 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2659                              uint64 kl, uint64 ku, double alpha,
2660                              const DeviceMemory<double> &a, int lda,
2661                              const DeviceMemory<double> &x, int incx,
2662                              double beta, DeviceMemory<double> *y, int incy) {
2663   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2664             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2665             PARAM(beta), PARAM(y), PARAM(incy));
2666 
2667   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2668                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2669                int, double, DeviceMemory<double> *, int> impl;
2670   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2671               a, lda, x, incx, beta, y, incy);
2672 }
2673 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2674 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2675                              uint64 kl, uint64 ku, std::complex<float> alpha,
2676                              const DeviceMemory<std::complex<float>> &a,
2677                              int lda,
2678                              const DeviceMemory<std::complex<float>> &x,
2679                              int incx, std::complex<float> beta,
2680                              DeviceMemory<std::complex<float>> *y, int incy) {
2681   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2682             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2683             PARAM(beta), PARAM(y), PARAM(incy));
2684 
2685   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2686                std::complex<float>, const DeviceMemory<std::complex<float>> &,
2687                int, const DeviceMemory<std::complex<float>> &, int,
2688                std::complex<float>, DeviceMemory<std::complex<float>> *,
2689                int> impl;
2690   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2691               a, lda, x, incx, beta, y, incy);
2692 }
2693 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2694 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2695                              uint64 kl, uint64 ku, std::complex<double> alpha,
2696                              const DeviceMemory<std::complex<double>> &a,
2697                              int lda,
2698                              const DeviceMemory<std::complex<double>> &x,
2699                              int incx, std::complex<double> beta,
2700                              DeviceMemory<std::complex<double>> *y, int incy) {
2701   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2702             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2703             PARAM(beta), PARAM(y), PARAM(incy));
2704 
2705   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2706                std::complex<double>, const DeviceMemory<std::complex<double>> &,
2707                int, const DeviceMemory<std::complex<double>> &, int,
2708                std::complex<double>, DeviceMemory<std::complex<double>> *,
2709                int> impl;
2710   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2711               a, lda, x, incx, beta, y, incy);
2712 }
2713 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2714 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2715                              float alpha, const DeviceMemory<float> &a, int lda,
2716                              const DeviceMemory<float> &x, int incx, float beta,
2717                              DeviceMemory<float> *y, int incy) {
2718   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2719             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2720             PARAM(incy));
2721 
2722   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2723                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2724                int, float, DeviceMemory<float> *, int> impl;
2725   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2726               x, incx, beta, y, incy);
2727 }
2728 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2729 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2730                              double alpha, const DeviceMemory<double> &a,
2731                              int lda, const DeviceMemory<double> &x, int incx,
2732                              double beta, DeviceMemory<double> *y, int incy) {
2733   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2734             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2735             PARAM(incy));
2736 
2737   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2738                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2739                int, double, DeviceMemory<double> *, int> impl;
2740   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2741               x, incx, beta, y, incy);
2742 }
2743 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2744 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2745                              std::complex<float> alpha,
2746                              const DeviceMemory<std::complex<float>> &a,
2747                              int lda,
2748                              const DeviceMemory<std::complex<float>> &x,
2749                              int incx, std::complex<float> beta,
2750                              DeviceMemory<std::complex<float>> *y, int incy) {
2751   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2752             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2753             PARAM(incy));
2754 
2755   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2756                const DeviceMemory<std::complex<float>> &, int,
2757                const DeviceMemory<std::complex<float>> &, int,
2758                std::complex<float>, DeviceMemory<std::complex<float>> *,
2759                int> impl;
2760   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2761               x, incx, beta, y, incy);
2762 }
2763 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2764 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2765                              std::complex<double> alpha,
2766                              const DeviceMemory<std::complex<double>> &a,
2767                              int lda,
2768                              const DeviceMemory<std::complex<double>> &x,
2769                              int incx, std::complex<double> beta,
2770                              DeviceMemory<std::complex<double>> *y, int incy) {
2771   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2772             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2773             PARAM(incy));
2774 
2775   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2776                const DeviceMemory<std::complex<double>> &, int,
2777                const DeviceMemory<std::complex<double>> &, int,
2778                std::complex<double>, DeviceMemory<std::complex<double>> *,
2779                int> impl;
2780   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2781               x, incx, beta, y, incy);
2782 }
2783 
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2784 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2785                             const DeviceMemory<float> &x, int incx,
2786                             const DeviceMemory<float> &y, int incy,
2787                             DeviceMemory<float> *a, int lda) {
2788   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2789             PARAM(incy), PARAM(a), PARAM(lda));
2790 
2791   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2792                const DeviceMemory<float> &, int, DeviceMemory<float> *,
2793                int> impl;
2794   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2795               incy, a, lda);
2796 }
2797 
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2798 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2799                             const DeviceMemory<double> &x, int incx,
2800                             const DeviceMemory<double> &y, int incy,
2801                             DeviceMemory<double> *a, int lda) {
2802   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2803             PARAM(incy), PARAM(a), PARAM(lda));
2804 
2805   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2806                const DeviceMemory<double> &, int, DeviceMemory<double> *,
2807                int> impl;
2808   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2809               incy, a, lda);
2810 }
2811 
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2812 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2813                              const DeviceMemory<std::complex<float>> &x,
2814                              int incx,
2815                              const DeviceMemory<std::complex<float>> &y,
2816                              int incy, DeviceMemory<std::complex<float>> *a,
2817                              int lda) {
2818   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2819             PARAM(incy), PARAM(a), PARAM(lda));
2820 
2821   ThenBlasImpl<uint64, uint64, std::complex<float>,
2822                const DeviceMemory<std::complex<float>> &, int,
2823                const DeviceMemory<std::complex<float>> &, int,
2824                DeviceMemory<std::complex<float>> *, int> impl;
2825   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2826               incy, a, lda);
2827 }
2828 
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2829 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2830                              const DeviceMemory<std::complex<double>> &x,
2831                              int incx,
2832                              const DeviceMemory<std::complex<double>> &y,
2833                              int incy, DeviceMemory<std::complex<double>> *a,
2834                              int lda) {
2835   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2836             PARAM(incy), PARAM(a), PARAM(lda));
2837 
2838   ThenBlasImpl<uint64, uint64, std::complex<double>,
2839                const DeviceMemory<std::complex<double>> &, int,
2840                const DeviceMemory<std::complex<double>> &, int,
2841                DeviceMemory<std::complex<double>> *, int> impl;
2842   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2843               incy, a, lda);
2844 }
2845 
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2846 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2847                              const DeviceMemory<std::complex<float>> &x,
2848                              int incx,
2849                              const DeviceMemory<std::complex<float>> &y,
2850                              int incy, DeviceMemory<std::complex<float>> *a,
2851                              int lda) {
2852   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2853             PARAM(incy), PARAM(a), PARAM(lda));
2854 
2855   ThenBlasImpl<uint64, uint64, std::complex<float>,
2856                const DeviceMemory<std::complex<float>> &, int,
2857                const DeviceMemory<std::complex<float>> &, int,
2858                DeviceMemory<std::complex<float>> *, int> impl;
2859   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2860               incy, a, lda);
2861 }
2862 
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2863 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2864                              const DeviceMemory<std::complex<double>> &x,
2865                              int incx,
2866                              const DeviceMemory<std::complex<double>> &y,
2867                              int incy, DeviceMemory<std::complex<double>> *a,
2868                              int lda) {
2869   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2870             PARAM(incy), PARAM(a), PARAM(lda));
2871 
2872   ThenBlasImpl<uint64, uint64, std::complex<double>,
2873                const DeviceMemory<std::complex<double>> &, int,
2874                const DeviceMemory<std::complex<double>> &, int,
2875                DeviceMemory<std::complex<double>> *, int> impl;
2876   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2877               incy, a, lda);
2878 }
2879 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2880 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2881                              std::complex<float> alpha,
2882                              const DeviceMemory<std::complex<float>> &a,
2883                              int lda,
2884                              const DeviceMemory<std::complex<float>> &x,
2885                              int incx, std::complex<float> beta,
2886                              DeviceMemory<std::complex<float>> *y, int incy) {
2887   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2888             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2889 
2890   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2891                const DeviceMemory<std::complex<float>> &, int,
2892                const DeviceMemory<std::complex<float>> &, int,
2893                std::complex<float>, DeviceMemory<std::complex<float>> *,
2894                int> impl;
2895   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2896               x, incx, beta, y, incy);
2897 }
2898 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2899 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2900                              std::complex<double> alpha,
2901                              const DeviceMemory<std::complex<double>> &a,
2902                              int lda,
2903                              const DeviceMemory<std::complex<double>> &x,
2904                              int incx, std::complex<double> beta,
2905                              DeviceMemory<std::complex<double>> *y, int incy) {
2906   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2907             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2908 
2909   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2910                const DeviceMemory<std::complex<double>> &, int,
2911                const DeviceMemory<std::complex<double>> &, int,
2912                std::complex<double>, DeviceMemory<std::complex<double>> *,
2913                int> impl;
2914   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2915               x, incx, beta, y, incy);
2916 }
2917 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2918 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2919                              std::complex<float> alpha,
2920                              const DeviceMemory<std::complex<float>> &a,
2921                              int lda,
2922                              const DeviceMemory<std::complex<float>> &x,
2923                              int incx, std::complex<float> beta,
2924                              DeviceMemory<std::complex<float>> *y, int incy) {
2925   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2926             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2927 
2928   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2929                const DeviceMemory<std::complex<float>> &, int,
2930                const DeviceMemory<std::complex<float>> &, int,
2931                std::complex<float>, DeviceMemory<std::complex<float>> *,
2932                int> impl;
2933   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2934               incx, beta, y, incy);
2935 }
2936 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2937 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2938                              std::complex<double> alpha,
2939                              const DeviceMemory<std::complex<double>> &a,
2940                              int lda,
2941                              const DeviceMemory<std::complex<double>> &x,
2942                              int incx, std::complex<double> beta,
2943                              DeviceMemory<std::complex<double>> *y, int incy) {
2944   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2945             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2946 
2947   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2948                const DeviceMemory<std::complex<double>> &, int,
2949                const DeviceMemory<std::complex<double>> &, int,
2950                std::complex<double>, DeviceMemory<std::complex<double>> *,
2951                int> impl;
2952   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2953               incx, beta, y, incy);
2954 }
2955 
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2956 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2957                             const DeviceMemory<std::complex<float>> &x,
2958                             int incx, DeviceMemory<std::complex<float>> *a,
2959                             int lda) {
2960   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2961             PARAM(a), PARAM(lda));
2962 
2963   ThenBlasImpl<blas::UpperLower, uint64, float,
2964                const DeviceMemory<std::complex<float>> &, int,
2965                DeviceMemory<std::complex<float>> *, int> impl;
2966   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2967               lda);
2968 }
2969 
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2970 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2971                             const DeviceMemory<std::complex<double>> &x,
2972                             int incx, DeviceMemory<std::complex<double>> *a,
2973                             int lda) {
2974   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2975             PARAM(a), PARAM(lda));
2976 
2977   ThenBlasImpl<blas::UpperLower, uint64, double,
2978                const DeviceMemory<std::complex<double>> &, int,
2979                DeviceMemory<std::complex<double>> *, int> impl;
2980   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2981               lda);
2982 }
2983 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2984 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2985                              std::complex<float> alpha,
2986                              const DeviceMemory<std::complex<float>> &x,
2987                              int incx,
2988                              const DeviceMemory<std::complex<float>> &y,
2989                              int incy, DeviceMemory<std::complex<float>> *a,
2990                              int lda) {
2991   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2992             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2993 
2994   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2995                const DeviceMemory<std::complex<float>> &, int,
2996                const DeviceMemory<std::complex<float>> &, int,
2997                DeviceMemory<std::complex<float>> *, int> impl;
2998   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2999               incy, a, lda);
3000 }
3001 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)3002 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
3003                              std::complex<double> alpha,
3004                              const DeviceMemory<std::complex<double>> &x,
3005                              int incx,
3006                              const DeviceMemory<std::complex<double>> &y,
3007                              int incy, DeviceMemory<std::complex<double>> *a,
3008                              int lda) {
3009   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3010             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3011 
3012   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3013                const DeviceMemory<std::complex<double>> &, int,
3014                const DeviceMemory<std::complex<double>> &, int,
3015                DeviceMemory<std::complex<double>> *, int> impl;
3016   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
3017               incy, a, lda);
3018 }
3019 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)3020 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
3021                              std::complex<float> alpha,
3022                              const DeviceMemory<std::complex<float>> &ap,
3023                              const DeviceMemory<std::complex<float>> &x,
3024                              int incx, std::complex<float> beta,
3025                              DeviceMemory<std::complex<float>> *y, int incy) {
3026   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3027             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3028 
3029   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
3030                const DeviceMemory<std::complex<float>> &,
3031                const DeviceMemory<std::complex<float>> &, int,
3032                std::complex<float>, DeviceMemory<std::complex<float>> *,
3033                int> impl;
3034   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
3035               beta, y, incy);
3036 }
3037 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)3038 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
3039                              std::complex<double> alpha,
3040                              const DeviceMemory<std::complex<double>> &ap,
3041                              const DeviceMemory<std::complex<double>> &x,
3042                              int incx, std::complex<double> beta,
3043                              DeviceMemory<std::complex<double>> *y, int incy) {
3044   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3045             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3046 
3047   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3048                const DeviceMemory<std::complex<double>> &,
3049                const DeviceMemory<std::complex<double>> &, int,
3050                std::complex<double>, DeviceMemory<std::complex<double>> *,
3051                int> impl;
3052   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
3053               beta, y, incy);
3054 }
3055 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)3056 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
3057                             const DeviceMemory<std::complex<float>> &x,
3058                             int incx, DeviceMemory<std::complex<float>> *ap) {
3059   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3060             PARAM(ap));
3061 
3062   ThenBlasImpl<blas::UpperLower, uint64, float,
3063                const DeviceMemory<std::complex<float>> &, int,
3064                DeviceMemory<std::complex<float>> *> impl;
3065   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
3066 }
3067 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)3068 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
3069                             const DeviceMemory<std::complex<double>> &x,
3070                             int incx, DeviceMemory<std::complex<double>> *ap) {
3071   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3072             PARAM(ap));
3073 
3074   ThenBlasImpl<blas::UpperLower, uint64, double,
3075                const DeviceMemory<std::complex<double>> &, int,
3076                DeviceMemory<std::complex<double>> *> impl;
3077   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
3078 }
3079 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)3080 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
3081                              std::complex<float> alpha,
3082                              const DeviceMemory<std::complex<float>> &x,
3083                              int incx,
3084                              const DeviceMemory<std::complex<float>> &y,
3085                              int incy, DeviceMemory<std::complex<float>> *ap) {
3086   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3087             PARAM(y), PARAM(incy), PARAM(ap));
3088 
3089   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
3090                const DeviceMemory<std::complex<float>> &, int,
3091                const DeviceMemory<std::complex<float>> &, int,
3092                DeviceMemory<std::complex<float>> *> impl;
3093   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
3094               incy, ap);
3095 }
3096 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)3097 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
3098                              std::complex<double> alpha,
3099                              const DeviceMemory<std::complex<double>> &x,
3100                              int incx,
3101                              const DeviceMemory<std::complex<double>> &y,
3102                              int incy, DeviceMemory<std::complex<double>> *ap) {
3103   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3104             PARAM(y), PARAM(incy), PARAM(ap));
3105 
3106   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
3107                const DeviceMemory<std::complex<double>> &, int,
3108                const DeviceMemory<std::complex<double>> &, int,
3109                DeviceMemory<std::complex<double>> *> impl;
3110   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
3111               incy, ap);
3112 }
3113 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3114 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
3115                              float alpha, const DeviceMemory<float> &a, int lda,
3116                              const DeviceMemory<float> &x, int incx, float beta,
3117                              DeviceMemory<float> *y, int incy) {
3118   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
3119             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3120 
3121   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
3122                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3123                int, float, DeviceMemory<float> *, int> impl;
3124   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
3125               x, incx, beta, y, incy);
3126 }
3127 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3128 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
3129                              double alpha, const DeviceMemory<double> &a,
3130                              int lda, const DeviceMemory<double> &x, int incx,
3131                              double beta, DeviceMemory<double> *y, int incy) {
3132   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
3133             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3134 
3135   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
3136                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3137                int, double, DeviceMemory<double> *, int> impl;
3138   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
3139               x, incx, beta, y, incy);
3140 }
3141 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3142 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
3143                              const DeviceMemory<float> &ap,
3144                              const DeviceMemory<float> &x, int incx, float beta,
3145                              DeviceMemory<float> *y, int incy) {
3146   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3147             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3148 
3149   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3150                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3151                int> impl;
3152   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3153               beta, y, incy);
3154 }
3155 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3156 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3157                              const DeviceMemory<double> &ap,
3158                              const DeviceMemory<double> &x, int incx,
3159                              double beta, DeviceMemory<double> *y, int incy) {
3160   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3161             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3162 
3163   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3164                const DeviceMemory<double> &, int, double,
3165                DeviceMemory<double> *, int> impl;
3166   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3167               beta, y, incy);
3168 }
3169 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3170 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3171                             const DeviceMemory<float> &x, int incx,
3172                             DeviceMemory<float> *ap) {
3173   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3174             PARAM(ap));
3175 
3176   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3177                int, DeviceMemory<float> *> impl;
3178   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3179 }
3180 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3181 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3182                             const DeviceMemory<double> &x, int incx,
3183                             DeviceMemory<double> *ap) {
3184   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3185             PARAM(ap));
3186 
3187   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3188                int, DeviceMemory<double> *> impl;
3189   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3190 }
3191 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3192 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3193                              const DeviceMemory<float> &x, int incx,
3194                              const DeviceMemory<float> &y, int incy,
3195                              DeviceMemory<float> *ap) {
3196   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3197             PARAM(y), PARAM(incy), PARAM(ap));
3198 
3199   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3200                int, const DeviceMemory<float> &, int,
3201                DeviceMemory<float> *> impl;
3202   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3203               incy, ap);
3204 }
3205 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3206 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3207                              const DeviceMemory<double> &x, int incx,
3208                              const DeviceMemory<double> &y, int incy,
3209                              DeviceMemory<double> *ap) {
3210   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3211             PARAM(y), PARAM(incy), PARAM(ap));
3212 
3213   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3214                int, const DeviceMemory<double> &, int,
3215                DeviceMemory<double> *> impl;
3216   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3217               incy, ap);
3218 }
3219 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3220 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3221                              const DeviceMemory<float> &a, int lda,
3222                              const DeviceMemory<float> &x, int incx, float beta,
3223                              DeviceMemory<float> *y, int incy) {
3224   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3225             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3226 
3227   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3228                int, const DeviceMemory<float> &, int, float,
3229                DeviceMemory<float> *, int> impl;
3230   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3231               incx, beta, y, incy);
3232 }
3233 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3234 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3235                              const DeviceMemory<double> &a, int lda,
3236                              const DeviceMemory<double> &x, int incx,
3237                              double beta, DeviceMemory<double> *y, int incy) {
3238   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3239             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3240 
3241   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3242                int, const DeviceMemory<double> &, int, double,
3243                DeviceMemory<double> *, int> impl;
3244   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3245               incx, beta, y, incy);
3246 }
3247 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3248 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3249                             const DeviceMemory<float> &x, int incx,
3250                             DeviceMemory<float> *a, int lda) {
3251   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3252             PARAM(a), PARAM(lda));
3253 
3254   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3255                int, DeviceMemory<float> *, int> impl;
3256   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3257               lda);
3258 }
3259 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3260 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3261                             const DeviceMemory<double> &x, int incx,
3262                             DeviceMemory<double> *a, int lda) {
3263   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3264             PARAM(a), PARAM(lda));
3265 
3266   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3267                int, DeviceMemory<double> *, int> impl;
3268   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3269               lda);
3270 }
3271 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3272 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3273                              const DeviceMemory<float> &x, int incx,
3274                              const DeviceMemory<float> &y, int incy,
3275                              DeviceMemory<float> *a, int lda) {
3276   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3277             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3278 
3279   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3280                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3281                int> impl;
3282   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3283               incy, a, lda);
3284 }
3285 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3286 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3287                              const DeviceMemory<double> &x, int incx,
3288                              const DeviceMemory<double> &y, int incy,
3289                              DeviceMemory<double> *a, int lda) {
3290   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3291             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3292 
3293   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3294                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3295                int> impl;
3296   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3297               incy, a, lda);
3298 }
3299 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3300 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3301                              blas::Diagonal diag, uint64 n, uint64 k,
3302                              const DeviceMemory<float> &a, int lda,
3303                              DeviceMemory<float> *x, int incx) {
3304   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3305             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3306 
3307   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3308                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3309                int> impl;
3310   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3311               lda, x, incx);
3312 }
3313 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3314 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3315                              blas::Diagonal diag, uint64 n, uint64 k,
3316                              const DeviceMemory<double> &a, int lda,
3317                              DeviceMemory<double> *x, int incx) {
3318   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3319             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3320 
3321   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3322                uint64, const DeviceMemory<double> &, int,
3323                DeviceMemory<double> *, int> impl;
3324   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3325               lda, x, incx);
3326 }
3327 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3328 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3329                              blas::Diagonal diag, uint64 n, uint64 k,
3330                              const DeviceMemory<std::complex<float>> &a,
3331                              int lda, DeviceMemory<std::complex<float>> *x,
3332                              int incx) {
3333   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3334             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3335 
3336   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3337                uint64, const DeviceMemory<std::complex<float>> &, int,
3338                DeviceMemory<std::complex<float>> *, int> impl;
3339   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3340               lda, x, incx);
3341 }
3342 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3343 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3344                              blas::Diagonal diag, uint64 n, uint64 k,
3345                              const DeviceMemory<std::complex<double>> &a,
3346                              int lda, DeviceMemory<std::complex<double>> *x,
3347                              int incx) {
3348   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3349             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3350 
3351   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3352                uint64, const DeviceMemory<std::complex<double>> &, int,
3353                DeviceMemory<std::complex<double>> *, int> impl;
3354   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3355               lda, x, incx);
3356 }
3357 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3358 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3359                              blas::Diagonal diag, uint64 n, uint64 k,
3360                              const DeviceMemory<float> &a, int lda,
3361                              DeviceMemory<float> *x, int incx) {
3362   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3363             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3364 
3365   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3366                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3367                int> impl;
3368   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3369               lda, x, incx);
3370 }
3371 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3372 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3373                              blas::Diagonal diag, uint64 n, uint64 k,
3374                              const DeviceMemory<double> &a, int lda,
3375                              DeviceMemory<double> *x, int incx) {
3376   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3377             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3378 
3379   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3380                uint64, const DeviceMemory<double> &, int,
3381                DeviceMemory<double> *, int> impl;
3382   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3383               lda, x, incx);
3384 }
3385 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3386 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3387                              blas::Diagonal diag, uint64 n, uint64 k,
3388                              const DeviceMemory<std::complex<float>> &a,
3389                              int lda, DeviceMemory<std::complex<float>> *x,
3390                              int incx) {
3391   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3392             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3393 
3394   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3395                uint64, const DeviceMemory<std::complex<float>> &, int,
3396                DeviceMemory<std::complex<float>> *, int> impl;
3397   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3398               lda, x, incx);
3399 }
3400 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3401 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3402                              blas::Diagonal diag, uint64 n, uint64 k,
3403                              const DeviceMemory<std::complex<double>> &a,
3404                              int lda, DeviceMemory<std::complex<double>> *x,
3405                              int incx) {
3406   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3407             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3408 
3409   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3410                uint64, const DeviceMemory<std::complex<double>> &, int,
3411                DeviceMemory<std::complex<double>> *, int> impl;
3412   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3413               lda, x, incx);
3414 }
3415 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3416 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3417                              blas::Diagonal diag, uint64 n,
3418                              const DeviceMemory<float> &ap,
3419                              DeviceMemory<float> *x, int incx) {
3420   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3421             PARAM(x), PARAM(incx));
3422 
3423   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3424                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3425   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3426               incx);
3427 }
3428 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3429 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3430                              blas::Diagonal diag, uint64 n,
3431                              const DeviceMemory<double> &ap,
3432                              DeviceMemory<double> *x, int incx) {
3433   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3434             PARAM(x), PARAM(incx));
3435 
3436   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3437                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3438   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3439               incx);
3440 }
3441 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3442 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3443                              blas::Diagonal diag, uint64 n,
3444                              const DeviceMemory<std::complex<float>> &ap,
3445                              DeviceMemory<std::complex<float>> *x, int incx) {
3446   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3447             PARAM(x), PARAM(incx));
3448 
3449   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3450                const DeviceMemory<std::complex<float>> &,
3451                DeviceMemory<std::complex<float>> *, int> impl;
3452   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3453               incx);
3454 }
3455 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3456 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3457                              blas::Diagonal diag, uint64 n,
3458                              const DeviceMemory<std::complex<double>> &ap,
3459                              DeviceMemory<std::complex<double>> *x, int incx) {
3460   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3461             PARAM(x), PARAM(incx));
3462 
3463   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3464                const DeviceMemory<std::complex<double>> &,
3465                DeviceMemory<std::complex<double>> *, int> impl;
3466   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3467               incx);
3468 }
3469 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3470 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3471                              blas::Diagonal diag, uint64 n,
3472                              const DeviceMemory<float> &ap,
3473                              DeviceMemory<float> *x, int incx) {
3474   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3475             PARAM(x), PARAM(incx));
3476 
3477   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3478                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3479   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3480               incx);
3481 }
3482 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3483 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3484                              blas::Diagonal diag, uint64 n,
3485                              const DeviceMemory<double> &ap,
3486                              DeviceMemory<double> *x, int incx) {
3487   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3488             PARAM(x), PARAM(incx));
3489 
3490   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3491                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3492   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3493               incx);
3494 }
3495 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3496 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3497                              blas::Diagonal diag, uint64 n,
3498                              const DeviceMemory<std::complex<float>> &ap,
3499                              DeviceMemory<std::complex<float>> *x, int incx) {
3500   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3501             PARAM(x), PARAM(incx));
3502 
3503   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3504                const DeviceMemory<std::complex<float>> &,
3505                DeviceMemory<std::complex<float>> *, int> impl;
3506   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3507               incx);
3508 }
3509 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3510 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3511                              blas::Diagonal diag, uint64 n,
3512                              const DeviceMemory<std::complex<double>> &ap,
3513                              DeviceMemory<std::complex<double>> *x, int incx) {
3514   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3515             PARAM(x), PARAM(incx));
3516 
3517   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3518                const DeviceMemory<std::complex<double>> &,
3519                DeviceMemory<std::complex<double>> *, int> impl;
3520   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3521               incx);
3522 }
3523 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3524 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3525                              blas::Diagonal diag, uint64 n,
3526                              const DeviceMemory<float> &a, int lda,
3527                              DeviceMemory<float> *x, int incx) {
3528   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3529             PARAM(lda), PARAM(x), PARAM(incx));
3530 
3531   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3532                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3533                int> impl;
3534   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3535               lda, x, incx);
3536 }
3537 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3538 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3539                              blas::Diagonal diag, uint64 n,
3540                              const DeviceMemory<double> &a, int lda,
3541                              DeviceMemory<double> *x, int incx) {
3542   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3543             PARAM(lda), PARAM(x), PARAM(incx));
3544 
3545   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3546                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3547                int> impl;
3548   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3549               lda, x, incx);
3550 }
3551 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3552 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3553                              blas::Diagonal diag, uint64 n,
3554                              const DeviceMemory<std::complex<float>> &a,
3555                              int lda, DeviceMemory<std::complex<float>> *x,
3556                              int incx) {
3557   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3558             PARAM(lda), PARAM(x), PARAM(incx));
3559 
3560   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3561                const DeviceMemory<std::complex<float>> &, int,
3562                DeviceMemory<std::complex<float>> *, int> impl;
3563   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3564               lda, x, incx);
3565 }
3566 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3567 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3568                              blas::Diagonal diag, uint64 n,
3569                              const DeviceMemory<std::complex<double>> &a,
3570                              int lda, DeviceMemory<std::complex<double>> *x,
3571                              int incx) {
3572   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3573             PARAM(lda), PARAM(x), PARAM(incx));
3574 
3575   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3576                const DeviceMemory<std::complex<double>> &, int,
3577                DeviceMemory<std::complex<double>> *, int> impl;
3578   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3579               lda, x, incx);
3580 }
3581 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3582 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3583                              blas::Diagonal diag, uint64 n,
3584                              const DeviceMemory<float> &a, int lda,
3585                              DeviceMemory<float> *x, int incx) {
3586   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3587             PARAM(lda), PARAM(x), PARAM(incx));
3588 
3589   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3590                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3591                int> impl;
3592   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3593               lda, x, incx);
3594 }
3595 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3596 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3597                              blas::Diagonal diag, uint64 n,
3598                              const DeviceMemory<double> &a, int lda,
3599                              DeviceMemory<double> *x, int incx) {
3600   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3601             PARAM(lda), PARAM(x), PARAM(incx));
3602 
3603   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3604                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3605                int> impl;
3606   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3607               lda, x, incx);
3608 }
3609 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3610 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3611                              blas::Diagonal diag, uint64 n,
3612                              const DeviceMemory<std::complex<float>> &a,
3613                              int lda, DeviceMemory<std::complex<float>> *x,
3614                              int incx) {
3615   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3616             PARAM(lda), PARAM(x), PARAM(incx));
3617 
3618   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3619                const DeviceMemory<std::complex<float>> &, int,
3620                DeviceMemory<std::complex<float>> *, int> impl;
3621   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3622               lda, x, incx);
3623 }
3624 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3625 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3626                              blas::Diagonal diag, uint64 n,
3627                              const DeviceMemory<std::complex<double>> &a,
3628                              int lda, DeviceMemory<std::complex<double>> *x,
3629                              int incx) {
3630   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3631             PARAM(lda), PARAM(x), PARAM(incx));
3632 
3633   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3634                const DeviceMemory<std::complex<double>> &, int,
3635                DeviceMemory<std::complex<double>> *, int> impl;
3636   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3637               lda, x, incx);
3638 }
3639 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3640 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3641                              uint64 m, uint64 n, uint64 k, float alpha,
3642                              const DeviceMemory<Eigen::half> &a, int lda,
3643                              const DeviceMemory<Eigen::half> &b, int ldb,
3644                              float beta,
3645                              DeviceMemory<Eigen::half> *c, int ldc) {
3646   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3647             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3648             PARAM(beta), PARAM(c), PARAM(ldc));
3649 
3650   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3651                const DeviceMemory<Eigen::half> &, int,
3652                const DeviceMemory<Eigen::half> &, int,
3653                float, DeviceMemory<Eigen::half> *, int> impl;
3654   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3655               alpha, a, lda, b, ldb, beta, c, ldc);
3656 }
3657 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3658 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3659                              uint64 m, uint64 n, uint64 k, float alpha,
3660                              const DeviceMemory<float> &a, int lda,
3661                              const DeviceMemory<float> &b, int ldb, float beta,
3662                              DeviceMemory<float> *c, int ldc) {
3663   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3664             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3665             PARAM(beta), PARAM(c), PARAM(ldc));
3666 
3667   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3668                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3669                int, float, DeviceMemory<float> *, int> impl;
3670   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3671               alpha, a, lda, b, ldb, beta, c, ldc);
3672 }
3673 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3674 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3675                              uint64 m, uint64 n, uint64 k, double alpha,
3676                              const DeviceMemory<double> &a, int lda,
3677                              const DeviceMemory<double> &b, int ldb,
3678                              double beta, DeviceMemory<double> *c, int ldc) {
3679   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3680             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3681             PARAM(beta), PARAM(c), PARAM(ldc));
3682 
3683   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3684                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3685                int, double, DeviceMemory<double> *, int> impl;
3686   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3687               alpha, a, lda, b, ldb, beta, c, ldc);
3688 }
3689 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3690 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3691                              uint64 m, uint64 n, uint64 k,
3692                              std::complex<float> alpha,
3693                              const DeviceMemory<std::complex<float>> &a,
3694                              int lda,
3695                              const DeviceMemory<std::complex<float>> &b,
3696                              int ldb, std::complex<float> beta,
3697                              DeviceMemory<std::complex<float>> *c, int ldc) {
3698   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3699             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3700             PARAM(beta), PARAM(c), PARAM(ldc));
3701 
3702   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3703                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3704                int, const DeviceMemory<std::complex<float>> &, int,
3705                std::complex<float>, DeviceMemory<std::complex<float>> *,
3706                int> impl;
3707   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3708               alpha, a, lda, b, ldb, beta, c, ldc);
3709 }
3710 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3711 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3712                              uint64 m, uint64 n, uint64 k,
3713                              std::complex<double> alpha,
3714                              const DeviceMemory<std::complex<double>> &a,
3715                              int lda,
3716                              const DeviceMemory<std::complex<double>> &b,
3717                              int ldb, std::complex<double> beta,
3718                              DeviceMemory<std::complex<double>> *c, int ldc) {
3719   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3720             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3721             PARAM(beta), PARAM(c), PARAM(ldc));
3722 
3723   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3724                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3725                int, const DeviceMemory<std::complex<double>> &, int,
3726                std::complex<double>, DeviceMemory<std::complex<double>> *,
3727                int> impl;
3728   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3729               alpha, a, lda, b, ldb, beta, c, ldc);
3730 }
3731 
3732 namespace {
3733 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3734 // blas::ProfileResult*.  This functor doesn't put the stream into an error
3735 // state if the op fails and the profile result is non-null.  Instead, the
3736 // error-ness is returned in the profile result itself.
3737 template <typename... Args>
3738 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon537844490211::ThenBlasWithProfileImpl3739   Stream &operator()(Stream *stream,
3740                      bool (blas::BlasSupport::*blas_func)(
3741                          Stream *, Args..., blas::ProfileResult *),
3742                      Args... args, blas::ProfileResult *profile_result) {
3743     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3744     bool record_error = profile_result == nullptr;
3745     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3746   }
3747 };
3748 }  // anonymous namespace
3749 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3750 Stream &Stream::ThenBlasGemvWithProfiling(
3751     blas::Transpose trans, uint64 m, uint64 n, float alpha,
3752     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3753     int incx, float beta, DeviceMemory<float> *y, int incy,
3754     blas::ProfileResult *output_profile_result) {
3755   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3756             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3757             PARAM(incy));
3758 
3759   ThenBlasWithProfileImpl<
3760       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3761       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3762       impl;
3763   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3764               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3765 }
3766 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3767 Stream &Stream::ThenBlasGemvWithProfiling(
3768     blas::Transpose trans, uint64 m, uint64 n, double alpha,
3769     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3770     int incx, double beta, DeviceMemory<double> *y, int incy,
3771     blas::ProfileResult *output_profile_result) {
3772   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3773             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3774             PARAM(incy));
3775 
3776   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3777                           const DeviceMemory<double> &, int,
3778                           const DeviceMemory<double> &, int, double,
3779                           DeviceMemory<double> *, int>
3780       impl;
3781   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3782               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3783 }
3784 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3785 Stream &Stream::ThenBlasGemvWithProfiling(
3786     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3787     const DeviceMemory<std::complex<float>> &a, int lda,
3788     const DeviceMemory<std::complex<float>> &x, int incx,
3789     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3790     blas::ProfileResult *output_profile_result) {
3791   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3792             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3793             PARAM(incy));
3794 
3795   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3796                           const DeviceMemory<std::complex<float>> &, int,
3797                           const DeviceMemory<std::complex<float>> &, int,
3798                           std::complex<float>,
3799                           DeviceMemory<std::complex<float>> *, int>
3800       impl;
3801   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3802               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3803 }
3804 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3805 Stream &Stream::ThenBlasGemvWithProfiling(
3806     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3807     const DeviceMemory<std::complex<double>> &a, int lda,
3808     const DeviceMemory<std::complex<double>> &x, int incx,
3809     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3810     blas::ProfileResult *output_profile_result) {
3811   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3812             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3813             PARAM(incy));
3814 
3815   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3816                           const DeviceMemory<std::complex<double>> &, int,
3817                           const DeviceMemory<std::complex<double>> &, int,
3818                           std::complex<double>,
3819                           DeviceMemory<std::complex<double>> *, int>
3820       impl;
3821   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3822               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3823 }
3824 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3825 Stream &Stream::ThenBlasGemmWithProfiling(
3826     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3827     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3828     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3829     DeviceMemory<Eigen::half> *c, int ldc,
3830     blas::ProfileResult *output_profile_result) {
3831   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3832             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3833             PARAM(beta), PARAM(c), PARAM(ldc));
3834 
3835   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3836                           uint64, float, const DeviceMemory<Eigen::half> &, int,
3837                           const DeviceMemory<Eigen::half> &, int, float,
3838                           DeviceMemory<Eigen::half> *, int>
3839       impl;
3840   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3841               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3842               output_profile_result);
3843 }
3844 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3845 Stream &Stream::ThenBlasGemmWithProfiling(
3846     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3847     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3848     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3849     int ldc, blas::ProfileResult *output_profile_result) {
3850   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3851             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3852             PARAM(beta), PARAM(c), PARAM(ldc));
3853 
3854   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3855                           uint64, float, const DeviceMemory<float> &, int,
3856                           const DeviceMemory<float> &, int, float,
3857                           DeviceMemory<float> *, int>
3858       impl;
3859   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3860               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3861               output_profile_result);
3862 }
3863 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3864 Stream &Stream::ThenBlasGemmWithProfiling(
3865     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3866     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3867     const DeviceMemory<double> &b, int ldb, double beta,
3868     DeviceMemory<double> *c, int ldc,
3869     blas::ProfileResult *output_profile_result) {
3870   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3871             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3872             PARAM(beta), PARAM(c), PARAM(ldc));
3873 
3874   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3875                           uint64, double, const DeviceMemory<double> &, int,
3876                           const DeviceMemory<double> &, int, double,
3877                           DeviceMemory<double> *, int>
3878       impl;
3879   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3880               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3881               output_profile_result);
3882 }
3883 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3884 Stream &Stream::ThenBlasGemmWithProfiling(
3885     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3886     uint64 k, std::complex<float> alpha,
3887     const DeviceMemory<std::complex<float>> &a, int lda,
3888     const DeviceMemory<std::complex<float>> &b, int ldb,
3889     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3890     blas::ProfileResult *output_profile_result) {
3891   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3892             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3893             PARAM(beta), PARAM(c), PARAM(ldc));
3894 
3895   ThenBlasWithProfileImpl<
3896       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3897       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3898       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3899       DeviceMemory<std::complex<float>> *, int>
3900       impl;
3901   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3902               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3903               output_profile_result);
3904 }
3905 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3906 Stream &Stream::ThenBlasGemmWithProfiling(
3907     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3908     uint64 k, std::complex<double> alpha,
3909     const DeviceMemory<std::complex<double>> &a, int lda,
3910     const DeviceMemory<std::complex<double>> &b, int ldb,
3911     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3912     blas::ProfileResult *output_profile_result) {
3913   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3914             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3915             PARAM(beta), PARAM(c), PARAM(ldc));
3916 
3917   ThenBlasWithProfileImpl<
3918       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3919       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3920       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3921       DeviceMemory<std::complex<double>> *, int>
3922       impl;
3923   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3924               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3925               output_profile_result);
3926 }
3927 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3928 Stream &Stream::ThenBlasGemmWithAlgorithm(
3929     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3930     uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3931     const DeviceMemory<Eigen::half> &a, int lda,
3932     const DeviceMemory<Eigen::half> &b, int ldb,
3933     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3934     int ldc, blas::ComputationType computation_type,
3935     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3936   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3937             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3938             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3939             PARAM(algorithm));
3940 
3941   ThenBlasWithProfileImpl<
3942       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3943       const HostOrDeviceScalar<Eigen::half> &,
3944       const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3945       int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3946       int, blas::ComputationType, blas::AlgorithmType>
3947       impl;
3948   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3949               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3950               algorithm, output_profile_result);
3951 }
3952 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3953 Stream &Stream::ThenBlasGemmWithAlgorithm(
3954     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3955     uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3956     int lda, const DeviceMemory<int8> &b, int ldb,
3957     const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3958     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3959     blas::ProfileResult *output_profile_result) {
3960   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3961             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3962             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3963             PARAM(algorithm));
3964 
3965   ThenBlasWithProfileImpl<
3966       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3967       const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3968       const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3969       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3970       impl;
3971   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3972               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3973               algorithm, output_profile_result);
3974 }
3975 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3976 Stream &Stream::ThenBlasGemmWithAlgorithm(
3977     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3978     uint64 k, const HostOrDeviceScalar<float> &alpha,
3979     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3980     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3981     int ldc, blas::ComputationType computation_type,
3982     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3983   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3984             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3985             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3986             PARAM(algorithm));
3987 
3988   ThenBlasWithProfileImpl<
3989       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3990       const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3991       const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3992       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3993       impl;
3994   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3995               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3996               algorithm, output_profile_result);
3997 }
3998 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3999 Stream &Stream::ThenBlasGemmWithAlgorithm(
4000     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4001     uint64 k, const HostOrDeviceScalar<double> &alpha,
4002     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
4003     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
4004     int ldc, blas::ComputationType computation_type,
4005     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
4006   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4007             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4008             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4009             PARAM(algorithm));
4010 
4011   ThenBlasWithProfileImpl<
4012       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4013       const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
4014       const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
4015       DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
4016       impl;
4017   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4018               m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
4019               HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
4020               algorithm, output_profile_result);
4021 }
4022 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)4023 Stream &Stream::ThenBlasGemmWithAlgorithm(
4024     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4025     uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
4026     const DeviceMemory<std::complex<float>> &a, int lda,
4027     const DeviceMemory<std::complex<float>> &b, int ldb,
4028     const HostOrDeviceScalar<std::complex<float>> &beta,
4029     DeviceMemory<std::complex<float>> *c, int ldc,
4030     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
4031     blas::ProfileResult *output_profile_result) {
4032   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4033             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4034             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4035             PARAM(algorithm));
4036 
4037   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
4038                           uint64,
4039                           const HostOrDeviceScalar<std::complex<float>> &,
4040                           const DeviceMemory<std::complex<float>> &, int,
4041                           const DeviceMemory<std::complex<float>> &, int,
4042                           const HostOrDeviceScalar<std::complex<float>> &,
4043                           DeviceMemory<std::complex<float>> *, int,
4044                           blas::ComputationType, blas::AlgorithmType>
4045       impl;
4046   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4047               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
4048               algorithm, output_profile_result);
4049 }
4050 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)4051 Stream &Stream::ThenBlasGemmWithAlgorithm(
4052     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4053     uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
4054     const DeviceMemory<std::complex<double>> &a, int lda,
4055     const DeviceMemory<std::complex<double>> &b, int ldb,
4056     const HostOrDeviceScalar<std::complex<double>> &beta,
4057     DeviceMemory<std::complex<double>> *c, int ldc,
4058     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
4059     blas::ProfileResult *output_profile_result) {
4060   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4061             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4062             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
4063             PARAM(algorithm));
4064 
4065   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
4066                           uint64,
4067                           const HostOrDeviceScalar<std::complex<double>> &,
4068                           const DeviceMemory<std::complex<double>> &, int,
4069                           const DeviceMemory<std::complex<double>> &, int,
4070                           const HostOrDeviceScalar<std::complex<double>> &,
4071                           DeviceMemory<std::complex<double>> *, int,
4072                           blas::ComputationType, blas::AlgorithmType>
4073       impl;
4074   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
4075               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
4076               algorithm, output_profile_result);
4077 }
4078 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4079 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
4080                              uint64 n, std::complex<float> alpha,
4081                              const DeviceMemory<std::complex<float>> &a,
4082                              int lda,
4083                              const DeviceMemory<std::complex<float>> &b,
4084                              int ldb, std::complex<float> beta,
4085                              DeviceMemory<std::complex<float>> *c, int ldc) {
4086   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4087             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4088             PARAM(ldc));
4089 
4090   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4091                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4092                int, const DeviceMemory<std::complex<float>> &, int,
4093                std::complex<float>, DeviceMemory<std::complex<float>> *,
4094                int> impl;
4095   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
4096               lda, b, ldb, beta, c, ldc);
4097 }
4098 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4099 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
4100                              uint64 n, std::complex<double> alpha,
4101                              const DeviceMemory<std::complex<double>> &a,
4102                              int lda,
4103                              const DeviceMemory<std::complex<double>> &b,
4104                              int ldb, std::complex<double> beta,
4105                              DeviceMemory<std::complex<double>> *c, int ldc) {
4106   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4107             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4108             PARAM(ldc));
4109 
4110   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4111                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4112                int, const DeviceMemory<std::complex<double>> &, int,
4113                std::complex<double>, DeviceMemory<std::complex<double>> *,
4114                int> impl;
4115   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
4116               lda, b, ldb, beta, c, ldc);
4117 }
4118 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4119 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
4120                              uint64 n, uint64 k, float alpha,
4121                              const DeviceMemory<std::complex<float>> &a,
4122                              int lda, float beta,
4123                              DeviceMemory<std::complex<float>> *c, int ldc) {
4124   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4125             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4126 
4127   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4128                const DeviceMemory<std::complex<float>> &, int, float,
4129                DeviceMemory<std::complex<float>> *, int> impl;
4130   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4131               lda, beta, c, ldc);
4132 }
4133 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4134 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
4135                              uint64 n, uint64 k, double alpha,
4136                              const DeviceMemory<std::complex<double>> &a,
4137                              int lda, double beta,
4138                              DeviceMemory<std::complex<double>> *c, int ldc) {
4139   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4140             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4141 
4142   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4143                const DeviceMemory<std::complex<double>> &, int, double,
4144                DeviceMemory<std::complex<double>> *, int> impl;
4145   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
4146               lda, beta, c, ldc);
4147 }
4148 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)4149 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4150                               uint64 n, uint64 k, std::complex<float> alpha,
4151                               const DeviceMemory<std::complex<float>> &a,
4152                               int lda,
4153                               const DeviceMemory<std::complex<float>> &b,
4154                               int ldb, float beta,
4155                               DeviceMemory<std::complex<float>> *c, int ldc) {
4156   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4157             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4158             PARAM(ldc));
4159 
4160   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4161                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4162                int, const DeviceMemory<std::complex<float>> &, int, float,
4163                DeviceMemory<std::complex<float>> *, int> impl;
4164   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4165               a, lda, b, ldb, beta, c, ldc);
4166 }
4167 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4168 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4169                               uint64 n, uint64 k, std::complex<double> alpha,
4170                               const DeviceMemory<std::complex<double>> &a,
4171                               int lda,
4172                               const DeviceMemory<std::complex<double>> &b,
4173                               int ldb, double beta,
4174                               DeviceMemory<std::complex<double>> *c, int ldc) {
4175   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4176             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4177             PARAM(ldc));
4178 
4179   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4180                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4181                int, const DeviceMemory<std::complex<double>> &, int, double,
4182                DeviceMemory<std::complex<double>> *, int> impl;
4183   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4184               a, lda, b, ldb, beta, c, ldc);
4185 }
4186 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4187 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4188                              uint64 n, float alpha,
4189                              const DeviceMemory<float> &a, int lda,
4190                              const DeviceMemory<float> &b, int ldb, float beta,
4191                              DeviceMemory<float> *c, int ldc) {
4192   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4193             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4194             PARAM(ldc));
4195 
4196   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4197                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4198                int, float, DeviceMemory<float> *, int> impl;
4199   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4200               lda, b, ldb, beta, c, ldc);
4201 }
4202 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4203 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4204                              uint64 n, double alpha,
4205                              const DeviceMemory<double> &a, int lda,
4206                              const DeviceMemory<double> &b, int ldb,
4207                              double beta, DeviceMemory<double> *c, int ldc) {
4208   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4209             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4210             PARAM(ldc));
4211 
4212   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4213                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4214                int, double, DeviceMemory<double> *, int> impl;
4215   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4216               lda, b, ldb, beta, c, ldc);
4217 }
4218 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4219 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4220                              uint64 n, std::complex<float> alpha,
4221                              const DeviceMemory<std::complex<float>> &a,
4222                              int lda,
4223                              const DeviceMemory<std::complex<float>> &b,
4224                              int ldb, std::complex<float> beta,
4225                              DeviceMemory<std::complex<float>> *c, int ldc) {
4226   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4227             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4228             PARAM(ldc));
4229 
4230   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4231                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4232                int, const DeviceMemory<std::complex<float>> &, int,
4233                std::complex<float>, DeviceMemory<std::complex<float>> *,
4234                int> impl;
4235   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4236               lda, b, ldb, beta, c, ldc);
4237 }
4238 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4239 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4240                              uint64 n, std::complex<double> alpha,
4241                              const DeviceMemory<std::complex<double>> &a,
4242                              int lda,
4243                              const DeviceMemory<std::complex<double>> &b,
4244                              int ldb, std::complex<double> beta,
4245                              DeviceMemory<std::complex<double>> *c, int ldc) {
4246   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4247             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4248             PARAM(ldc));
4249 
4250   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4251                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4252                int, const DeviceMemory<std::complex<double>> &, int,
4253                std::complex<double>, DeviceMemory<std::complex<double>> *,
4254                int> impl;
4255   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4256               lda, b, ldb, beta, c, ldc);
4257 }
4258 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4259 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4260                              uint64 n, uint64 k, float alpha,
4261                              const DeviceMemory<float> &a, int lda, float beta,
4262                              DeviceMemory<float> *c, int ldc) {
4263   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4264             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4265 
4266   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4267                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4268                int> impl;
4269   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4270               lda, beta, c, ldc);
4271 }
4272 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4273 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4274                              uint64 n, uint64 k, double alpha,
4275                              const DeviceMemory<double> &a, int lda,
4276                              double beta, DeviceMemory<double> *c, int ldc) {
4277   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4278             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4279 
4280   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4281                const DeviceMemory<double> &, int, double,
4282                DeviceMemory<double> *, int> impl;
4283   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4284               lda, beta, c, ldc);
4285 }
4286 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4287 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4288                              uint64 n, uint64 k, std::complex<float> alpha,
4289                              const DeviceMemory<std::complex<float>> &a,
4290                              int lda, std::complex<float> beta,
4291                              DeviceMemory<std::complex<float>> *c, int ldc) {
4292   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4293             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4294 
4295   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4296                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4297                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4298                int> impl;
4299   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4300               lda, beta, c, ldc);
4301 }
4302 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4303 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4304                              uint64 n, uint64 k, std::complex<double> alpha,
4305                              const DeviceMemory<std::complex<double>> &a,
4306                              int lda, std::complex<double> beta,
4307                              DeviceMemory<std::complex<double>> *c, int ldc) {
4308   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4309             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4310 
4311   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4312                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4313                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4314                int> impl;
4315   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4316               lda, beta, c, ldc);
4317 }
4318 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4319 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4320                               uint64 n, uint64 k, float alpha,
4321                               const DeviceMemory<float> &a, int lda,
4322                               const DeviceMemory<float> &b, int ldb, float beta,
4323                               DeviceMemory<float> *c, int ldc) {
4324   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4325             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4326             PARAM(ldc));
4327 
4328   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4329                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4330                int, float, DeviceMemory<float> *, int> impl;
4331   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4332               a, lda, b, ldb, beta, c, ldc);
4333 }
4334 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4335 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4336                               uint64 n, uint64 k, double alpha,
4337                               const DeviceMemory<double> &a, int lda,
4338                               const DeviceMemory<double> &b, int ldb,
4339                               double beta, DeviceMemory<double> *c, int ldc) {
4340   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4341             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4342             PARAM(ldc));
4343 
4344   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4345                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4346                int, double, DeviceMemory<double> *, int> impl;
4347   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4348               a, lda, b, ldb, beta, c, ldc);
4349 }
4350 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4351 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4352                               uint64 n, uint64 k, std::complex<float> alpha,
4353                               const DeviceMemory<std::complex<float>> &a,
4354                               int lda,
4355                               const DeviceMemory<std::complex<float>> &b,
4356                               int ldb, std::complex<float> beta,
4357                               DeviceMemory<std::complex<float>> *c, int ldc) {
4358   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4359             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4360             PARAM(ldc));
4361 
4362   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4363                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4364                int, const DeviceMemory<std::complex<float>> &, int,
4365                std::complex<float>, DeviceMemory<std::complex<float>> *,
4366                int> impl;
4367   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4368               a, lda, b, ldb, beta, c, ldc);
4369 }
4370 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4371 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4372                               uint64 n, uint64 k, std::complex<double> alpha,
4373                               const DeviceMemory<std::complex<double>> &a,
4374                               int lda,
4375                               const DeviceMemory<std::complex<double>> &b,
4376                               int ldb, std::complex<double> beta,
4377                               DeviceMemory<std::complex<double>> *c, int ldc) {
4378   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4379             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4380             PARAM(ldc));
4381 
4382   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4383                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4384                int, const DeviceMemory<std::complex<double>> &, int,
4385                std::complex<double>, DeviceMemory<std::complex<double>> *,
4386                int> impl;
4387   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4388               a, lda, b, ldb, beta, c, ldc);
4389 }
4390 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4391 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4392                              blas::Transpose transa, blas::Diagonal diag,
4393                              uint64 m, uint64 n, float alpha,
4394                              const DeviceMemory<float> &a, int lda,
4395                              DeviceMemory<float> *b, int ldb) {
4396   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4397             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4398 
4399   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4400                uint64, uint64, float, const DeviceMemory<float> &, int,
4401                DeviceMemory<float> *, int> impl;
4402   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4403               n, alpha, a, lda, b, ldb);
4404 }
4405 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4406 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4407                              blas::Transpose transa, blas::Diagonal diag,
4408                              uint64 m, uint64 n, double alpha,
4409                              const DeviceMemory<double> &a, int lda,
4410                              DeviceMemory<double> *b, int ldb) {
4411   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4412             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4413 
4414   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4415                uint64, uint64, double, const DeviceMemory<double> &, int,
4416                DeviceMemory<double> *, int> impl;
4417   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4418               n, alpha, a, lda, b, ldb);
4419 }
4420 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4421 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4422                              blas::Transpose transa, blas::Diagonal diag,
4423                              uint64 m, uint64 n, std::complex<float> alpha,
4424                              const DeviceMemory<std::complex<float>> &a,
4425                              int lda, DeviceMemory<std::complex<float>> *b,
4426                              int ldb) {
4427   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4428             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4429 
4430   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4431                uint64, uint64, std::complex<float>,
4432                const DeviceMemory<std::complex<float>> &, int,
4433                DeviceMemory<std::complex<float>> *, int> impl;
4434   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4435               n, alpha, a, lda, b, ldb);
4436 }
4437 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4438 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4439                              blas::Transpose transa, blas::Diagonal diag,
4440                              uint64 m, uint64 n, std::complex<double> alpha,
4441                              const DeviceMemory<std::complex<double>> &a,
4442                              int lda, DeviceMemory<std::complex<double>> *b,
4443                              int ldb) {
4444   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4445             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4446 
4447   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4448                uint64, uint64, std::complex<double>,
4449                const DeviceMemory<std::complex<double>> &, int,
4450                DeviceMemory<std::complex<double>> *, int> impl;
4451   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4452               n, alpha, a, lda, b, ldb);
4453 }
4454 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4455 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4456                              blas::Transpose transa, blas::Diagonal diag,
4457                              uint64 m, uint64 n, float alpha,
4458                              const DeviceMemory<float> &a, int lda,
4459                              DeviceMemory<float> *b, int ldb) {
4460   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4461             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4462 
4463   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4464                uint64, uint64, float, const DeviceMemory<float> &, int,
4465                DeviceMemory<float> *, int> impl;
4466   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4467               n, alpha, a, lda, b, ldb);
4468 }
4469 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4470 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4471                              blas::Transpose transa, blas::Diagonal diag,
4472                              uint64 m, uint64 n, double alpha,
4473                              const DeviceMemory<double> &a, int lda,
4474                              DeviceMemory<double> *b, int ldb) {
4475   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4476             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4477 
4478   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4479                uint64, uint64, double, const DeviceMemory<double> &, int,
4480                DeviceMemory<double> *, int> impl;
4481   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4482               n, alpha, a, lda, b, ldb);
4483 }
4484 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4485 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4486                              blas::Transpose transa, blas::Diagonal diag,
4487                              uint64 m, uint64 n, std::complex<float> alpha,
4488                              const DeviceMemory<std::complex<float>> &a,
4489                              int lda, DeviceMemory<std::complex<float>> *b,
4490                              int ldb) {
4491   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4492             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4493 
4494   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4495                uint64, uint64, std::complex<float>,
4496                const DeviceMemory<std::complex<float>> &, int,
4497                DeviceMemory<std::complex<float>> *, int> impl;
4498   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4499               n, alpha, a, lda, b, ldb);
4500 }
4501 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4502 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4503                              blas::Transpose transa, blas::Diagonal diag,
4504                              uint64 m, uint64 n, std::complex<double> alpha,
4505                              const DeviceMemory<std::complex<double>> &a,
4506                              int lda, DeviceMemory<std::complex<double>> *b,
4507                              int ldb) {
4508   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4509             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4510 
4511   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4512                uint64, uint64, std::complex<double>,
4513                const DeviceMemory<std::complex<double>> &, int,
4514                DeviceMemory<std::complex<double>> *, int> impl;
4515   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4516               n, alpha, a, lda, b, ldb);
4517 }
4518 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4519 Stream &Stream::ThenBlasGemmBatched(
4520     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4521     uint64 k, float alpha,
4522     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4523     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4524     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4525     int batch_count) {
4526   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4527                                         b, ldb, beta, c, ldc, batch_count,
4528                                         /*scratch_allocator=*/nullptr);
4529 }
4530 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4531 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4532     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4533     uint64 k, float alpha,
4534     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4535     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4536     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4537     int batch_count, ScratchAllocator *scratch_allocator) {
4538   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4539             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4540             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4541 
4542   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4543                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4544                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4545                float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4546                int, int, ScratchAllocator *>
4547       impl;
4548   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4549               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4550               scratch_allocator);
4551 }
4552 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4553 Stream &Stream::ThenBlasGemmBatched(
4554     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4555     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4556     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4557     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4558     int batch_count) {
4559   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4560                                         b, ldb, beta, c, ldc, batch_count,
4561                                         /*scratch_allocator=*/nullptr);
4562 }
4563 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4564 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4565     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4566     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4567     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4568     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4569     int batch_count, ScratchAllocator *scratch_allocator) {
4570   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4571             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4572             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4573 
4574   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4575                const port::ArraySlice<DeviceMemory<float> *> &, int,
4576                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4577                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4578                ScratchAllocator *>
4579       impl;
4580   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4581               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4582               scratch_allocator);
4583 }
4584 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4585 Stream &Stream::ThenBlasGemmBatched(
4586     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4587     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4588     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4589     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4590     int batch_count) {
4591   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4592                                         b, ldb, beta, c, ldc, batch_count,
4593                                         /*scratch_allocator=*/nullptr);
4594 }
4595 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4596 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4597     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4598     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4599     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4600     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4601     int batch_count, ScratchAllocator *scratch_allocator) {
4602   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4603             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4604             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4605 
4606   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4607                const port::ArraySlice<DeviceMemory<double> *> &, int,
4608                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4609                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4610                ScratchAllocator *>
4611       impl;
4612   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4613               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4614               scratch_allocator);
4615 }
4616 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4617 Stream &Stream::ThenBlasGemmBatched(
4618     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4619     uint64 k, std::complex<float> alpha,
4620     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4621     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4622     std::complex<float> beta,
4623     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4624     int batch_count) {
4625   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4626                                         b, ldb, beta, c, ldc, batch_count,
4627                                         /*scratch_allocator=*/nullptr);
4628 }
4629 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4630 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4631     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4632     uint64 k, std::complex<float> alpha,
4633     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4634     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4635     std::complex<float> beta,
4636     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4637     int batch_count, ScratchAllocator *scratch_allocator) {
4638   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4639             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4640             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4641 
4642   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4643                std::complex<float>,
4644                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4645                int,
4646                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4647                int, std::complex<float>,
4648                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4649                int, int, ScratchAllocator *>
4650       impl;
4651   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4652               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4653               scratch_allocator);
4654 }
4655 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4656 Stream &Stream::ThenBlasGemmBatched(
4657     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4658     uint64 k, std::complex<double> alpha,
4659     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4660     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4661     std::complex<double> beta,
4662     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4663     int batch_count) {
4664   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4665                                         b, ldb, beta, c, ldc, batch_count,
4666                                         /*scratch_allocator=*/nullptr);
4667 }
4668 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4669 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4670     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4671     uint64 k, std::complex<double> alpha,
4672     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4673     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4674     std::complex<double> beta,
4675     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4676     int batch_count, ScratchAllocator *scratch_allocator) {
4677   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4678             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4679             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4680 
4681   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4682                std::complex<double>,
4683                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4684                int,
4685                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4686                int, std::complex<double>,
4687                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4688                int, int, ScratchAllocator *>
4689       impl;
4690   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4691               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4692               scratch_allocator);
4693 }
4694 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4695 Stream &Stream::ThenBlasGemmStridedBatched(
4696     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4697     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4698     int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4699     float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4700     int batch_count) {
4701   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4702             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4703             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4704             PARAM(stride_c), PARAM(batch_count));
4705 
4706   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4707                const DeviceMemory<Eigen::half> &, int, int64,
4708                const DeviceMemory<Eigen::half> &, int, int64, float,
4709                DeviceMemory<Eigen::half> *, int, int64, int>
4710       impl;
4711   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4712               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4713               c, ldc, stride_c, batch_count);
4714 }
4715 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4716 Stream &Stream::ThenBlasGemmStridedBatched(
4717     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4718     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4719     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4720     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4721     int batch_count) {
4722   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4723             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4724             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4725             PARAM(stride_c), PARAM(batch_count));
4726 
4727   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4728                const DeviceMemory<float> &, int, int64,
4729                const DeviceMemory<float> &, int, int64, float,
4730                DeviceMemory<float> *, int, int64, int>
4731       impl;
4732   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4733               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4734               c, ldc, stride_c, batch_count);
4735 }
4736 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4737 Stream &Stream::ThenBlasGemmStridedBatched(
4738     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4739     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4740     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4741     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4742     int batch_count) {
4743   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4744             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4745             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4746             PARAM(stride_c), PARAM(batch_count));
4747 
4748   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4749                const DeviceMemory<double> &, int, int64,
4750                const DeviceMemory<double> &, int, int64, double,
4751                DeviceMemory<double> *, int, int64, int>
4752       impl;
4753   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4754               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4755               c, ldc, stride_c, batch_count);
4756 }
4757 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4758 Stream &Stream::ThenBlasGemmStridedBatched(
4759     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4760     uint64 k, std::complex<float> alpha,
4761     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4762     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4763     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4764     int64 stride_c, int batch_count) {
4765   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4766             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4767             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4768             PARAM(stride_c), PARAM(batch_count));
4769 
4770   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4771                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4772                int, int64, const DeviceMemory<std::complex<float>> &, int,
4773                int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4774                int, int64, int>
4775       impl;
4776   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4777               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4778               c, ldc, stride_c, batch_count);
4779 }
4780 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4781 Stream &Stream::ThenBlasGemmStridedBatched(
4782     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4783     uint64 k, std::complex<double> alpha,
4784     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4785     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4786     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4787     int64 stride_c, int batch_count) {
4788   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4789             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4790             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4791             PARAM(stride_c), PARAM(batch_count));
4792 
4793   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4794                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4795                int, int64, const DeviceMemory<std::complex<double>> &, int,
4796                int64, std::complex<double>,
4797                DeviceMemory<std::complex<double>> *, int, int64, int>
4798       impl;
4799   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4800               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4801               c, ldc, stride_c, batch_count);
4802 }
4803 
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4804 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4805   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4806 
4807   if (ok()) {
4808     if (rng::RngSupport *rng = parent_->AsRng()) {
4809       CheckError(rng->SetSeed(this, seed, seed_bytes));
4810     } else {
4811       SetError();
4812       LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4813     }
4814   } else {
4815     LOG(INFO) << DebugStreamPointers()
4816               << " did not set RNG seed: " << static_cast<const void *>(seed)
4817               << "; bytes: " << seed_bytes;
4818   }
4819   return *this;
4820 }
4821 
ThenPopulateRandUniform(DeviceMemory<float> * values)4822 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4823   VLOG_CALL(PARAM(values));
4824 
4825   if (ok()) {
4826     if (rng::RngSupport *rng = parent_->AsRng()) {
4827       CheckError(rng->DoPopulateRandUniform(this, values));
4828     } else {
4829       SetError();
4830       LOG(INFO) << DebugStreamPointers()
4831                 << " attempting to perform RNG operation using StreamExecutor"
4832                    " without RNG support.";
4833     }
4834   }
4835   return *this;
4836 }
4837 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4838 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4839                                          DeviceMemory<float> *values) {
4840   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4841 
4842   if (ok()) {
4843     if (rng::RngSupport *rng = parent_->AsRng()) {
4844       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4845     } else {
4846       SetError();
4847       LOG(INFO) << DebugStreamPointers()
4848                 << " attempting to perform RNG operation using StreamExecutor"
4849                    " without RNG support.";
4850     }
4851   }
4852   return *this;
4853 }
4854 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4855 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4856                                          DeviceMemory<double> *values) {
4857   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4858 
4859   if (ok()) {
4860     if (rng::RngSupport *rng = parent_->AsRng()) {
4861       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4862     } else {
4863       SetError();
4864       LOG(INFO) << DebugStreamPointers()
4865                 << " attempting to perform RNG operation using StreamExecutor"
4866                    " without RNG support.";
4867     }
4868   }
4869   return *this;
4870 }
4871 
ThenPopulateRandUniform(DeviceMemory<double> * values)4872 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4873   VLOG_CALL(PARAM(values));
4874 
4875   if (ok()) {
4876     if (rng::RngSupport *rng = parent_->AsRng()) {
4877       CheckError(rng->DoPopulateRandUniform(this, values));
4878     } else {
4879       SetError();
4880       LOG(INFO) << DebugStreamPointers()
4881                 << " attempting to perform RNG operation using StreamExecutor"
4882                    " without RNG support.";
4883     }
4884   }
4885   return *this;
4886 }
4887 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4888 Stream &Stream::ThenPopulateRandUniform(
4889     DeviceMemory<std::complex<float>> *values) {
4890   VLOG_CALL(PARAM(values));
4891 
4892   if (ok()) {
4893     if (rng::RngSupport *rng = parent_->AsRng()) {
4894       CheckError(rng->DoPopulateRandUniform(this, values));
4895     } else {
4896       SetError();
4897       LOG(INFO) << DebugStreamPointers()
4898                 << " attempting to perform RNG operation using StreamExecutor"
4899                    " without RNG support.";
4900     }
4901   }
4902   return *this;
4903 }
4904 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4905 Stream &Stream::ThenPopulateRandUniform(
4906     DeviceMemory<std::complex<double>> *values) {
4907   VLOG_CALL(PARAM(values));
4908 
4909   if (ok()) {
4910     if (rng::RngSupport *rng = parent_->AsRng()) {
4911       CheckError(rng->DoPopulateRandUniform(this, values));
4912     } else {
4913       SetError();
4914       LOG(INFO) << DebugStreamPointers()
4915                 << " attempting to perform RNG operation using StreamExecutor"
4916                    " without RNG support.";
4917     }
4918   }
4919   return *this;
4920 }
4921 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4922 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4923                            uint64 size) {
4924   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4925 
4926   if (ok()) {
4927     CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4928   } else {
4929     LOG(INFO) << DebugStreamPointers()
4930               << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4931   }
4932   return *this;
4933 }
4934 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4935 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4936                            uint64 size) {
4937   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4938 
4939   if (ok()) {
4940     CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4941   } else {
4942     LOG(INFO) << DebugStreamPointers()
4943               << " did not memcpy host-to-device; source: " << host_src;
4944   }
4945   return *this;
4946 }
4947 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4948 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4949                            const DeviceMemoryBase &gpu_src, uint64 size) {
4950   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4951 
4952   if (ok()) {
4953     CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4954   } else {
4955     LOG(INFO) << DebugStreamPointers()
4956               << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4957   }
4958   return *this;
4959 }
4960 
ThenMemZero(DeviceMemoryBase * location,uint64 size)4961 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4962   VLOG_CALL(PARAM(location), PARAM(size));
4963 
4964   if (ok()) {
4965     CheckStatus(parent_->MemZero(this, location, size));
4966   } else {
4967     LOG(INFO) << DebugStreamPointers()
4968               << " did not memzero GPU location; source: " << location;
4969   }
4970   return *this;
4971 }
4972 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4973 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4974                              uint64 size) {
4975   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4976 
4977   if (ok()) {
4978     CheckStatus(parent_->Memset32(this, location, pattern, size));
4979   } else {
4980     LOG(INFO) << DebugStreamPointers()
4981               << " did not memset GPU location; source: " << location
4982               << "; size: " << size << "; pattern: " << std::hex << pattern;
4983   }
4984   return *this;
4985 }
4986 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4987 Stream &Stream::ThenRnnForward(
4988     const dnn::RnnDescriptor &rnn_desc,
4989     const dnn::RnnSequenceTensorDescriptor &input_desc,
4990     const DeviceMemory<Eigen::half> &input_data,
4991     const dnn::RnnStateTensorDescriptor &input_h_desc,
4992     const DeviceMemory<Eigen::half> &input_h_data,
4993     const dnn::RnnStateTensorDescriptor &input_c_desc,
4994     const DeviceMemory<Eigen::half> &input_c_data,
4995     const DeviceMemory<Eigen::half> &params,
4996     const dnn::RnnSequenceTensorDescriptor &output_desc,
4997     DeviceMemory<Eigen::half> *output_data,
4998     const dnn::RnnStateTensorDescriptor &output_h_desc,
4999     DeviceMemory<Eigen::half> *output_h_data,
5000     const dnn::RnnStateTensorDescriptor &output_c_desc,
5001     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
5002     ScratchAllocator *reserve_space_allocator,
5003     ScratchAllocator *workspace_allocator,
5004     dnn::ProfileResult *output_profile_result) {
5005   // TODO(zhengxq): add VLOG PARAM calls.
5006   if (ok()) {
5007     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5008       auto status = dnn->DoRnnForward(
5009           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5010           input_c_desc, input_c_data, params, output_desc, output_data,
5011           output_h_desc, output_h_data, output_c_desc, output_c_data,
5012           is_training, reserve_space_allocator, workspace_allocator,
5013           output_profile_result);
5014       if (!status && !output_profile_result) {
5015         SetError();
5016       }
5017     } else {
5018       SetErrorAndLogNoDnnSupport();
5019     }
5020   }
5021   return *this;
5022 }
5023 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5024 Stream &Stream::ThenRnnForward(
5025     const dnn::RnnDescriptor &rnn_desc,
5026     const dnn::RnnSequenceTensorDescriptor &input_desc,
5027     const DeviceMemory<float> &input_data,
5028     const dnn::RnnStateTensorDescriptor &input_h_desc,
5029     const DeviceMemory<float> &input_h_data,
5030     const dnn::RnnStateTensorDescriptor &input_c_desc,
5031     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
5032     const dnn::RnnSequenceTensorDescriptor &output_desc,
5033     DeviceMemory<float> *output_data,
5034     const dnn::RnnStateTensorDescriptor &output_h_desc,
5035     DeviceMemory<float> *output_h_data,
5036     const dnn::RnnStateTensorDescriptor &output_c_desc,
5037     DeviceMemory<float> *output_c_data, bool is_training,
5038     ScratchAllocator *reserve_space_allocator,
5039     ScratchAllocator *workspace_allocator,
5040     dnn::ProfileResult *output_profile_result) {
5041   // TODO(zhengxq): add VLOG PARAM calls.
5042   if (ok()) {
5043     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5044       auto status = dnn->DoRnnForward(
5045           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5046           input_c_desc, input_c_data, params, output_desc, output_data,
5047           output_h_desc, output_h_data, output_c_desc, output_c_data,
5048           is_training, reserve_space_allocator, workspace_allocator,
5049           output_profile_result);
5050       if (!status && !output_profile_result) {
5051         SetError();
5052       }
5053     } else {
5054       SetErrorAndLogNoDnnSupport();
5055     }
5056   }
5057   return *this;
5058 }
5059 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5060 Stream &Stream::ThenRnnForward(
5061     const dnn::RnnDescriptor &rnn_desc,
5062     const dnn::RnnSequenceTensorDescriptor &input_desc,
5063     const DeviceMemory<double> &input_data,
5064     const dnn::RnnStateTensorDescriptor &input_h_desc,
5065     const DeviceMemory<double> &input_h_data,
5066     const dnn::RnnStateTensorDescriptor &input_c_desc,
5067     const DeviceMemory<double> &input_c_data,
5068     const DeviceMemory<double> &params,
5069     const dnn::RnnSequenceTensorDescriptor &output_desc,
5070     DeviceMemory<double> *output_data,
5071     const dnn::RnnStateTensorDescriptor &output_h_desc,
5072     DeviceMemory<double> *output_h_data,
5073     const dnn::RnnStateTensorDescriptor &output_c_desc,
5074     DeviceMemory<double> *output_c_data, bool is_training,
5075     ScratchAllocator *reserve_space_allocator,
5076     ScratchAllocator *workspace_allocator,
5077     dnn::ProfileResult *output_profile_result) {
5078   // TODO(zhengxq): add VLOG PARAM calls.
5079   if (ok()) {
5080     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5081       auto status = dnn->DoRnnForward(
5082           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5083           input_c_desc, input_c_data, params, output_desc, output_data,
5084           output_h_desc, output_h_data, output_c_desc, output_c_data,
5085           is_training, reserve_space_allocator, workspace_allocator,
5086           output_profile_result);
5087       if (!status && !output_profile_result) {
5088         SetError();
5089       }
5090     } else {
5091       SetErrorAndLogNoDnnSupport();
5092     }
5093   }
5094   return *this;
5095 }
5096 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5097 Stream &Stream::ThenRnnBackward(
5098     const dnn::RnnDescriptor &rnn_desc,
5099     const dnn::RnnSequenceTensorDescriptor &input_desc,
5100     const DeviceMemory<Eigen::half> &input_data,
5101     const dnn::RnnStateTensorDescriptor &input_h_desc,
5102     const DeviceMemory<Eigen::half> &input_h_data,
5103     const dnn::RnnStateTensorDescriptor &input_c_desc,
5104     const DeviceMemory<Eigen::half> &input_c_data,
5105     const DeviceMemory<Eigen::half> &params,
5106     const dnn::RnnSequenceTensorDescriptor &output_desc,
5107     const DeviceMemory<Eigen::half> &output_data,
5108     const dnn::RnnStateTensorDescriptor &output_h_desc,
5109     const DeviceMemory<Eigen::half> &output_h_data,
5110     const dnn::RnnStateTensorDescriptor &output_c_desc,
5111     const DeviceMemory<Eigen::half> &output_c_data,
5112     const DeviceMemory<Eigen::half> &output_backprop_data,
5113     const DeviceMemory<Eigen::half> &output_h_backprop_data,
5114     const DeviceMemory<Eigen::half> &output_c_backprop_data,
5115     DeviceMemory<Eigen::half> *input_backprop_data,
5116     DeviceMemory<Eigen::half> *input_h_backprop_data,
5117     DeviceMemory<Eigen::half> *input_c_backprop_data,
5118     DeviceMemory<Eigen::half> *params_backprop_data,
5119     DeviceMemory<uint8> *reserve_space_data,
5120     ScratchAllocator *workspace_allocator,
5121     dnn::ProfileResult *output_profile_result) {
5122   // TODO(zhengxq): add VLOG PARAM calls.
5123   if (ok()) {
5124     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5125       auto status = dnn->DoRnnBackward(
5126           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5127           input_c_desc, input_c_data, params, output_desc, output_data,
5128           output_h_desc, output_h_data, output_c_desc, output_c_data,
5129           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5130           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5131           params_backprop_data, reserve_space_data, workspace_allocator,
5132           output_profile_result);
5133       if (!status && !output_profile_result) {
5134         SetError();
5135       }
5136     } else {
5137       SetError();
5138       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5139     }
5140   }
5141   return *this;
5142 }
5143 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5144 Stream &Stream::ThenRnnBackward(
5145     const dnn::RnnDescriptor &rnn_desc,
5146     const dnn::RnnSequenceTensorDescriptor &input_desc,
5147     const DeviceMemory<float> &input_data,
5148     const dnn::RnnStateTensorDescriptor &input_h_desc,
5149     const DeviceMemory<float> &input_h_data,
5150     const dnn::RnnStateTensorDescriptor &input_c_desc,
5151     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
5152     const dnn::RnnSequenceTensorDescriptor &output_desc,
5153     const DeviceMemory<float> &output_data,
5154     const dnn::RnnStateTensorDescriptor &output_h_desc,
5155     const DeviceMemory<float> &output_h_data,
5156     const dnn::RnnStateTensorDescriptor &output_c_desc,
5157     const DeviceMemory<float> &output_c_data,
5158     const DeviceMemory<float> &output_backprop_data,
5159     const DeviceMemory<float> &output_h_backprop_data,
5160     const DeviceMemory<float> &output_c_backprop_data,
5161     DeviceMemory<float> *input_backprop_data,
5162     DeviceMemory<float> *input_h_backprop_data,
5163     DeviceMemory<float> *input_c_backprop_data,
5164     DeviceMemory<float> *params_backprop_data,
5165     DeviceMemory<uint8> *reserve_space_data,
5166     ScratchAllocator *workspace_allocator,
5167     dnn::ProfileResult *output_profile_result) {
5168   // TODO(zhengxq): add VLOG PARAM calls.
5169   if (ok()) {
5170     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5171       auto status = dnn->DoRnnBackward(
5172           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5173           input_c_desc, input_c_data, params, output_desc, output_data,
5174           output_h_desc, output_h_data, output_c_desc, output_c_data,
5175           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5176           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5177           params_backprop_data, reserve_space_data, workspace_allocator,
5178           output_profile_result);
5179       if (!status && !output_profile_result) {
5180         SetError();
5181       }
5182     } else {
5183       SetError();
5184       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5185     }
5186   }
5187   return *this;
5188 }
5189 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)5190 Stream &Stream::ThenRnnBackward(
5191     const dnn::RnnDescriptor &rnn_desc,
5192     const dnn::RnnSequenceTensorDescriptor &input_desc,
5193     const DeviceMemory<double> &input_data,
5194     const dnn::RnnStateTensorDescriptor &input_h_desc,
5195     const DeviceMemory<double> &input_h_data,
5196     const dnn::RnnStateTensorDescriptor &input_c_desc,
5197     const DeviceMemory<double> &input_c_data,
5198     const DeviceMemory<double> &params,
5199     const dnn::RnnSequenceTensorDescriptor &output_desc,
5200     const DeviceMemory<double> &output_data,
5201     const dnn::RnnStateTensorDescriptor &output_h_desc,
5202     const DeviceMemory<double> &output_h_data,
5203     const dnn::RnnStateTensorDescriptor &output_c_desc,
5204     const DeviceMemory<double> &output_c_data,
5205     const DeviceMemory<double> &output_backprop_data,
5206     const DeviceMemory<double> &output_h_backprop_data,
5207     const DeviceMemory<double> &output_c_backprop_data,
5208     DeviceMemory<double> *input_backprop_data,
5209     DeviceMemory<double> *input_h_backprop_data,
5210     DeviceMemory<double> *input_c_backprop_data,
5211     DeviceMemory<double> *params_backprop_data,
5212     DeviceMemory<uint8> *reserve_space_data,
5213     ScratchAllocator *workspace_allocator,
5214     dnn::ProfileResult *output_profile_result) {
5215   // TODO(zhengxq): add VLOG PARAM calls.
5216   if (ok()) {
5217     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5218       auto status = dnn->DoRnnBackward(
5219           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
5220           input_c_desc, input_c_data, params, output_desc, output_data,
5221           output_h_desc, output_h_data, output_c_desc, output_c_data,
5222           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
5223           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
5224           params_backprop_data, reserve_space_data, workspace_allocator,
5225           output_profile_result);
5226       if (!status && !output_profile_result) {
5227         SetError();
5228       }
5229     } else {
5230       SetError();
5231       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
5232     }
5233   }
5234   return *this;
5235 }
5236 
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)5237 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
5238                             const DeviceMemory<float> &probs_data,
5239                             absl::Span<const int> labels_data,
5240                             absl::Span<const int> labels_lengths_data,
5241                             absl::Span<const int> input_lengths_data,
5242                             DeviceMemory<float> *costs_data,
5243                             const dnn::RnnStateTensorDescriptor &grads_desc,
5244                             DeviceMemory<float> *grads_data,
5245                             ScratchAllocator *workspace_allocator) {
5246   if (ok()) {
5247     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5248       DeviceMemory<uint8> scratch_memory;
5249       auto status = dnn->PrepareForCtcLoss(
5250                            this, probs_desc, probs_data, grads_desc,
5251                            labels_data, labels_lengths_data, input_lengths_data,
5252                            workspace_allocator, &scratch_memory)
5253                         .ok();
5254       if (status) {
5255         status =
5256             dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
5257                            labels_lengths_data, input_lengths_data, costs_data,
5258                            grads_desc, grads_data, &scratch_memory);
5259       }
5260       if (!status) {
5261         SetError();
5262       }
5263     } else {
5264       SetErrorAndLogNoDnnSupport();
5265     }
5266   }
5267   return *this;
5268 }
5269 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)5270 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
5271                                     dnn::DataType input_type,
5272                                     const DeviceMemoryBase &input_data,
5273                                     const dnn::BatchDescriptor &output_desc,
5274                                     dnn::DataType output_type, float scale,
5275                                     DeviceMemoryBase *output_data) {
5276   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
5277             PARAM(output_desc), PARAM(output_type), PARAM(scale),
5278             PARAM(output_data));
5279   if (ok()) {
5280     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
5281       CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
5282                                         input_data, output_desc, output_type,
5283                                         scale, output_data));
5284     } else {
5285       SetErrorAndLogNoDnnSupport();
5286     }
5287   }
5288   return *this;
5289 }
5290 
ThenDoHostCallback(std::function<void ()> callback)5291 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
5292   VLOG_CALL(PARAM(callback));
5293 
5294   if (!ok()) {
5295     LOG(INFO) << DebugStreamPointers()
5296               << " was in error state before adding host callback";
5297   }
5298   CheckError(parent_->HostCallback(this, std::move(callback)));
5299   return *this;
5300 }
5301 
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)5302 Stream &Stream::ThenDoHostCallbackWithStatus(
5303     std::function<port::Status()> callback) {
5304   VLOG_CALL(PARAM(callback));
5305 
5306   if (!ok()) {
5307     LOG(INFO) << DebugStreamPointers()
5308               << " was in error state before adding host callback";
5309   }
5310   CheckError(parent_->HostCallback(this, std::move(callback)));
5311   return *this;
5312 }
5313 
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)5314 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
5315     std::function<void()> callback) {
5316   VLOG_CALL(PARAM(callback));
5317 
5318   if (!ok()) {
5319     LOG(INFO) << DebugStreamPointers()
5320               << " was in error state before adding callback to be run after "
5321                  "next block-host-until-done.";
5322   }
5323   absl::MutexLock lock(&mu_);
5324   after_block_host_until_done_callbacks_.push_back(std::move(callback));
5325   return *this;
5326 }
5327 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)5328 Stream &Stream::ThenFft(fft::Plan *plan,
5329                         const DeviceMemory<std::complex<float>> &input,
5330                         DeviceMemory<std::complex<float>> *output) {
5331   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5332 
5333   if (ok()) {
5334     if (fft::FftSupport *fft = parent_->AsFft()) {
5335       CheckError(fft->DoFft(this, plan, input, output));
5336     } else {
5337       SetError();
5338       LOG(INFO) << DebugStreamPointers()
5339                 << " attempting to perform FFT operation using StreamExecutor"
5340                    " without FFT support";
5341     }
5342   }
5343   return *this;
5344 }
5345 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)5346 Stream &Stream::ThenFft(fft::Plan *plan,
5347                         const DeviceMemory<std::complex<double>> &input,
5348                         DeviceMemory<std::complex<double>> *output) {
5349   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5350 
5351   if (ok()) {
5352     if (fft::FftSupport *fft = parent_->AsFft()) {
5353       CheckError(fft->DoFft(this, plan, input, output));
5354     } else {
5355       SetError();
5356       LOG(INFO) << DebugStreamPointers()
5357                 << " attempting to perform FFT operation using StreamExecutor"
5358                    " without FFT support";
5359     }
5360   }
5361   return *this;
5362 }
5363 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)5364 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
5365                         DeviceMemory<std::complex<float>> *output) {
5366   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5367 
5368   if (ok()) {
5369     if (fft::FftSupport *fft = parent_->AsFft()) {
5370       CheckError(fft->DoFft(this, plan, input, output));
5371     } else {
5372       SetError();
5373       LOG(INFO) << DebugStreamPointers()
5374                 << " attempting to perform FFT operation using StreamExecutor"
5375                    " without FFT support";
5376     }
5377   }
5378   return *this;
5379 }
5380 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)5381 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
5382                         DeviceMemory<std::complex<double>> *output) {
5383   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5384 
5385   if (ok()) {
5386     if (fft::FftSupport *fft = parent_->AsFft()) {
5387       CheckError(fft->DoFft(this, plan, input, output));
5388     } else {
5389       SetError();
5390       LOG(INFO) << DebugStreamPointers()
5391                 << " attempting to perform FFT operation using StreamExecutor"
5392                    " without FFT support";
5393     }
5394   }
5395   return *this;
5396 }
5397 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5398 Stream &Stream::ThenFft(fft::Plan *plan,
5399                         const DeviceMemory<std::complex<float>> &input,
5400                         DeviceMemory<float> *output) {
5401   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5402 
5403   if (ok()) {
5404     if (fft::FftSupport *fft = parent_->AsFft()) {
5405       CheckError(fft->DoFft(this, plan, input, output));
5406     } else {
5407       SetError();
5408       LOG(INFO) << DebugStreamPointers()
5409                 << " attempting to perform FFT operation using StreamExecutor"
5410                    " without FFT support";
5411     }
5412   }
5413   return *this;
5414 }
5415 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5416 Stream &Stream::ThenFft(fft::Plan *plan,
5417                         const DeviceMemory<std::complex<double>> &input,
5418                         DeviceMemory<double> *output) {
5419   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5420 
5421   if (ok()) {
5422     if (fft::FftSupport *fft = parent_->AsFft()) {
5423       CheckError(fft->DoFft(this, plan, input, output));
5424     } else {
5425       SetError();
5426       LOG(INFO) << DebugStreamPointers()
5427                 << " attempting to perform FFT operation using StreamExecutor"
5428                    " without FFT support";
5429     }
5430   }
5431   return *this;
5432 }
5433 
5434 // It looks confusing, but all this is doing is inserting a callback at the
5435 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5436 Stream &Stream::ThenEnqueueOnBackgroundThread(
5437     std::function<void(StreamExecutor *)> task) {
5438   VLOG_CALL(PARAM(task));
5439 
5440   StreamExecutor *stream_executor = this->parent_;
5441   std::function<void()> bound_task = std::bind(task, stream_executor);
5442 
5443   return ThenDoHostCallback([stream_executor, bound_task]() {
5444     stream_executor->EnqueueOnBackgroundThread(bound_task);
5445   });
5446 }
5447 
BlockHostUntilDone()5448 port::Status Stream::BlockHostUntilDone() {
5449   VLOG_CALL();
5450 
5451   if (!ok()) {
5452     port::Status status = port::Status(
5453         port::error::INTERNAL,
5454         "stream did not block host until done; was already in an error state");
5455     LOG(INFO) << DebugStreamPointers() << " " << status;
5456     return status;
5457   }
5458 
5459   temporary_memory_manager_.DeallocateFinalizedTemporaries();
5460 
5461   port::Status error = parent_->BlockHostUntilDone(this);
5462   CheckError(error.ok());
5463 
5464   RunAfterBlockHostUntilDoneCallbacks();
5465   return error;
5466 }
5467 
RunAfterBlockHostUntilDoneCallbacks()5468 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
5469   std::vector<std::function<void()>> callbacks;
5470   {
5471     absl::MutexLock lock(&mu_);
5472     std::swap(callbacks, after_block_host_until_done_callbacks_);
5473   }
5474   for (const auto &fn : callbacks) {
5475     fn();
5476   }
5477 }
5478 
DebugStreamPointers() const5479 string Stream::DebugStreamPointers() const {
5480   // Relies on the ToVlogString(const void*) overload above.
5481   return absl::StrCat("[stream=", ToVlogString(this),
5482                       ",impl=", ToVlogString(implementation_.get()), "]");
5483 }
5484 
CheckStatus(port::Status status)5485 void Stream::CheckStatus(port::Status status) {
5486   if (status.ok()) {
5487     return;
5488   }
5489   LOG(ERROR) << status;
5490   absl::MutexLock lock(&mu_);
5491   ok_ = false;
5492 }
5493 
5494 }  // namespace stream_executor
5495