• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/stream.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_or_device_scalar.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/platform.h"
24 #include "tensorflow/stream_executor/platform/logging.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29 
30 namespace stream_executor {
31 
32 namespace {
33 // Code to turn parameters to functions on stream into strings that
34 // will be VLOG'ed. We need overloads, instead of
35 // e.g. BatchDescriptorToVlogString(), as the code that calls these
36 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)37 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
38   return descriptor.ToShortString();
39 }
40 
ToVlogString(const dnn::FilterDescriptor & descriptor)41 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
42   return descriptor.ToShortString();
43 }
44 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)45 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
46   return descriptor.ToShortString();
47 }
48 
ToVlogString(const dnn::PoolingDescriptor & descriptor)49 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
50   return descriptor.ToShortString();
51 }
52 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)53 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
54   return descriptor.ToShortString();
55 }
56 
ToVlogString(dnn::ActivationMode mode)57 std::string ToVlogString(dnn::ActivationMode mode) {
58   return dnn::ActivationModeString(mode);
59 }
60 
ToVlogString(const dnn::AlgorithmConfig & algo_config)61 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
62   return algo_config.ToString();
63 }
64 
ToVlogString(dnn::ElementwiseOperation op)65 std::string ToVlogString(dnn::ElementwiseOperation op) {
66   return dnn::ElementwiseOperationString(op);
67 }
68 
ToVlogString(dnn::QuantizedActivationMode mode)69 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
70   return dnn::QuantizedActivationModeString(mode);
71 }
72 
ToVlogString(blas::Transpose t)73 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
74 
ToVlogString(blas::UpperLower ul)75 std::string ToVlogString(blas::UpperLower ul) {
76   return blas::UpperLowerString(ul);
77 }
78 
ToVlogString(blas::Diagonal d)79 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
80 
ToVlogString(blas::Side s)81 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
82 
ToVlogString(blas::ComputationType ty)83 std::string ToVlogString(blas::ComputationType ty) {
84   return blas::ComputationTypeString(ty);
85 }
86 
ToVlogString(const void * ptr)87 std::string ToVlogString(const void *ptr) {
88   if (ptr == nullptr) {
89     return "null";
90   }
91 
92   // StrCat does not convert pointers to text.
93   std::ostringstream out;
94   out << ptr;
95   return out.str();
96 }
97 
98 template <class T>
ToVlogString(const std::complex<T> & c)99 std::string ToVlogString(const std::complex<T> &c) {
100   // StrCat does not convert std::complex to text.
101   std::ostringstream out;
102   out << c;
103   return out.str();
104 }
105 
106 template <class T>
ToVlogString(const std::function<T> & f)107 std::string ToVlogString(const std::function<T> &f) {
108   return f == nullptr ? "null" : "<non-null function>";
109 }
110 
ToVlogString(const DeviceMemoryBase & memory)111 std::string ToVlogString(const DeviceMemoryBase &memory) {
112   return ToVlogString(memory.opaque());
113 }
114 
ToVlogString(const DeviceMemoryBase * memory)115 std::string ToVlogString(const DeviceMemoryBase *memory) {
116   return memory == nullptr ? "null" : ToVlogString(*memory);
117 }
118 
ToVlogString(const Eigen::half & h)119 std::string ToVlogString(const Eigen::half &h) {
120   return absl::StrCat(static_cast<float>(h));
121 }
122 
ToVlogString(int i)123 std::string ToVlogString(int i) { return absl::StrCat(i); }
124 
ToVlogString(uint32 i)125 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
126 
ToVlogString(uint64 i)127 std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
128 
ToVlogString(int64 i)129 std::string ToVlogString(int64 i) { return absl::StrCat(i); }
130 
ToVlogString(float f)131 std::string ToVlogString(float f) { return absl::StrCat(f); }
132 
ToVlogString(double d)133 std::string ToVlogString(double d) { return absl::StrCat(d); }
134 
135 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)136 std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
137   if (memory_or_constant.is_pointer()) {
138     return ToVlogString(memory_or_constant.pointer());
139   }
140   return ToVlogString(memory_or_constant.value());
141 }
142 
143 template <class T>
ToVlogString(port::ArraySlice<T> elements)144 std::string ToVlogString(port::ArraySlice<T> elements) {
145   std::string str = absl::StrCat(
146       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
147       elements.size(), "]{");
148   const char *separator = "";
149   size_t max_to_show = std::numeric_limits<size_t>::max();
150   if (!VLOG_IS_ON(2)) {
151     max_to_show = 5;
152   } else if (!VLOG_IS_ON(3)) {
153     max_to_show = 20;
154   } else if (!VLOG_IS_ON(11)) {
155     max_to_show = 1000;
156   }
157   for (size_t i = 0; i < elements.size(); ++i) {
158     if (i == max_to_show) {
159       str += ", ...";
160       break;
161     }
162     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
163     separator = ", ";
164   }
165   str += "}";
166   return str;
167 }
168 
169 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)170 std::string ToVlogString(port::MutableArraySlice<T> elements) {
171   return ToVlogString(port::ArraySlice<T>(elements));
172 }
173 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)174 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
175   switch (depth_to_space_layout) {
176     case dnn::DepthToSpaceLayout::DepthHeightWidth:
177       return "DepthToSpaceLayout::DepthHeightWidth";
178   }
179   return "unknown DepthToSpaceLayout";
180 }
181 
ToVlogString(dnn::DataType data_type)182 std::string ToVlogString(dnn::DataType data_type) {
183   switch (data_type) {
184     case dnn::DataType::kFloat:
185       return "dnn::DataType::kFloat";
186     case dnn::DataType::kDouble:
187       return "dnn::DataType::kDouble";
188     case dnn::DataType::kHalf:
189       return "dnn::DataType::kHalf";
190     case dnn::DataType::kInt8:
191       return "dnn::DataType::kInt8";
192     case dnn::DataType::kInt32:
193       return "dnn::DataType::kInt32";
194     default:
195       return "unknown DataType";
196   }
197 }
198 
199 // Used together with PARAM to VLOG calls made to the stream. Intended
200 // to be used like this:
201 //
202 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
203 //
204 // where a and b are the parameters to MyFunction.
205 //
206 // See VLOG_CALL for a short-hand for this. This way of doing it saves
207 // a tremendous amount of boilerplate code given how many functions
208 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)209 std::string CallStr(const char *function_name, Stream *stream,
210                     std::vector<std::pair<const char *, std::string>> params) {
211   // Do not call this function unless VLOG is on since just
212   // constructing all the strings in params is expensive.
213   CHECK(VLOG_IS_ON(1));
214 
215   std::string str = absl::StrCat(stream->DebugStreamPointers(),
216                                  " Called Stream::", function_name, "(");
217   const char *separator = "";
218   for (const auto &param : params) {
219     absl::StrAppend(&str, separator, param.first, "=", param.second);
220     separator = ", ";
221   }
222   absl::StrAppend(&str, ")");
223   if (VLOG_IS_ON(10)) {
224     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
225   }
226   return str;
227 }
228 
229 // Use this macro to avoid having to type every parameter twice to log
230 // it with VLOG and CallStr.
231 #define PARAM(parameter) \
232   { #parameter, ToVlogString(parameter) }
233 
234 // Use this macro to avoid having to type out the name of each
235 // function and to save some boilerplate. Intended to be used like this:
236 //
237 //   VLOG_CALL(PARAM(a), PARAM(b))
238 //
239 // This saves a tremendous amount of boilerplate compared to the alternative:
240 //
241 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
242 //           << ", b=" << ToVlogString(b);
243 //
244 // Note here that most of the parameter names are not short and that
245 // most of the functions take many more than 2 parameters.
246 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
247 
248 }  // namespace
249 
Stream(StreamExecutor * parent)250 Stream::Stream(StreamExecutor *parent)
251     : parent_(parent),
252       implementation_(parent->implementation()->GetStreamImplementation()),
253       allocated_(false),
254       status_(port::InternalError("Uninitialized stream")),
255       temporary_memory_manager_(this) {
256   VLOG_CALL(PARAM(parent));
257 }
258 
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)259 Stream::Stream(StreamExecutor *parent,
260                internal::StreamInterface *implementation)
261     : parent_(parent),
262       implementation_(implementation),
263       allocated_(false),
264       status_(port::InternalError("Uninitialized stream")),
265       temporary_memory_manager_(this) {
266   VLOG_CALL(PARAM(parent), PARAM(implementation));
267 }
268 
~Stream()269 Stream::~Stream() {
270   VLOG_CALL();
271 
272   // Ensure the stream is completed.
273   auto status = BlockHostUntilDone();
274   if (!status.ok()) {
275     LOG(WARNING) << "Error blocking host until done in stream destructor: "
276                  << status;
277   }
278   temporary_memory_manager_.ForceDeallocateAll();
279   RunAfterBlockHostUntilDoneCallbacks();
280 
281   if (allocated_) {
282     parent_->DeallocateStream(this);
283   }
284 }
285 
RefreshStatus()286 port::Status Stream::RefreshStatus() {
287   port::Status status = parent_->GetStatus(this);
288   // We should not put the stream in an error state, just because the GetStatus
289   // method is unimplemented.
290   if (status != port::Status(port::error::UNIMPLEMENTED,
291                              "GetStatus is not supported on this executor.")) {
292     CheckStatus(status);
293   }
294   return status;
295 }
296 
Init()297 Stream &Stream::Init() {
298   VLOG_CALL();
299 
300   absl::MutexLock lock(&mu_);
301   CHECK_EQ(false, allocated_)
302       << "stream appears to already have been initialized";
303   CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
304 
305   if (parent_->AllocateStream(this)) {
306     // Successful initialization!
307     allocated_ = true;
308     status_ = port::Status::OK();
309   } else {
310     LOG(ERROR) << "failed to allocate stream during initialization";
311   }
312 
313   return *this;
314 }
315 
InitTimer(Timer * timer)316 Stream &Stream::InitTimer(Timer *timer) {
317   VLOG_CALL(PARAM(timer));
318 
319     CheckError(parent_->AllocateTimer(timer));
320   return *this;
321 }
322 
InitWithTimer(Timer * timer)323 Stream &Stream::InitWithTimer(Timer *timer) {
324   VLOG_CALL(PARAM(timer));
325 
326   return Init().InitTimer(timer);
327 }
328 
ThenRecordEvent(Event * event)329 Stream &Stream::ThenRecordEvent(Event *event) {
330   VLOG_CALL(PARAM(event));
331 
332   port::Status status = parent_->RecordEvent(this, event);
333   if (!status.ok()) {
334     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
335                << "; not marking stream as bad, as the Event object may be "
336                << "at fault. Monitor for further errors.";
337   }
338 
339   return *this;
340 }
341 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)342 Stream &Stream::ThenBatchNormalizationForward(
343     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
344     const DeviceMemory<float> &offset,
345     const DeviceMemory<float> &estimated_mean,
346     const DeviceMemory<float> &estimated_variance,
347     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
348     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
349     const double exponential_average_factor,
350     dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
351     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
352     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
353     bool is_training,
354     ScratchAllocator *reserve_space_allocator,
355     ScratchAllocator *workspace_allocator) {
356   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
357             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
358   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
359     CheckError(dnn->DoBatchNormalizationForward(
360         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
361         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
362         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
363         is_training, reserve_space_allocator, workspace_allocator));
364   } else {
365     SetErrorAndLogNoDnnSupport();
366   }
367   return *this;
368 }
369 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)370 Stream &Stream::ThenBatchNormalizationBackward(
371     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
372     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
373     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
374     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
375     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
376     DeviceMemory<float> *offset_backprop,
377     DeviceMemory<uint8> *reserve_space_data,
378     ScratchAllocator *workspace_allocator) {
379   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
380             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
381             PARAM(scale_backprop), PARAM(offset_backprop));
382   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
383     CheckError(dnn->DoBatchNormalizationBackward(
384         this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
385         epsilon, x_backprop, scale_backprop, offset_backprop,
386         reserve_space_data, workspace_allocator));
387   } else {
388     SetErrorAndLogNoDnnSupport();
389   }
390   return *this;
391 }
392 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)393 Stream &Stream::ThenBatchNormalizationForward(
394     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
395     const DeviceMemory<float> &offset,
396     const DeviceMemory<float> &estimated_mean,
397     const DeviceMemory<float> &estimated_variance,
398     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
399     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
400     const double exponential_average_factor,
401     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
402     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
403     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
404     bool is_training,
405     ScratchAllocator *reserve_space_allocator,
406     ScratchAllocator *workspace_allocator) {
407   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
408             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
409   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
410     CheckError(dnn->DoBatchNormalizationForward(
411         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
412         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
413         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
414         is_training, reserve_space_allocator, workspace_allocator));
415   } else {
416     SetErrorAndLogNoDnnSupport();
417   }
418   return *this;
419 }
420 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)421 Stream &Stream::ThenBatchNormalizationBackward(
422     const DeviceMemory<Eigen::half> &y_backprop,
423     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
424     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
425     const dnn::BatchDescriptor &x_desc,
426     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
427     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
428     DeviceMemory<float> *offset_backprop,
429     DeviceMemory<uint8> *reserve_space_data,
430     ScratchAllocator *workspace_allocator) {
431   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
432             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
433             PARAM(scale_backprop), PARAM(offset_backprop));
434   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
435     CheckError(dnn->DoBatchNormalizationBackward(
436         this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
437         epsilon, x_backprop, scale_backprop, offset_backprop,
438         reserve_space_data, workspace_allocator));
439 
440   } else {
441     SetErrorAndLogNoDnnSupport();
442   }
443   return *this;
444 }
445 
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<double> & conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<double> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<double> & side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<double> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<double> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)446 port::Status Stream::FusedConvolveWithAlgorithm(
447     const dnn::BatchDescriptor &conv_input_descriptor,
448     const DeviceMemory<double> &conv_input_data, double conv_input_scale,
449     const dnn::FilterDescriptor &filter_descriptor,
450     const DeviceMemory<double> &filter_data,
451     const dnn::ConvolutionDescriptor &convolution_descriptor,
452     const DeviceMemory<double> &side_input_data, double side_input_scale,
453     const dnn::BatchDescriptor &bias_descriptor,
454     const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
455     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
456     ScratchAllocator *scratch_allocator,
457     const dnn::AlgorithmConfig &algorithm_config,
458     dnn::ProfileResult *output_profile_result) {
459   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
460             PARAM(conv_input_scale), PARAM(filter_descriptor),
461             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
462             PARAM(side_input_data), PARAM(side_input_scale),
463             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
464             PARAM(algorithm_config));
465 
466   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
467     return dnn->DoFusedConvolve(
468         this, conv_input_descriptor, conv_input_data, conv_input_scale,
469         filter_descriptor, filter_data, convolution_descriptor, side_input_data,
470         side_input_scale, bias_descriptor, biases, activation_mode,
471         output_descriptor, output, scratch_allocator, algorithm_config,
472         output_profile_result);
473   }
474   return port::UnimplementedError("DNN library is not found.");
475 }
476 
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)477 port::Status Stream::FusedConvolveWithAlgorithm(
478     const dnn::BatchDescriptor &conv_input_descriptor,
479     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
480     const dnn::FilterDescriptor &filter_descriptor,
481     const DeviceMemory<float> &filter_data,
482     const dnn::ConvolutionDescriptor &convolution_descriptor,
483     const DeviceMemory<float> &side_input_data, float side_input_scale,
484     const dnn::BatchDescriptor &bias_descriptor,
485     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
486     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
487     ScratchAllocator *scratch_allocator,
488     const dnn::AlgorithmConfig &algorithm_config,
489     dnn::ProfileResult *output_profile_result) {
490   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
491             PARAM(conv_input_scale), PARAM(filter_descriptor),
492             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
493             PARAM(side_input_data), PARAM(side_input_scale),
494             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
495             PARAM(algorithm_config));
496 
497   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
498     return dnn->DoFusedConvolve(
499         this, conv_input_descriptor, conv_input_data, conv_input_scale,
500         filter_descriptor, filter_data, convolution_descriptor, side_input_data,
501         side_input_scale, bias_descriptor, biases, activation_mode,
502         output_descriptor, output, scratch_allocator, algorithm_config,
503         output_profile_result);
504   }
505   return port::UnimplementedError("DNN library is not found.");
506 }
507 
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)508 port::Status Stream::FusedConvolveWithAlgorithm(
509     const dnn::BatchDescriptor &conv_input_descriptor,
510     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
511     const dnn::FilterDescriptor &filter_descriptor,
512     const DeviceMemory<Eigen::half> &filter_data,
513     const dnn::ConvolutionDescriptor &convolution_descriptor,
514     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
515     const dnn::BatchDescriptor &bias_descriptor,
516     const DeviceMemory<Eigen::half> &biases,
517     dnn::ActivationMode activation_mode,
518     const dnn::BatchDescriptor &output_descriptor,
519     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
520     const dnn::AlgorithmConfig &algorithm_config,
521     dnn::ProfileResult *output_profile_result) {
522   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
523             PARAM(conv_input_scale), PARAM(filter_descriptor),
524             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
525             PARAM(side_input_data), PARAM(side_input_scale),
526             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
527             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
528 
529   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
530     return dnn->DoFusedConvolve(
531         this, conv_input_descriptor, conv_input_data, conv_input_scale,
532         filter_descriptor, filter_data, convolution_descriptor, side_input_data,
533         side_input_scale, bias_descriptor, biases, activation_mode,
534         output_descriptor, output, scratch_allocator, algorithm_config,
535         output_profile_result);
536   }
537   return port::UnimplementedError("DNN library is not found.");
538 }
539 
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)540 port::Status Stream::FusedConvolveWithAlgorithm(
541     const dnn::BatchDescriptor &conv_input_descriptor,
542     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
543     const dnn::FilterDescriptor &filter_descriptor,
544     const DeviceMemory<int8> &filter_data,
545     const dnn::ConvolutionDescriptor &convolution_descriptor,
546     const DeviceMemory<int8> &side_input_data, float side_input_scale,
547     const dnn::BatchDescriptor &bias_descriptor,
548     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
549     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
550     ScratchAllocator *scratch_allocator,
551     const dnn::AlgorithmConfig &algorithm_config,
552     dnn::ProfileResult *output_profile_result) {
553   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
554             PARAM(conv_input_scale), PARAM(filter_descriptor),
555             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
556             PARAM(side_input_data), PARAM(side_input_scale),
557             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
558             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
559 
560   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
561     return dnn->DoFusedConvolve(
562         this, conv_input_descriptor, conv_input_data, conv_input_scale,
563         filter_descriptor, filter_data, convolution_descriptor, side_input_data,
564         side_input_scale, bias_descriptor, biases, activation_mode,
565         output_descriptor, output, scratch_allocator, algorithm_config,
566         output_profile_result);
567   }
568   return port::UnimplementedError("DNN library is not found.");
569 }
570 
FusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)571 port::Status Stream::FusedConvolveWithAlgorithm(
572     const dnn::BatchDescriptor &conv_input_descriptor,
573     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
574     const dnn::FilterDescriptor &filter_descriptor,
575     const DeviceMemory<int8> &filter_data,
576     const dnn::ConvolutionDescriptor &convolution_descriptor,
577     const DeviceMemory<float> &side_input_data, float side_input_scale,
578     const dnn::BatchDescriptor &bias_descriptor,
579     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
580     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
581     ScratchAllocator *scratch_allocator,
582     const dnn::AlgorithmConfig &algorithm_config,
583     dnn::ProfileResult *output_profile_result) {
584   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
585             PARAM(conv_input_scale), PARAM(filter_descriptor),
586             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
587             PARAM(side_input_data), PARAM(side_input_scale),
588             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
589             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
590 
591   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
592     return dnn->DoFusedConvolve(
593         this, conv_input_descriptor, conv_input_data, conv_input_scale,
594         filter_descriptor, filter_data, convolution_descriptor, side_input_data,
595         side_input_scale, bias_descriptor, biases, activation_mode,
596         output_descriptor, output, scratch_allocator, algorithm_config,
597         output_profile_result);
598   }
599   return port::UnimplementedError("DNN library is not found.");
600 }
601 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)602 Stream &Stream::ThenConvolve(
603     const dnn::BatchDescriptor &input_descriptor,
604     const DeviceMemory<float> &input_data,
605     const dnn::FilterDescriptor &filter_descriptor,
606     const DeviceMemory<float> &filter_data,
607     const dnn::ConvolutionDescriptor &convolution_descriptor,
608     const dnn::BatchDescriptor &output_descriptor,
609     DeviceMemory<float> *output) {
610   if (ok()) {
611     CheckError(ConvolveWithAlgorithm(
612                    input_descriptor, input_data, filter_descriptor, filter_data,
613                    convolution_descriptor, output_descriptor, output,
614                    /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
615                    /*output_profile_result=*/nullptr)
616                    .ok());
617   }
618   return *this;
619 }
620 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)621 Stream &Stream::ThenConvolveQuantized(
622     const dnn::BatchDescriptor &input_descriptor,
623     const DeviceMemory<float> &input_data,
624     const dnn::FilterDescriptor &filter_descriptor,
625     const DeviceMemory<int8> &filter_coefficients,
626     const DeviceMemory<float> &coefficient_scales,
627     const dnn::ConvolutionDescriptor &convolution_descriptor,
628     const dnn::BatchDescriptor &output_descriptor,
629     DeviceMemory<float> *output) {
630   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
631             PARAM(filter_descriptor), PARAM(filter_coefficients),
632             PARAM(coefficient_scales), PARAM(convolution_descriptor),
633             PARAM(output_descriptor), PARAM(output));
634 
635   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
636     CheckError(dnn->DoConvolveQuantized(
637         this, input_descriptor, input_data, filter_descriptor,
638         filter_coefficients, coefficient_scales, convolution_descriptor,
639         output_descriptor, output));
640   } else {
641     SetError();
642     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
643                     "without DNN support";
644   }
645   return *this;
646 }
647 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)648 Stream &Stream::ThenConvolveQuantized(
649     const dnn::BatchDescriptor &input_descriptor,
650     const DeviceMemory<float> &input_data,
651     const dnn::FilterDescriptor &filter_descriptor,
652     const DeviceMemory<int16> &filter_coefficients,
653     const DeviceMemory<float> &coefficient_scales,
654     const dnn::ConvolutionDescriptor &convolution_descriptor,
655     const dnn::BatchDescriptor &output_descriptor,
656     DeviceMemory<float> *output) {
657   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
658             PARAM(filter_descriptor), PARAM(filter_coefficients),
659             PARAM(coefficient_scales), PARAM(convolution_descriptor),
660             PARAM(output_descriptor), PARAM(output));
661 
662   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
663     CheckError(dnn->DoConvolveQuantized(
664         this, input_descriptor, input_data, filter_descriptor,
665         filter_coefficients, coefficient_scales, convolution_descriptor,
666         output_descriptor, output));
667   } else {
668     SetError();
669     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
670                     "without DNN support";
671   }
672   return *this;
673 }
674 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)675 Stream &Stream::ThenSeparableConvolve(
676     const dnn::BatchDescriptor &batch_descriptor,
677     const DeviceMemory<float> &input_data,
678     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
679     const DeviceMemory<float> &first_weights,
680     const DeviceMemory<float> &second_weights,
681     const dnn::ConvolutionDescriptor &convolution_descriptor,
682     const dnn::BatchDescriptor &output_descriptor,
683     DeviceMemory<float> *output) {
684   VLOG_CALL(
685       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
686       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
687       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
688 
689   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
690     CheckError(dnn->DoSeparableConvolve(
691         this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
692         first_weights, second_weights, convolution_descriptor,
693         output_descriptor, output));
694   } else {
695     SetErrorAndLogNoDnnSupport();
696   }
697   return *this;
698 }
699 
700 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)701 Stream &Stream::ThenConvolveBackwardBiasImpl(
702     const dnn::BatchDescriptor &input_descriptor,
703     const DeviceMemory<T> &input_data,
704     const dnn::BatchDescriptor &bias_descriptor,
705     DeviceMemory<T> *backward_bias_data) {
706   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
707             PARAM(backward_bias_data));
708 
709   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
710     CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
711                                            bias_descriptor,
712                                            backward_bias_data));
713   } else {
714     SetErrorAndLogNoDnnSupport();
715   }
716   return *this;
717 }
718 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)719 Stream &Stream::ThenConvolveBackwardBias(
720     const dnn::BatchDescriptor &input_descriptor,
721     const DeviceMemory<double> &input_data,
722     const dnn::BatchDescriptor &bias_descriptor,
723     DeviceMemory<double> *backward_bias_data) {
724   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
725                                       bias_descriptor, backward_bias_data);
726 }
727 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)728 Stream &Stream::ThenConvolveBackwardBias(
729     const dnn::BatchDescriptor &input_descriptor,
730     const DeviceMemory<float> &input_data,
731     const dnn::BatchDescriptor &bias_descriptor,
732     DeviceMemory<float> *backward_bias_data) {
733   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
734                                       bias_descriptor, backward_bias_data);
735 }
736 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)737 Stream &Stream::ThenConvolveBackwardBias(
738     const dnn::BatchDescriptor &input_descriptor,
739     const DeviceMemory<Eigen::half> &input_data,
740     const dnn::BatchDescriptor &bias_descriptor,
741     DeviceMemory<Eigen::half> *backward_bias_data) {
742   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
743                                       bias_descriptor, backward_bias_data);
744 }
745 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)746 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
747                            const DeviceMemory<float> &weights,
748                            const dnn::BatchDescriptor &input_dimensions,
749                            const dnn::BatchDescriptor &output_dimensions,
750                            DeviceMemory<float> *output_data) {
751   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
752             PARAM(output_dimensions), PARAM(output_data));
753 
754   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
755     CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
756                              output_dimensions, output_data));
757   } else {
758     SetErrorAndLogNoDnnSupport();
759   }
760   return *this;
761 }
762 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)763 Stream &Stream::ThenMatMulQuantized(
764     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
765     const DeviceMemory<float> &weight_scales,
766     const dnn::BatchDescriptor &input_dimensions,
767     const dnn::BatchDescriptor &output_dimensions,
768     DeviceMemory<float> *output_data) {
769   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
770             PARAM(input_dimensions), PARAM(output_dimensions),
771             PARAM(output_data));
772 
773   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
774     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
775                                       input_dimensions, output_dimensions,
776                                       output_data));
777   } else {
778     SetErrorAndLogNoDnnSupport();
779   }
780   return *this;
781 }
782 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)783 Stream &Stream::ThenMatMulQuantized(
784     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
785     const DeviceMemory<float> &weight_scales,
786     const dnn::BatchDescriptor &input_dimensions,
787     const dnn::BatchDescriptor &output_dimensions,
788     DeviceMemory<float> *output_data) {
789   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
790             PARAM(input_dimensions), PARAM(output_dimensions),
791             PARAM(output_data));
792 
793   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
795                                       input_dimensions, output_dimensions,
796                                       output_data));
797   } else {
798     SetErrorAndLogNoDnnSupport();
799   }
800   return *this;
801 }
802 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)803 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
804                             const DeviceMemory<float> &biases,
805                             const dnn::BatchDescriptor &dimensions,
806                             DeviceMemory<float> *output_data) {
807   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
808             PARAM(output_data));
809 
810   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
811     CheckError(
812         dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
813   } else {
814     SetErrorAndLogNoDnnSupport();
815   }
816   return *this;
817 }
818 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)819 Stream &Stream::ThenPoolForward(
820     const dnn::PoolingDescriptor &pooling_dimensions,
821     const dnn::BatchDescriptor &input_dimensions,
822     const DeviceMemory<double> &input_data,
823     const dnn::BatchDescriptor &output_dimensions,
824     DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
825   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
826             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
827             PARAM(workspace_allocator));
828 
829   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
830     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
831                                   input_data, output_dimensions, output_data,
832                                   workspace_allocator));
833   } else {
834     SetError();
835     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
836                     "without DNN support";
837   }
838   return *this;
839 }
840 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)841 Stream &Stream::ThenPoolForward(
842     const dnn::PoolingDescriptor &pooling_dimensions,
843     const dnn::BatchDescriptor &input_dimensions,
844     const DeviceMemory<float> &input_data,
845     const dnn::BatchDescriptor &output_dimensions,
846     DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
847   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
848             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
849             PARAM(workspace_allocator));
850 
851   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
852     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
853                                   input_data, output_dimensions, output_data,
854                                   workspace_allocator));
855   } else {
856     SetErrorAndLogNoDnnSupport();
857   }
858   return *this;
859 }
860 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)861 Stream &Stream::ThenPoolForward(
862     const dnn::PoolingDescriptor &pooling_dimensions,
863     const dnn::BatchDescriptor &input_dimensions,
864     const DeviceMemory<Eigen::half> &input_data,
865     const dnn::BatchDescriptor &output_dimensions,
866     DeviceMemory<Eigen::half> *output_data,
867     ScratchAllocator *workspace_allocator) {
868   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
869             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
870             PARAM(workspace_allocator));
871 
872   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
873     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
874                                   input_data, output_dimensions, output_data,
875                                   workspace_allocator));
876   } else {
877     SetErrorAndLogNoDnnSupport();
878   }
879   return *this;
880 }
881 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)882 Stream &Stream::ThenPoolForward(
883     const dnn::PoolingDescriptor &pooling_dimensions,
884     const dnn::BatchDescriptor &input_dimensions,
885     const DeviceMemory<int8> &input_data,
886     const dnn::BatchDescriptor &output_dimensions,
887     DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
888   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
889             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
890             PARAM(workspace_allocator));
891 
892   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
893     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
894                                   input_data, output_dimensions, output_data,
895                                   workspace_allocator));
896   } else {
897     SetErrorAndLogNoDnnSupport();
898   }
899   return *this;
900 }
901 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)902 Stream &Stream::ThenPoolBackward(
903     const dnn::PoolingDescriptor &pooling_dimensions,
904     const dnn::BatchDescriptor &input_dimensions,
905     const DeviceMemory<double> &input_data,
906     const dnn::BatchDescriptor &output_dimensions,
907     const DeviceMemory<double> &output_data,
908     const DeviceMemory<double> &input_diff_data,
909     DeviceMemory<double> *output_diff_data,
910     ScratchAllocator *workspace_allocator) {
911   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
912             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
913             PARAM(input_diff_data), PARAM(output_diff_data),
914             PARAM(workspace_allocator));
915 
916   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
917     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
918                                    input_data, output_dimensions, output_data,
919                                    input_diff_data, output_diff_data,
920                                    workspace_allocator));
921   } else {
922     SetError();
923     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
924                     "without DNN support";
925   }
926   return *this;
927 }
928 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)929 Stream &Stream::ThenPoolBackward(
930     const dnn::PoolingDescriptor &pooling_dimensions,
931     const dnn::BatchDescriptor &input_dimensions,
932     const DeviceMemory<float> &input_data,
933     const dnn::BatchDescriptor &output_dimensions,
934     const DeviceMemory<float> &output_data,
935     const DeviceMemory<float> &input_diff_data,
936     DeviceMemory<float> *output_diff_data,
937     ScratchAllocator *workspace_allocator) {
938   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
939             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
940             PARAM(input_diff_data), PARAM(output_diff_data),
941             PARAM(workspace_allocator));
942 
943   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
944     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
945                                    input_data, output_dimensions, output_data,
946                                    input_diff_data, output_diff_data,
947                                    workspace_allocator));
948   } else {
949     SetErrorAndLogNoDnnSupport();
950   }
951   return *this;
952 }
953 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)954 Stream &Stream::ThenPoolBackward(
955     const dnn::PoolingDescriptor &pooling_dimensions,
956     const dnn::BatchDescriptor &input_dimensions,
957     const DeviceMemory<Eigen::half> &input_data,
958     const dnn::BatchDescriptor &output_dimensions,
959     const DeviceMemory<Eigen::half> &output_data,
960     const DeviceMemory<Eigen::half> &input_diff_data,
961     DeviceMemory<Eigen::half> *output_diff_data,
962     ScratchAllocator *workspace_allocator) {
963   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
964             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
965             PARAM(input_diff_data), PARAM(output_diff_data),
966             PARAM(workspace_allocator));
967 
968   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
969     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
970                                    input_data, output_dimensions, output_data,
971                                    input_diff_data, output_diff_data,
972                                    workspace_allocator));
973   } else {
974     SetErrorAndLogNoDnnSupport();
975   }
976   return *this;
977 }
978 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)979 Stream &Stream::ThenNormalizeWithDimensions(
980     const dnn::NormalizeDescriptor &normalize_descriptor,
981     const dnn::BatchDescriptor &dimensions,
982     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
983   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
984             PARAM(output_data));
985 
986   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
987     CheckError(dnn->DoNormalizeWithDimensions(
988         this, normalize_descriptor, dimensions, input_data, output_data));
989   } else {
990     SetErrorAndLogNoDnnSupport();
991   }
992   return *this;
993 }
994 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)995 Stream &Stream::ThenNormalizeBackwardWithDimensions(
996     const dnn::NormalizeDescriptor &normalize_descriptor,
997     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
998     const DeviceMemory<float> &normalized_data,
999     const DeviceMemory<float> &normalized_variable_gradient,
1000     DeviceMemory<float> *raw_variable_gradient,
1001     ScratchAllocator *workspace_allocator) {
1002   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1003             PARAM(normalized_data), PARAM(normalized_variable_gradient),
1004             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
1005 
1006   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1007     CheckError(dnn->DoNormalizeBackwardWithDimensions(
1008         this, normalize_descriptor, dimensions, raw_data, normalized_data,
1009         normalized_variable_gradient, raw_variable_gradient,
1010         workspace_allocator));
1011   } else {
1012     SetErrorAndLogNoDnnSupport();
1013   }
1014   return *this;
1015 }
1016 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1017 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1018                              const dnn::BatchDescriptor &dimensions,
1019                              const DeviceMemory<float> &input_data,
1020                              DeviceMemory<float> *output_data) {
1021   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1022                                  output_data, /*options=*/0);
1023 }
1024 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1025 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1026                                         const dnn::BatchDescriptor &dimensions,
1027                                         const DeviceMemory<float> &input_data,
1028                                         DeviceMemory<float> *output_data,
1029                                         uint64 options) {
1030   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1031             PARAM(output_data), PARAM(options));
1032 
1033   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1034     CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1035                                output_data, options));
1036   } else {
1037     SetErrorAndLogNoDnnSupport();
1038   }
1039   return *this;
1040 }
1041 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1042 Stream &Stream::ThenDepthConcatenate(
1043     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1044     port::ArraySlice<const DeviceMemory<float> *> input_data,
1045     DeviceMemory<float> *output_data) {
1046   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1047 
1048   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1049     if (input_dimensions[i].count() != input_dimensions[0].count() ||
1050         input_dimensions[i].height() != input_dimensions[0].height() ||
1051         input_dimensions[i].width() != input_dimensions[0].width()) {
1052       SetError();
1053       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1054                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1055                  << "input_dimensions[" << i
1056                  << "]: " << input_dimensions[i].ToString();
1057       return *this;
1058     }
1059   }
1060 
1061   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1062     CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1063                                        output_data));
1064   } else {
1065     SetErrorAndLogNoDnnSupport();
1066   }
1067   return *this;
1068 }
1069 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1070 Stream &Stream::ThenSpaceConcatenate(
1071     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1072     port::ArraySlice<const DeviceMemory<float> *> input_data,
1073     DeviceMemory<float> *output_data,
1074     dnn::SpaceConcatenateMode concat_direction) {
1075   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1076 
1077   // Check that the input dimensions of all the other batches match those of the
1078   // first batch.
1079   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1080     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1081         (input_dimensions[i].count() != input_dimensions[0].count() ||
1082          input_dimensions[i].height() != input_dimensions[0].height() ||
1083          input_dimensions[i].feature_map_count() !=
1084              input_dimensions[0].feature_map_count())) {
1085       SetError();
1086       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1087                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1088                  << "input_dimensions[" << i
1089                  << "]: " << input_dimensions[i].ToString();
1090       return *this;
1091     }
1092 
1093     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1094         (input_dimensions[i].count() != input_dimensions[0].count() ||
1095          input_dimensions[i].width() != input_dimensions[0].width() ||
1096          input_dimensions[i].feature_map_count() !=
1097              input_dimensions[0].feature_map_count())) {
1098       SetError();
1099       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1100                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1101                  << "input_dimensions[" << i
1102                  << "]: " << input_dimensions[i].ToString();
1103       return *this;
1104     }
1105   }
1106   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1107     CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1108                                        output_data, concat_direction));
1109   } else {
1110     SetErrorAndLogNoDnnSupport();
1111   }
1112   return *this;
1113 }
1114 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1115 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1116                             const DeviceMemory<float> &input_data,
1117                             const dnn::BatchDescriptor &output_dimensions,
1118                             DeviceMemory<float> *output_data) {
1119   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1120             PARAM(output_dimensions), PARAM(output_data));
1121 
1122   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1123     CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1124                               output_dimensions, output_data));
1125   } else {
1126     SetErrorAndLogNoDnnSupport();
1127   }
1128   return *this;
1129 }
1130 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1131 Stream &Stream::ThenDepthToSpace(
1132     const dnn::BatchDescriptor &input_dimensions,
1133     const DeviceMemory<float> &input_data,
1134     const dnn::DepthToSpaceLayout &depth_to_space_layout,
1135     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1136   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1137             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1138             PARAM(output_data));
1139 
1140   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1141     CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1142                                    depth_to_space_layout, sqrt_depth_reduction,
1143                                    output_data));
1144   } else {
1145     SetErrorAndLogNoDnnSupport();
1146   }
1147   return *this;
1148 }
1149 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1150 Stream &Stream::ThenSpaceToDepth(
1151     const dnn::BatchDescriptor &input_dimensions,
1152     const DeviceMemory<float> &input_data,
1153     const dnn::DepthToSpaceLayout &space_to_depth_layout,
1154     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1155   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1156             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1157             PARAM(output_data));
1158 
1159   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1160     CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1161                                    space_to_depth_layout, sqrt_depth_increase,
1162                                    output_data));
1163   } else {
1164     SetErrorAndLogNoDnnSupport();
1165   }
1166   return *this;
1167 }
1168 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1169 Stream &Stream::ThenElementwiseOperate(
1170     dnn::ElementwiseOperation operation,
1171     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1172     port::ArraySlice<const DeviceMemory<float> *> input_data,
1173     const dnn::BatchDescriptor &output_dimensions,
1174     DeviceMemory<float> *output_data) {
1175   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1176             PARAM(output_dimensions), PARAM(output_data));
1177 
1178   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1179     CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1180                                          input_data, output_dimensions,
1181                                          output_data));
1182   } else {
1183     SetErrorAndLogNoDnnSupport();
1184   }
1185   return *this;
1186 }
1187 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1188 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1189     dnn::ElementwiseOperation operation,
1190     port::ArraySlice<int> input_multiplicands, int output_divisor,
1191     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1192     port::ArraySlice<const DeviceMemory<float> *> input_data,
1193     const dnn::BatchDescriptor &output_dimensions,
1194     DeviceMemory<float> *output_data) {
1195   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1196             PARAM(input_dimensions), PARAM(input_data),
1197             PARAM(output_dimensions), PARAM(output_data));
1198 
1199   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1200     CheckError(dnn->DoElementwiseOperateScaledQuantized(
1201         this, operation, input_multiplicands, output_divisor, input_dimensions,
1202         input_data, output_dimensions, output_data));
1203   } else {
1204     SetErrorAndLogNoDnnSupport();
1205   }
1206   return *this;
1207 }
1208 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1209 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1210                           const DeviceMemory<float> &input_data, int64 left_pad,
1211                           int64 right_pad, int64 top_pad, int64 bottom_pad,
1212                           DeviceMemory<float> *output_data) {
1213   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1214             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1215             PARAM(output_data));
1216 
1217   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1218     CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1219                             top_pad, bottom_pad, output_data));
1220   } else {
1221     SetErrorAndLogNoDnnSupport();
1222   }
1223   return *this;
1224 }
1225 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1226 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1227                             const DeviceMemory<float> &input_data,
1228                             int64 left_trim, int64 right_trim, int64 top_trim,
1229                             int64 bottom_trim,
1230                             DeviceMemory<float> *output_data) {
1231   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1232             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1233             PARAM(output_data));
1234 
1235   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1236     CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1237                               right_trim, top_trim, bottom_trim, output_data));
1238   } else {
1239     SetErrorAndLogNoDnnSupport();
1240   }
1241   return *this;
1242 }
1243 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1244 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1245                                 const DeviceMemory<float> &input_data,
1246                                 int64 replicate_x, int64 replicate_y,
1247                                 DeviceMemory<float> *output_data) {
1248   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1249             PARAM(replicate_y), PARAM(output_data));
1250 
1251   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1252     CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1253                                   replicate_y, output_data));
1254   } else {
1255     SetErrorAndLogNoDnnSupport();
1256   }
1257   return *this;
1258 }
1259 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1260 Stream &Stream::ThenMemcpyD2HQuantized(
1261     const DeviceMemory<float> &gpu_unquantized_src,
1262     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1263   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1264             PARAM(size));
1265 
1266   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1267     CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1268                                          host_dst, size));
1269   } else {
1270     SetErrorAndLogNoDnnSupport();
1271   }
1272   return *this;
1273 }
1274 
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1275 Stream &Stream::ThenMemcpyH2DQuantized(
1276     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1277     DeviceMemory<float> *gpu_unquantized_dst) {
1278   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1279             PARAM(gpu_unquantized_dst));
1280 
1281   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1282     CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1283                                          gpu_unquantized_dst));
1284   } else {
1285     SetErrorAndLogNoDnnSupport();
1286   }
1287   return *this;
1288 }
1289 
GetOrCreateSubStream()1290 Stream *Stream::GetOrCreateSubStream() {
1291   // Do not destroy bad streams when holding mu_ because ~Stream() may
1292   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1293   std::vector<std::unique_ptr<Stream>> bad_streams;
1294 
1295   absl::MutexLock lock(&mu_);
1296 
1297   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1298   // we encounter along the way.
1299   for (size_t index = 0; index < sub_streams_.size();) {
1300     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1301     if (pair.second) {
1302       // The sub_stream is reusable.
1303       Stream *sub_stream = pair.first.get();
1304       if (sub_stream->ok()) {
1305         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1306                 << sub_stream->DebugStreamPointers();
1307         pair.second = false;
1308         return sub_stream;
1309       }
1310 
1311       // The stream is reusable and not ok. Streams have a monotonic state
1312       // machine; the stream will remain in !ok forever. Swap it with the last
1313       // stream and pop it off.
1314       const int64 last = sub_streams_.size() - 1;
1315       if (index != last) {
1316         std::swap(pair, sub_streams_[last]);
1317       }
1318       bad_streams.push_back(std::move(sub_streams_.back().first));
1319       sub_streams_.pop_back();
1320       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1321               << sub_stream->DebugStreamPointers();
1322     } else {
1323       // The sub_stream is not reusable, move on to the next one.
1324       ++index;
1325     }
1326   }
1327 
1328   // No streams are reusable; create a new stream.
1329   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1330                             false);
1331   Stream *sub_stream = sub_streams_.back().first.get();
1332   sub_stream->Init();
1333   if (!sub_stream->ok()) {
1334     LOG(ERROR) << "sub-stream failed to be initialized";
1335   }
1336   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1337           << sub_stream->DebugStreamPointers();
1338 
1339   return sub_stream;
1340 }
1341 
ReturnSubStream(Stream * sub_stream)1342 void Stream::ReturnSubStream(Stream *sub_stream) {
1343   // Do not destroy bad streams when holding mu_ because ~Stream() may
1344   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1345   std::unique_ptr<Stream> bad_stream;
1346 
1347   absl::MutexLock lock(&mu_);
1348 
1349   // Look for the sub-stream.
1350   for (int64 index = 0, end = sub_streams_.size(); index < end; ++index) {
1351     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1352     if (pair.first.get() != sub_stream) {
1353       continue;
1354     }
1355 
1356     // Found the sub_stream.
1357     if (sub_stream->ok()) {
1358       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1359               << sub_stream->DebugStreamPointers();
1360       pair.second = true;
1361     } else {
1362       // The returned stream is not ok. Streams have a monotonic state
1363       // machine; the stream will remain in !ok forever. Swap it with the last
1364       // stream and pop it off.
1365       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1366               << sub_stream->DebugStreamPointers();
1367       const int64 last = sub_streams_.size() - 1;
1368       if (index != last) {
1369         std::swap(pair, sub_streams_[last]);
1370       }
1371       std::swap(bad_stream, sub_streams_.back().first);
1372       sub_streams_.pop_back();
1373     }
1374     return;
1375   }
1376 
1377   LOG(FATAL) << DebugStreamPointers()
1378              << " did not create the returned sub-stream "
1379              << sub_stream->DebugStreamPointers();
1380 }
1381 
ThenStartTimer(Timer * t)1382 Stream &Stream::ThenStartTimer(Timer *t) {
1383   VLOG_CALL(PARAM(t));
1384 
1385   CheckError(parent_->StartTimer(this, t));
1386   return *this;
1387 }
1388 
ThenStopTimer(Timer * t)1389 Stream &Stream::ThenStopTimer(Timer *t) {
1390   VLOG_CALL(PARAM(t));
1391 
1392   CheckError(parent_->StopTimer(this, t));
1393   return *this;
1394 }
1395 
ThenWaitFor(Stream * other)1396 Stream &Stream::ThenWaitFor(Stream *other) {
1397   VLOG_CALL(PARAM(other));
1398 
1399   CHECK(this != other) << "stream cannot wait for itself";
1400   if (ok() && other->ok()) {
1401     CheckError(parent_->CreateStreamDependency(this, other));
1402   } else {
1403     SetError();
1404     LOG(INFO) << DebugStreamPointers() << " did not wait for "
1405               << other->DebugStreamPointers();
1406   }
1407   return *this;
1408 }
1409 
ThenWaitFor(Event * event)1410 Stream &Stream::ThenWaitFor(Event *event) {
1411   VLOG_CALL(PARAM(event));
1412 
1413   if (ok()) {
1414     port::Status status = parent_->WaitForEvent(this, event);
1415     if (!status.ok()) {
1416       LOG(ERROR) << "Error waiting for event in stream: "
1417                  << status.error_message()
1418                  << "; not marking stream as bad, as the Event object may be "
1419                  << "at fault. Monitor for further errors.";
1420     }
1421   } else {
1422     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1423   }
1424   return *this;
1425 }
1426 
1427 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1428 // functions and logs for errors.
1429 template <typename... Args>
1430 struct ThenBlasImpl {
1431   // blas_func is the DoBlasXXX member function pointer, and args are its
1432   // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1433   Stream &operator()(Stream *stream,
1434                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1435                      Args... args) {
1436     return Run(stream, blas_func, /*record_error=*/true, args...);
1437   }
1438 
1439   // Like operator(), but only calls stream->CheckError() if record_error is
1440   // true.
1441   Stream &Run(Stream *stream,
1442               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1443               bool record_error, Args... args);
1444 };
1445 
1446 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1447 Stream &ThenBlasImpl<Args...>::Run(
1448     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1449     bool record_error, Args... args) {
1450   if (stream->ok()) {
1451     bool ok;
1452     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1453       ok = (blas->*blas_func)(stream, args...);
1454     } else {
1455       LOG(WARNING)
1456           << "attempting to perform BLAS operation using StreamExecutor "
1457              "without BLAS support";
1458       ok = false;
1459     }
1460     if (record_error) {
1461       stream->CheckError(ok);
1462     }
1463   }
1464   return *stream;
1465 }
1466 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1467 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1468                              int incx, DeviceMemory<float> *result) {
1469   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1470 
1471   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1472       impl;
1473   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1474               result);
1475 }
1476 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1477 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1478                              int incx, DeviceMemory<double> *result) {
1479   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1480 
1481   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1482                DeviceMemory<double> *>
1483       impl;
1484   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1485               result);
1486 }
1487 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1488 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1489                              const DeviceMemory<std::complex<float>> &x,
1490                              int incx, DeviceMemory<float> *result) {
1491   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1492 
1493   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1494                DeviceMemory<float> *>
1495       impl;
1496   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1497               result);
1498 }
1499 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1500 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1501                              const DeviceMemory<std::complex<double>> &x,
1502                              int incx, DeviceMemory<double> *result) {
1503   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1504 
1505   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1506                DeviceMemory<double> *>
1507       impl;
1508   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1509               result);
1510 }
1511 
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1512 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1513                              const DeviceMemory<float> &x, int incx,
1514                              DeviceMemory<float> *y, int incy) {
1515   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1516             PARAM(incy));
1517 
1518   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1519                DeviceMemory<float> *, int>
1520       impl;
1521   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1522               y, incy);
1523 }
1524 
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1525 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1526                              const DeviceMemory<double> &x, int incx,
1527                              DeviceMemory<double> *y, int incy) {
1528   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1529             PARAM(incy));
1530 
1531   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1532                DeviceMemory<double> *, int>
1533       impl;
1534   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1535               y, incy);
1536 }
1537 
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1538 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1539                              const DeviceMemory<std::complex<float>> &x,
1540                              int incx, DeviceMemory<std::complex<float>> *y,
1541                              int incy) {
1542   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1543             PARAM(incy));
1544 
1545   ThenBlasImpl<uint64, std::complex<float>,
1546                const DeviceMemory<std::complex<float>> &, int,
1547                DeviceMemory<std::complex<float>> *, int>
1548       impl;
1549   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1550               y, incy);
1551 }
1552 
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1553 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1554                              const DeviceMemory<std::complex<double>> &x,
1555                              int incx, DeviceMemory<std::complex<double>> *y,
1556                              int incy) {
1557   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1558             PARAM(incy));
1559 
1560   ThenBlasImpl<uint64, std::complex<double>,
1561                const DeviceMemory<std::complex<double>> &, int,
1562                DeviceMemory<std::complex<double>> *, int>
1563       impl;
1564   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1565               y, incy);
1566 }
1567 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1568 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
1569                              int incx, DeviceMemory<float> *y, int incy) {
1570   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1571 
1572   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
1573                int>
1574       impl;
1575   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1576               incy);
1577 }
1578 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1579 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
1580                              int incx, DeviceMemory<double> *y, int incy) {
1581   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1582 
1583   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1584                DeviceMemory<double> *, int>
1585       impl;
1586   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1587               incy);
1588 }
1589 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1590 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1591                              const DeviceMemory<std::complex<float>> &x,
1592                              int incx, DeviceMemory<std::complex<float>> *y,
1593                              int incy) {
1594   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1595 
1596   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1597                DeviceMemory<std::complex<float>> *, int>
1598       impl;
1599   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1600               incy);
1601 }
1602 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1603 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1604                              const DeviceMemory<std::complex<double>> &x,
1605                              int incx, DeviceMemory<std::complex<double>> *y,
1606                              int incy) {
1607   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1608 
1609   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1610                DeviceMemory<std::complex<double>> *, int>
1611       impl;
1612   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1613               incy);
1614 }
1615 
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)1616 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
1617                             int incx, const DeviceMemory<float> &y, int incy,
1618                             DeviceMemory<float> *result) {
1619   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1620             PARAM(result));
1621 
1622   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
1623                const DeviceMemory<float> &, int, DeviceMemory<float> *>
1624       impl;
1625   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1626               result);
1627 }
1628 
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)1629 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
1630                             int incx, const DeviceMemory<double> &y, int incy,
1631                             DeviceMemory<double> *result) {
1632   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1633             PARAM(result));
1634 
1635   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1636                const DeviceMemory<double> &, int, DeviceMemory<double> *>
1637       impl;
1638   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1639               result);
1640 }
1641 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1642 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1643                              const DeviceMemory<std::complex<float>> &x,
1644                              int incx,
1645                              const DeviceMemory<std::complex<float>> &y,
1646                              int incy,
1647                              DeviceMemory<std::complex<float>> *result) {
1648   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1649             PARAM(result));
1650 
1651   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1652                const DeviceMemory<std::complex<float>> &, int,
1653                DeviceMemory<std::complex<float>> *>
1654       impl;
1655   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1656               incy, result);
1657 }
1658 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1659 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1660                              const DeviceMemory<std::complex<double>> &x,
1661                              int incx,
1662                              const DeviceMemory<std::complex<double>> &y,
1663                              int incy,
1664                              DeviceMemory<std::complex<double>> *result) {
1665   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1666             PARAM(result));
1667 
1668   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1669                const DeviceMemory<std::complex<double>> &, int,
1670                DeviceMemory<std::complex<double>> *>
1671       impl;
1672   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1673               incy, result);
1674 }
1675 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1676 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1677                              const DeviceMemory<std::complex<float>> &x,
1678                              int incx,
1679                              const DeviceMemory<std::complex<float>> &y,
1680                              int incy,
1681                              DeviceMemory<std::complex<float>> *result) {
1682   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1683             PARAM(result));
1684 
1685   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1686                const DeviceMemory<std::complex<float>> &, int,
1687                DeviceMemory<std::complex<float>> *>
1688       impl;
1689   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1690               incy, result);
1691 }
1692 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1693 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1694                              const DeviceMemory<std::complex<double>> &x,
1695                              int incx,
1696                              const DeviceMemory<std::complex<double>> &y,
1697                              int incy,
1698                              DeviceMemory<std::complex<double>> *result) {
1699   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1700             PARAM(result));
1701 
1702   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1703                const DeviceMemory<std::complex<double>> &, int,
1704                DeviceMemory<std::complex<double>> *>
1705       impl;
1706   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1707               incy, result);
1708 }
1709 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1710 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
1711                              int incx, DeviceMemory<float> *result) {
1712   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1713 
1714   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1715       impl;
1716   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1717               result);
1718 }
1719 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1720 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
1721                              int incx, DeviceMemory<double> *result) {
1722   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1723 
1724   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1725                DeviceMemory<double> *>
1726       impl;
1727   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1728               result);
1729 }
1730 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1731 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1732                              const DeviceMemory<std::complex<float>> &x,
1733                              int incx, DeviceMemory<float> *result) {
1734   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1735 
1736   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1737                DeviceMemory<float> *>
1738       impl;
1739   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1740               result);
1741 }
1742 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1743 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1744                              const DeviceMemory<std::complex<double>> &x,
1745                              int incx, DeviceMemory<double> *result) {
1746   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1747 
1748   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1749                DeviceMemory<double> *>
1750       impl;
1751   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1752               result);
1753 }
1754 
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)1755 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
1756                             DeviceMemory<float> *y, int incy, float c,
1757                             float s) {
1758   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1759             PARAM(c), PARAM(s));
1760 
1761   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1762                float, float>
1763       impl;
1764   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1765               c, s);
1766 }
1767 
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)1768 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
1769                             int incx, DeviceMemory<double> *y, int incy,
1770                             double c, double s) {
1771   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1772             PARAM(c), PARAM(s));
1773 
1774   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1775                double, double>
1776       impl;
1777   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1778               c, s);
1779 }
1780 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)1781 Stream &Stream::ThenBlasRot(uint64 elem_count,
1782                             DeviceMemory<std::complex<float>> *x, int incx,
1783                             DeviceMemory<std::complex<float>> *y, int incy,
1784                             float c, float s) {
1785   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1786             PARAM(c), PARAM(s));
1787 
1788   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1789                DeviceMemory<std::complex<float>> *, int, float, float>
1790       impl;
1791   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1792               c, s);
1793 }
1794 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)1795 Stream &Stream::ThenBlasRot(uint64 elem_count,
1796                             DeviceMemory<std::complex<double>> *x, int incx,
1797                             DeviceMemory<std::complex<double>> *y, int incy,
1798                             double c, double s) {
1799   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1800             PARAM(c), PARAM(s));
1801 
1802   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1803                DeviceMemory<std::complex<double>> *, int, double, double>
1804       impl;
1805   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1806               c, s);
1807 }
1808 
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)1809 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
1810                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
1811   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1812 
1813   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1814                DeviceMemory<float> *, DeviceMemory<float> *>
1815       impl;
1816   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1817 }
1818 
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)1819 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
1820                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
1821   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1822 
1823   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1824                DeviceMemory<double> *, DeviceMemory<double> *>
1825       impl;
1826   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1827 }
1828 
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)1829 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
1830                              DeviceMemory<std::complex<float>> *b,
1831                              DeviceMemory<float> *c,
1832                              DeviceMemory<std::complex<float>> *s) {
1833   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1834 
1835   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
1836                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
1837                DeviceMemory<std::complex<float>> *>
1838       impl;
1839   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1840 }
1841 
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)1842 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
1843                              DeviceMemory<std::complex<double>> *b,
1844                              DeviceMemory<double> *c,
1845                              DeviceMemory<std::complex<double>> *s) {
1846   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1847 
1848   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
1849                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
1850                DeviceMemory<std::complex<double>> *>
1851       impl;
1852   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1853 }
1854 
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)1855 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
1856                              int incx, DeviceMemory<float> *y, int incy,
1857                              const DeviceMemory<float> &param) {
1858   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1859             PARAM(param));
1860 
1861   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1862                const DeviceMemory<float> &>
1863       impl;
1864   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1865               incy, param);
1866 }
1867 
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)1868 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
1869                              int incx, DeviceMemory<double> *y, int incy,
1870                              const DeviceMemory<double> &param) {
1871   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1872             PARAM(param));
1873 
1874   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1875                const DeviceMemory<double> &>
1876       impl;
1877   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1878               incy, param);
1879 }
1880 
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)1881 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
1882                               DeviceMemory<float> *x1,
1883                               const DeviceMemory<float> &y1,
1884                               DeviceMemory<float> *param) {
1885   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1886 
1887   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1888                DeviceMemory<float> *, const DeviceMemory<float> &,
1889                DeviceMemory<float> *>
1890       impl;
1891   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1892 }
1893 
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)1894 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
1895                               DeviceMemory<double> *d2,
1896                               DeviceMemory<double> *x1,
1897                               const DeviceMemory<double> &y1,
1898                               DeviceMemory<double> *param) {
1899   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1900 
1901   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1902                DeviceMemory<double> *, const DeviceMemory<double> &,
1903                DeviceMemory<double> *>
1904       impl;
1905   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1906 }
1907 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)1908 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1909                              DeviceMemory<float> *x, int incx) {
1910   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1911 
1912   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
1913   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1914 }
1915 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)1916 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1917                              DeviceMemory<double> *x, int incx) {
1918   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1919 
1920   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
1921   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1922 }
1923 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1924 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1925                              DeviceMemory<std::complex<float>> *x, int incx) {
1926   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1927 
1928   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
1929   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1930 }
1931 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1932 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1933                              DeviceMemory<std::complex<double>> *x, int incx) {
1934   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1935 
1936   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
1937   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1938 }
1939 
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1940 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
1941                              DeviceMemory<std::complex<float>> *x, int incx) {
1942   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1943 
1944   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
1945                int>
1946       impl;
1947   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1948 }
1949 
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1950 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
1951                              DeviceMemory<std::complex<double>> *x, int incx) {
1952   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1953 
1954   ThenBlasImpl<uint64, std::complex<double>,
1955                DeviceMemory<std::complex<double>> *, int>
1956       impl;
1957   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1958 }
1959 
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)1960 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
1961                              int incx, DeviceMemory<float> *y, int incy) {
1962   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1963 
1964   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
1965       impl;
1966   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1967               incy);
1968 }
1969 
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)1970 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
1971                              int incx, DeviceMemory<double> *y, int incy) {
1972   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1973 
1974   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
1975       impl;
1976   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1977               incy);
1978 }
1979 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1980 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1981                              DeviceMemory<std::complex<float>> *x, int incx,
1982                              DeviceMemory<std::complex<float>> *y, int incy) {
1983   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1984 
1985   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1986                DeviceMemory<std::complex<float>> *, int>
1987       impl;
1988   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1989               incy);
1990 }
1991 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1992 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1993                              DeviceMemory<std::complex<double>> *x, int incx,
1994                              DeviceMemory<std::complex<double>> *y, int incy) {
1995   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1996 
1997   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1998                DeviceMemory<std::complex<double>> *, int>
1999       impl;
2000   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2001               incy);
2002 }
2003 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2004 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2005                               int incx, DeviceMemory<int> *result) {
2006   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2007 
2008   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2009       impl;
2010   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2011               result);
2012 }
2013 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2014 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2015                               int incx, DeviceMemory<int> *result) {
2016   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2017 
2018   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2019       impl;
2020   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2021               result);
2022 }
2023 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2024 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2025                               const DeviceMemory<std::complex<float>> &x,
2026                               int incx, DeviceMemory<int> *result) {
2027   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2028 
2029   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2030                DeviceMemory<int> *>
2031       impl;
2032   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2033               result);
2034 }
2035 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2036 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2037                               const DeviceMemory<std::complex<double>> &x,
2038                               int incx, DeviceMemory<int> *result) {
2039   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2040 
2041   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2042                DeviceMemory<int> *>
2043       impl;
2044   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2045               result);
2046 }
2047 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2048 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2049                               int incx, DeviceMemory<int> *result) {
2050   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2051 
2052   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2053       impl;
2054   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2055               result);
2056 }
2057 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2058 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2059                               int incx, DeviceMemory<int> *result) {
2060   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2061 
2062   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2063       impl;
2064   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2065               result);
2066 }
2067 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2068 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2069                               const DeviceMemory<std::complex<float>> &x,
2070                               int incx, DeviceMemory<int> *result) {
2071   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2072 
2073   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2074                DeviceMemory<int> *>
2075       impl;
2076   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2077               result);
2078 }
2079 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2080 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2081                               const DeviceMemory<std::complex<double>> &x,
2082                               int incx, DeviceMemory<int> *result) {
2083   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2084 
2085   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2086                DeviceMemory<int> *>
2087       impl;
2088   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2089               result);
2090 }
2091 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2092 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2093                              uint64 kl, uint64 ku, float alpha,
2094                              const DeviceMemory<float> &a, int lda,
2095                              const DeviceMemory<float> &x, int incx, float beta,
2096                              DeviceMemory<float> *y, int incy) {
2097   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2098             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2099             PARAM(beta), PARAM(y), PARAM(incy));
2100 
2101   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2102                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2103                int, float, DeviceMemory<float> *, int>
2104       impl;
2105   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2106               a, lda, x, incx, beta, y, incy);
2107 }
2108 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2109 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2110                              uint64 kl, uint64 ku, double alpha,
2111                              const DeviceMemory<double> &a, int lda,
2112                              const DeviceMemory<double> &x, int incx,
2113                              double beta, DeviceMemory<double> *y, int incy) {
2114   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2115             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2116             PARAM(beta), PARAM(y), PARAM(incy));
2117 
2118   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2119                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2120                int, double, DeviceMemory<double> *, int>
2121       impl;
2122   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2123               a, lda, x, incx, beta, y, incy);
2124 }
2125 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2126 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2127                              uint64 kl, uint64 ku, std::complex<float> alpha,
2128                              const DeviceMemory<std::complex<float>> &a,
2129                              int lda,
2130                              const DeviceMemory<std::complex<float>> &x,
2131                              int incx, std::complex<float> beta,
2132                              DeviceMemory<std::complex<float>> *y, int incy) {
2133   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2134             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2135             PARAM(beta), PARAM(y), PARAM(incy));
2136 
2137   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2138                std::complex<float>, const DeviceMemory<std::complex<float>> &,
2139                int, const DeviceMemory<std::complex<float>> &, int,
2140                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2141       impl;
2142   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2143               a, lda, x, incx, beta, y, incy);
2144 }
2145 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2146 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2147                              uint64 kl, uint64 ku, std::complex<double> alpha,
2148                              const DeviceMemory<std::complex<double>> &a,
2149                              int lda,
2150                              const DeviceMemory<std::complex<double>> &x,
2151                              int incx, std::complex<double> beta,
2152                              DeviceMemory<std::complex<double>> *y, int incy) {
2153   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2154             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2155             PARAM(beta), PARAM(y), PARAM(incy));
2156 
2157   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2158                std::complex<double>, const DeviceMemory<std::complex<double>> &,
2159                int, const DeviceMemory<std::complex<double>> &, int,
2160                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2161       impl;
2162   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2163               a, lda, x, incx, beta, y, incy);
2164 }
2165 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2166 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2167                              float alpha, const DeviceMemory<float> &a, int lda,
2168                              const DeviceMemory<float> &x, int incx, float beta,
2169                              DeviceMemory<float> *y, int incy) {
2170   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2171             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2172             PARAM(incy));
2173 
2174   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2175                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2176                int, float, DeviceMemory<float> *, int>
2177       impl;
2178   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2179               x, incx, beta, y, incy);
2180 }
2181 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2182 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2183                              double alpha, const DeviceMemory<double> &a,
2184                              int lda, const DeviceMemory<double> &x, int incx,
2185                              double beta, DeviceMemory<double> *y, int incy) {
2186   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2187             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2188             PARAM(incy));
2189 
2190   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2191                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2192                int, double, DeviceMemory<double> *, int>
2193       impl;
2194   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2195               x, incx, beta, y, incy);
2196 }
2197 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2198 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2199                              std::complex<float> alpha,
2200                              const DeviceMemory<std::complex<float>> &a,
2201                              int lda,
2202                              const DeviceMemory<std::complex<float>> &x,
2203                              int incx, std::complex<float> beta,
2204                              DeviceMemory<std::complex<float>> *y, int incy) {
2205   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2206             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2207             PARAM(incy));
2208 
2209   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2210                const DeviceMemory<std::complex<float>> &, int,
2211                const DeviceMemory<std::complex<float>> &, int,
2212                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2213       impl;
2214   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2215               x, incx, beta, y, incy);
2216 }
2217 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2218 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2219                              std::complex<double> alpha,
2220                              const DeviceMemory<std::complex<double>> &a,
2221                              int lda,
2222                              const DeviceMemory<std::complex<double>> &x,
2223                              int incx, std::complex<double> beta,
2224                              DeviceMemory<std::complex<double>> *y, int incy) {
2225   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2226             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2227             PARAM(incy));
2228 
2229   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2230                const DeviceMemory<std::complex<double>> &, int,
2231                const DeviceMemory<std::complex<double>> &, int,
2232                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2233       impl;
2234   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2235               x, incx, beta, y, incy);
2236 }
2237 
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2238 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2239                             const DeviceMemory<float> &x, int incx,
2240                             const DeviceMemory<float> &y, int incy,
2241                             DeviceMemory<float> *a, int lda) {
2242   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2243             PARAM(incy), PARAM(a), PARAM(lda));
2244 
2245   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2246                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2247       impl;
2248   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2249               incy, a, lda);
2250 }
2251 
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2252 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2253                             const DeviceMemory<double> &x, int incx,
2254                             const DeviceMemory<double> &y, int incy,
2255                             DeviceMemory<double> *a, int lda) {
2256   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2257             PARAM(incy), PARAM(a), PARAM(lda));
2258 
2259   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2260                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2261       impl;
2262   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2263               incy, a, lda);
2264 }
2265 
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2266 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2267                              const DeviceMemory<std::complex<float>> &x,
2268                              int incx,
2269                              const DeviceMemory<std::complex<float>> &y,
2270                              int incy, DeviceMemory<std::complex<float>> *a,
2271                              int lda) {
2272   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2273             PARAM(incy), PARAM(a), PARAM(lda));
2274 
2275   ThenBlasImpl<uint64, uint64, std::complex<float>,
2276                const DeviceMemory<std::complex<float>> &, int,
2277                const DeviceMemory<std::complex<float>> &, int,
2278                DeviceMemory<std::complex<float>> *, int>
2279       impl;
2280   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2281               incy, a, lda);
2282 }
2283 
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2284 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2285                              const DeviceMemory<std::complex<double>> &x,
2286                              int incx,
2287                              const DeviceMemory<std::complex<double>> &y,
2288                              int incy, DeviceMemory<std::complex<double>> *a,
2289                              int lda) {
2290   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2291             PARAM(incy), PARAM(a), PARAM(lda));
2292 
2293   ThenBlasImpl<uint64, uint64, std::complex<double>,
2294                const DeviceMemory<std::complex<double>> &, int,
2295                const DeviceMemory<std::complex<double>> &, int,
2296                DeviceMemory<std::complex<double>> *, int>
2297       impl;
2298   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2299               incy, a, lda);
2300 }
2301 
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2302 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2303                              const DeviceMemory<std::complex<float>> &x,
2304                              int incx,
2305                              const DeviceMemory<std::complex<float>> &y,
2306                              int incy, DeviceMemory<std::complex<float>> *a,
2307                              int lda) {
2308   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2309             PARAM(incy), PARAM(a), PARAM(lda));
2310 
2311   ThenBlasImpl<uint64, uint64, std::complex<float>,
2312                const DeviceMemory<std::complex<float>> &, int,
2313                const DeviceMemory<std::complex<float>> &, int,
2314                DeviceMemory<std::complex<float>> *, int>
2315       impl;
2316   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2317               incy, a, lda);
2318 }
2319 
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2320 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2321                              const DeviceMemory<std::complex<double>> &x,
2322                              int incx,
2323                              const DeviceMemory<std::complex<double>> &y,
2324                              int incy, DeviceMemory<std::complex<double>> *a,
2325                              int lda) {
2326   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2327             PARAM(incy), PARAM(a), PARAM(lda));
2328 
2329   ThenBlasImpl<uint64, uint64, std::complex<double>,
2330                const DeviceMemory<std::complex<double>> &, int,
2331                const DeviceMemory<std::complex<double>> &, int,
2332                DeviceMemory<std::complex<double>> *, int>
2333       impl;
2334   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2335               incy, a, lda);
2336 }
2337 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2338 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2339                              std::complex<float> alpha,
2340                              const DeviceMemory<std::complex<float>> &a,
2341                              int lda,
2342                              const DeviceMemory<std::complex<float>> &x,
2343                              int incx, std::complex<float> beta,
2344                              DeviceMemory<std::complex<float>> *y, int incy) {
2345   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2346             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2347 
2348   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2349                const DeviceMemory<std::complex<float>> &, int,
2350                const DeviceMemory<std::complex<float>> &, int,
2351                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2352       impl;
2353   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2354               x, incx, beta, y, incy);
2355 }
2356 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2357 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2358                              std::complex<double> alpha,
2359                              const DeviceMemory<std::complex<double>> &a,
2360                              int lda,
2361                              const DeviceMemory<std::complex<double>> &x,
2362                              int incx, std::complex<double> beta,
2363                              DeviceMemory<std::complex<double>> *y, int incy) {
2364   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2365             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2366 
2367   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2368                const DeviceMemory<std::complex<double>> &, int,
2369                const DeviceMemory<std::complex<double>> &, int,
2370                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2371       impl;
2372   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2373               x, incx, beta, y, incy);
2374 }
2375 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2376 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2377                              std::complex<float> alpha,
2378                              const DeviceMemory<std::complex<float>> &a,
2379                              int lda,
2380                              const DeviceMemory<std::complex<float>> &x,
2381                              int incx, std::complex<float> beta,
2382                              DeviceMemory<std::complex<float>> *y, int incy) {
2383   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2384             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2385 
2386   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2387                const DeviceMemory<std::complex<float>> &, int,
2388                const DeviceMemory<std::complex<float>> &, int,
2389                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2390       impl;
2391   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2392               incx, beta, y, incy);
2393 }
2394 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2395 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2396                              std::complex<double> alpha,
2397                              const DeviceMemory<std::complex<double>> &a,
2398                              int lda,
2399                              const DeviceMemory<std::complex<double>> &x,
2400                              int incx, std::complex<double> beta,
2401                              DeviceMemory<std::complex<double>> *y, int incy) {
2402   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2403             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2404 
2405   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2406                const DeviceMemory<std::complex<double>> &, int,
2407                const DeviceMemory<std::complex<double>> &, int,
2408                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2409       impl;
2410   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2411               incx, beta, y, incy);
2412 }
2413 
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2414 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2415                             const DeviceMemory<std::complex<float>> &x,
2416                             int incx, DeviceMemory<std::complex<float>> *a,
2417                             int lda) {
2418   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2419             PARAM(a), PARAM(lda));
2420 
2421   ThenBlasImpl<blas::UpperLower, uint64, float,
2422                const DeviceMemory<std::complex<float>> &, int,
2423                DeviceMemory<std::complex<float>> *, int>
2424       impl;
2425   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2426               lda);
2427 }
2428 
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2429 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2430                             const DeviceMemory<std::complex<double>> &x,
2431                             int incx, DeviceMemory<std::complex<double>> *a,
2432                             int lda) {
2433   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2434             PARAM(a), PARAM(lda));
2435 
2436   ThenBlasImpl<blas::UpperLower, uint64, double,
2437                const DeviceMemory<std::complex<double>> &, int,
2438                DeviceMemory<std::complex<double>> *, int>
2439       impl;
2440   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2441               lda);
2442 }
2443 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2444 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2445                              std::complex<float> alpha,
2446                              const DeviceMemory<std::complex<float>> &x,
2447                              int incx,
2448                              const DeviceMemory<std::complex<float>> &y,
2449                              int incy, DeviceMemory<std::complex<float>> *a,
2450                              int lda) {
2451   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2452             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2453 
2454   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2455                const DeviceMemory<std::complex<float>> &, int,
2456                const DeviceMemory<std::complex<float>> &, int,
2457                DeviceMemory<std::complex<float>> *, int>
2458       impl;
2459   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2460               incy, a, lda);
2461 }
2462 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2463 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2464                              std::complex<double> alpha,
2465                              const DeviceMemory<std::complex<double>> &x,
2466                              int incx,
2467                              const DeviceMemory<std::complex<double>> &y,
2468                              int incy, DeviceMemory<std::complex<double>> *a,
2469                              int lda) {
2470   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2471             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2472 
2473   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2474                const DeviceMemory<std::complex<double>> &, int,
2475                const DeviceMemory<std::complex<double>> &, int,
2476                DeviceMemory<std::complex<double>> *, int>
2477       impl;
2478   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2479               incy, a, lda);
2480 }
2481 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2482 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2483                              std::complex<float> alpha,
2484                              const DeviceMemory<std::complex<float>> &ap,
2485                              const DeviceMemory<std::complex<float>> &x,
2486                              int incx, std::complex<float> beta,
2487                              DeviceMemory<std::complex<float>> *y, int incy) {
2488   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2489             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2490 
2491   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2492                const DeviceMemory<std::complex<float>> &,
2493                const DeviceMemory<std::complex<float>> &, int,
2494                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2495       impl;
2496   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2497               beta, y, incy);
2498 }
2499 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2500 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2501                              std::complex<double> alpha,
2502                              const DeviceMemory<std::complex<double>> &ap,
2503                              const DeviceMemory<std::complex<double>> &x,
2504                              int incx, std::complex<double> beta,
2505                              DeviceMemory<std::complex<double>> *y, int incy) {
2506   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2507             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2508 
2509   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2510                const DeviceMemory<std::complex<double>> &,
2511                const DeviceMemory<std::complex<double>> &, int,
2512                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2513       impl;
2514   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2515               beta, y, incy);
2516 }
2517 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2518 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2519                             const DeviceMemory<std::complex<float>> &x,
2520                             int incx, DeviceMemory<std::complex<float>> *ap) {
2521   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2522             PARAM(ap));
2523 
2524   ThenBlasImpl<blas::UpperLower, uint64, float,
2525                const DeviceMemory<std::complex<float>> &, int,
2526                DeviceMemory<std::complex<float>> *>
2527       impl;
2528   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2529 }
2530 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2531 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2532                             const DeviceMemory<std::complex<double>> &x,
2533                             int incx, DeviceMemory<std::complex<double>> *ap) {
2534   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2535             PARAM(ap));
2536 
2537   ThenBlasImpl<blas::UpperLower, uint64, double,
2538                const DeviceMemory<std::complex<double>> &, int,
2539                DeviceMemory<std::complex<double>> *>
2540       impl;
2541   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2542 }
2543 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2544 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2545                              std::complex<float> alpha,
2546                              const DeviceMemory<std::complex<float>> &x,
2547                              int incx,
2548                              const DeviceMemory<std::complex<float>> &y,
2549                              int incy, DeviceMemory<std::complex<float>> *ap) {
2550   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2551             PARAM(y), PARAM(incy), PARAM(ap));
2552 
2553   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2554                const DeviceMemory<std::complex<float>> &, int,
2555                const DeviceMemory<std::complex<float>> &, int,
2556                DeviceMemory<std::complex<float>> *>
2557       impl;
2558   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2559               incy, ap);
2560 }
2561 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2562 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2563                              std::complex<double> alpha,
2564                              const DeviceMemory<std::complex<double>> &x,
2565                              int incx,
2566                              const DeviceMemory<std::complex<double>> &y,
2567                              int incy, DeviceMemory<std::complex<double>> *ap) {
2568   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2569             PARAM(y), PARAM(incy), PARAM(ap));
2570 
2571   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2572                const DeviceMemory<std::complex<double>> &, int,
2573                const DeviceMemory<std::complex<double>> &, int,
2574                DeviceMemory<std::complex<double>> *>
2575       impl;
2576   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2577               incy, ap);
2578 }
2579 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2580 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2581                              float alpha, const DeviceMemory<float> &a, int lda,
2582                              const DeviceMemory<float> &x, int incx, float beta,
2583                              DeviceMemory<float> *y, int incy) {
2584   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2585             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2586 
2587   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2588                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2589                int, float, DeviceMemory<float> *, int>
2590       impl;
2591   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2592               x, incx, beta, y, incy);
2593 }
2594 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2595 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2596                              double alpha, const DeviceMemory<double> &a,
2597                              int lda, const DeviceMemory<double> &x, int incx,
2598                              double beta, DeviceMemory<double> *y, int incy) {
2599   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2600             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2601 
2602   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2603                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2604                int, double, DeviceMemory<double> *, int>
2605       impl;
2606   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2607               x, incx, beta, y, incy);
2608 }
2609 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2610 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2611                              const DeviceMemory<float> &ap,
2612                              const DeviceMemory<float> &x, int incx, float beta,
2613                              DeviceMemory<float> *y, int incy) {
2614   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2615             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2616 
2617   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2618                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
2619                int>
2620       impl;
2621   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2622               beta, y, incy);
2623 }
2624 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2625 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
2626                              const DeviceMemory<double> &ap,
2627                              const DeviceMemory<double> &x, int incx,
2628                              double beta, DeviceMemory<double> *y, int incy) {
2629   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2630             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2631 
2632   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2633                const DeviceMemory<double> &, int, double,
2634                DeviceMemory<double> *, int>
2635       impl;
2636   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2637               beta, y, incy);
2638 }
2639 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)2640 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
2641                             const DeviceMemory<float> &x, int incx,
2642                             DeviceMemory<float> *ap) {
2643   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2644             PARAM(ap));
2645 
2646   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2647                int, DeviceMemory<float> *>
2648       impl;
2649   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2650 }
2651 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)2652 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
2653                             const DeviceMemory<double> &x, int incx,
2654                             DeviceMemory<double> *ap) {
2655   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2656             PARAM(ap));
2657 
2658   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2659                int, DeviceMemory<double> *>
2660       impl;
2661   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2662 }
2663 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)2664 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
2665                              const DeviceMemory<float> &x, int incx,
2666                              const DeviceMemory<float> &y, int incy,
2667                              DeviceMemory<float> *ap) {
2668   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2669             PARAM(y), PARAM(incy), PARAM(ap));
2670 
2671   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2672                int, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2673       impl;
2674   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2675               incy, ap);
2676 }
2677 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)2678 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
2679                              const DeviceMemory<double> &x, int incx,
2680                              const DeviceMemory<double> &y, int incy,
2681                              DeviceMemory<double> *ap) {
2682   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2683             PARAM(y), PARAM(incy), PARAM(ap));
2684 
2685   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2686                int, const DeviceMemory<double> &, int, DeviceMemory<double> *>
2687       impl;
2688   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2689               incy, ap);
2690 }
2691 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2692 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
2693                              const DeviceMemory<float> &a, int lda,
2694                              const DeviceMemory<float> &x, int incx, float beta,
2695                              DeviceMemory<float> *y, int incy) {
2696   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2697             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2698 
2699   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2700                int, const DeviceMemory<float> &, int, float,
2701                DeviceMemory<float> *, int>
2702       impl;
2703   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2704               incx, beta, y, incy);
2705 }
2706 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2707 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
2708                              const DeviceMemory<double> &a, int lda,
2709                              const DeviceMemory<double> &x, int incx,
2710                              double beta, DeviceMemory<double> *y, int incy) {
2711   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2712             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2713 
2714   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2715                int, const DeviceMemory<double> &, int, double,
2716                DeviceMemory<double> *, int>
2717       impl;
2718   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2719               incx, beta, y, incy);
2720 }
2721 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)2722 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
2723                             const DeviceMemory<float> &x, int incx,
2724                             DeviceMemory<float> *a, int lda) {
2725   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2726             PARAM(a), PARAM(lda));
2727 
2728   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2729                int, DeviceMemory<float> *, int>
2730       impl;
2731   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2732               lda);
2733 }
2734 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)2735 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
2736                             const DeviceMemory<double> &x, int incx,
2737                             DeviceMemory<double> *a, int lda) {
2738   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2739             PARAM(a), PARAM(lda));
2740 
2741   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2742                int, DeviceMemory<double> *, int>
2743       impl;
2744   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2745               lda);
2746 }
2747 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2748 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
2749                              const DeviceMemory<float> &x, int incx,
2750                              const DeviceMemory<float> &y, int incy,
2751                              DeviceMemory<float> *a, int lda) {
2752   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2753             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2754 
2755   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2756                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2757                int>
2758       impl;
2759   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2760               incy, a, lda);
2761 }
2762 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2763 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
2764                              const DeviceMemory<double> &x, int incx,
2765                              const DeviceMemory<double> &y, int incy,
2766                              DeviceMemory<double> *a, int lda) {
2767   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2768             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2769 
2770   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2771                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
2772                int>
2773       impl;
2774   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2775               incy, a, lda);
2776 }
2777 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2778 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2779                              blas::Diagonal diag, uint64 n, uint64 k,
2780                              const DeviceMemory<float> &a, int lda,
2781                              DeviceMemory<float> *x, int incx) {
2782   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2783             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2784 
2785   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2786                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2787                int>
2788       impl;
2789   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2790               lda, x, incx);
2791 }
2792 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2793 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2794                              blas::Diagonal diag, uint64 n, uint64 k,
2795                              const DeviceMemory<double> &a, int lda,
2796                              DeviceMemory<double> *x, int incx) {
2797   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2798             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2799 
2800   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2801                uint64, const DeviceMemory<double> &, int,
2802                DeviceMemory<double> *, int>
2803       impl;
2804   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2805               lda, x, incx);
2806 }
2807 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2808 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2809                              blas::Diagonal diag, uint64 n, uint64 k,
2810                              const DeviceMemory<std::complex<float>> &a,
2811                              int lda, DeviceMemory<std::complex<float>> *x,
2812                              int incx) {
2813   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2814             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2815 
2816   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2817                uint64, const DeviceMemory<std::complex<float>> &, int,
2818                DeviceMemory<std::complex<float>> *, int>
2819       impl;
2820   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2821               lda, x, incx);
2822 }
2823 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2824 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2825                              blas::Diagonal diag, uint64 n, uint64 k,
2826                              const DeviceMemory<std::complex<double>> &a,
2827                              int lda, DeviceMemory<std::complex<double>> *x,
2828                              int incx) {
2829   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2830             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2831 
2832   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2833                uint64, const DeviceMemory<std::complex<double>> &, int,
2834                DeviceMemory<std::complex<double>> *, int>
2835       impl;
2836   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2837               lda, x, incx);
2838 }
2839 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2840 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2841                              blas::Diagonal diag, uint64 n, uint64 k,
2842                              const DeviceMemory<float> &a, int lda,
2843                              DeviceMemory<float> *x, int incx) {
2844   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2845             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2846 
2847   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2848                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2849                int>
2850       impl;
2851   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2852               lda, x, incx);
2853 }
2854 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2855 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2856                              blas::Diagonal diag, uint64 n, uint64 k,
2857                              const DeviceMemory<double> &a, int lda,
2858                              DeviceMemory<double> *x, int incx) {
2859   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2860             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2861 
2862   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2863                uint64, const DeviceMemory<double> &, int,
2864                DeviceMemory<double> *, int>
2865       impl;
2866   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2867               lda, x, incx);
2868 }
2869 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2870 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2871                              blas::Diagonal diag, uint64 n, uint64 k,
2872                              const DeviceMemory<std::complex<float>> &a,
2873                              int lda, DeviceMemory<std::complex<float>> *x,
2874                              int incx) {
2875   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2876             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2877 
2878   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2879                uint64, const DeviceMemory<std::complex<float>> &, int,
2880                DeviceMemory<std::complex<float>> *, int>
2881       impl;
2882   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2883               lda, x, incx);
2884 }
2885 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2886 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2887                              blas::Diagonal diag, uint64 n, uint64 k,
2888                              const DeviceMemory<std::complex<double>> &a,
2889                              int lda, DeviceMemory<std::complex<double>> *x,
2890                              int incx) {
2891   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2892             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2893 
2894   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2895                uint64, const DeviceMemory<std::complex<double>> &, int,
2896                DeviceMemory<std::complex<double>> *, int>
2897       impl;
2898   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2899               lda, x, incx);
2900 }
2901 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2902 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2903                              blas::Diagonal diag, uint64 n,
2904                              const DeviceMemory<float> &ap,
2905                              DeviceMemory<float> *x, int incx) {
2906   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2907             PARAM(x), PARAM(incx));
2908 
2909   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2910                const DeviceMemory<float> &, DeviceMemory<float> *, int>
2911       impl;
2912   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2913               incx);
2914 }
2915 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2916 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2917                              blas::Diagonal diag, uint64 n,
2918                              const DeviceMemory<double> &ap,
2919                              DeviceMemory<double> *x, int incx) {
2920   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2921             PARAM(x), PARAM(incx));
2922 
2923   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2924                const DeviceMemory<double> &, DeviceMemory<double> *, int>
2925       impl;
2926   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2927               incx);
2928 }
2929 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2930 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2931                              blas::Diagonal diag, uint64 n,
2932                              const DeviceMemory<std::complex<float>> &ap,
2933                              DeviceMemory<std::complex<float>> *x, int incx) {
2934   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2935             PARAM(x), PARAM(incx));
2936 
2937   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2938                const DeviceMemory<std::complex<float>> &,
2939                DeviceMemory<std::complex<float>> *, int>
2940       impl;
2941   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2942               incx);
2943 }
2944 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2945 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2946                              blas::Diagonal diag, uint64 n,
2947                              const DeviceMemory<std::complex<double>> &ap,
2948                              DeviceMemory<std::complex<double>> *x, int incx) {
2949   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2950             PARAM(x), PARAM(incx));
2951 
2952   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2953                const DeviceMemory<std::complex<double>> &,
2954                DeviceMemory<std::complex<double>> *, int>
2955       impl;
2956   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2957               incx);
2958 }
2959 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2960 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2961                              blas::Diagonal diag, uint64 n,
2962                              const DeviceMemory<float> &ap,
2963                              DeviceMemory<float> *x, int incx) {
2964   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2965             PARAM(x), PARAM(incx));
2966 
2967   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2968                const DeviceMemory<float> &, DeviceMemory<float> *, int>
2969       impl;
2970   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2971               incx);
2972 }
2973 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2974 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2975                              blas::Diagonal diag, uint64 n,
2976                              const DeviceMemory<double> &ap,
2977                              DeviceMemory<double> *x, int incx) {
2978   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2979             PARAM(x), PARAM(incx));
2980 
2981   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2982                const DeviceMemory<double> &, DeviceMemory<double> *, int>
2983       impl;
2984   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2985               incx);
2986 }
2987 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2988 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2989                              blas::Diagonal diag, uint64 n,
2990                              const DeviceMemory<std::complex<float>> &ap,
2991                              DeviceMemory<std::complex<float>> *x, int incx) {
2992   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2993             PARAM(x), PARAM(incx));
2994 
2995   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2996                const DeviceMemory<std::complex<float>> &,
2997                DeviceMemory<std::complex<float>> *, int>
2998       impl;
2999   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3000               incx);
3001 }
3002 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3003 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3004                              blas::Diagonal diag, uint64 n,
3005                              const DeviceMemory<std::complex<double>> &ap,
3006                              DeviceMemory<std::complex<double>> *x, int incx) {
3007   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3008             PARAM(x), PARAM(incx));
3009 
3010   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3011                const DeviceMemory<std::complex<double>> &,
3012                DeviceMemory<std::complex<double>> *, int>
3013       impl;
3014   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3015               incx);
3016 }
3017 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3018 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3019                              blas::Diagonal diag, uint64 n,
3020                              const DeviceMemory<float> &a, int lda,
3021                              DeviceMemory<float> *x, int incx) {
3022   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3023             PARAM(lda), PARAM(x), PARAM(incx));
3024 
3025   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3026                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
3027       impl;
3028   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3029               lda, x, incx);
3030 }
3031 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3032 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3033                              blas::Diagonal diag, uint64 n,
3034                              const DeviceMemory<double> &a, int lda,
3035                              DeviceMemory<double> *x, int incx) {
3036   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3037             PARAM(lda), PARAM(x), PARAM(incx));
3038 
3039   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3040                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
3041       impl;
3042   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3043               lda, x, incx);
3044 }
3045 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3046 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3047                              blas::Diagonal diag, uint64 n,
3048                              const DeviceMemory<std::complex<float>> &a,
3049                              int lda, DeviceMemory<std::complex<float>> *x,
3050                              int incx) {
3051   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3052             PARAM(lda), PARAM(x), PARAM(incx));
3053 
3054   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3055                const DeviceMemory<std::complex<float>> &, int,
3056                DeviceMemory<std::complex<float>> *, int>
3057       impl;
3058   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3059               lda, x, incx);
3060 }
3061 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3062 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3063                              blas::Diagonal diag, uint64 n,
3064                              const DeviceMemory<std::complex<double>> &a,
3065                              int lda, DeviceMemory<std::complex<double>> *x,
3066                              int incx) {
3067   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3068             PARAM(lda), PARAM(x), PARAM(incx));
3069 
3070   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3071                const DeviceMemory<std::complex<double>> &, int,
3072                DeviceMemory<std::complex<double>> *, int>
3073       impl;
3074   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3075               lda, x, incx);
3076 }
3077 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3078 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3079                              blas::Diagonal diag, uint64 n,
3080                              const DeviceMemory<float> &a, int lda,
3081                              DeviceMemory<float> *x, int incx) {
3082   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3083             PARAM(lda), PARAM(x), PARAM(incx));
3084 
3085   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3086                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
3087       impl;
3088   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3089               lda, x, incx);
3090 }
3091 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3092 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3093                              blas::Diagonal diag, uint64 n,
3094                              const DeviceMemory<double> &a, int lda,
3095                              DeviceMemory<double> *x, int incx) {
3096   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3097             PARAM(lda), PARAM(x), PARAM(incx));
3098 
3099   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3100                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
3101       impl;
3102   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3103               lda, x, incx);
3104 }
3105 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3106 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3107                              blas::Diagonal diag, uint64 n,
3108                              const DeviceMemory<std::complex<float>> &a,
3109                              int lda, DeviceMemory<std::complex<float>> *x,
3110                              int incx) {
3111   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3112             PARAM(lda), PARAM(x), PARAM(incx));
3113 
3114   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3115                const DeviceMemory<std::complex<float>> &, int,
3116                DeviceMemory<std::complex<float>> *, int>
3117       impl;
3118   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3119               lda, x, incx);
3120 }
3121 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3122 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3123                              blas::Diagonal diag, uint64 n,
3124                              const DeviceMemory<std::complex<double>> &a,
3125                              int lda, DeviceMemory<std::complex<double>> *x,
3126                              int incx) {
3127   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3128             PARAM(lda), PARAM(x), PARAM(incx));
3129 
3130   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3131                const DeviceMemory<std::complex<double>> &, int,
3132                DeviceMemory<std::complex<double>> *, int>
3133       impl;
3134   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3135               lda, x, incx);
3136 }
3137 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3138 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3139                              uint64 m, uint64 n, uint64 k, float alpha,
3140                              const DeviceMemory<Eigen::half> &a, int lda,
3141                              const DeviceMemory<Eigen::half> &b, int ldb,
3142                              float beta, DeviceMemory<Eigen::half> *c,
3143                              int ldc) {
3144   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3145             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3146             PARAM(beta), PARAM(c), PARAM(ldc));
3147 
3148   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3149                const DeviceMemory<Eigen::half> &, int,
3150                const DeviceMemory<Eigen::half> &, int, float,
3151                DeviceMemory<Eigen::half> *, int>
3152       impl;
3153   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3154               alpha, a, lda, b, ldb, beta, c, ldc);
3155 }
3156 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3157 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3158                              uint64 m, uint64 n, uint64 k, float alpha,
3159                              const DeviceMemory<float> &a, int lda,
3160                              const DeviceMemory<float> &b, int ldb, float beta,
3161                              DeviceMemory<float> *c, int ldc) {
3162   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3163             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3164             PARAM(beta), PARAM(c), PARAM(ldc));
3165 
3166   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3167                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3168                int, float, DeviceMemory<float> *, int>
3169       impl;
3170   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3171               alpha, a, lda, b, ldb, beta, c, ldc);
3172 }
3173 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3174 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3175                              uint64 m, uint64 n, uint64 k, double alpha,
3176                              const DeviceMemory<double> &a, int lda,
3177                              const DeviceMemory<double> &b, int ldb,
3178                              double beta, DeviceMemory<double> *c, int ldc) {
3179   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3180             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3181             PARAM(beta), PARAM(c), PARAM(ldc));
3182 
3183   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3184                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3185                int, double, DeviceMemory<double> *, int>
3186       impl;
3187   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3188               alpha, a, lda, b, ldb, beta, c, ldc);
3189 }
3190 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3191 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3192                              uint64 m, uint64 n, uint64 k,
3193                              std::complex<float> alpha,
3194                              const DeviceMemory<std::complex<float>> &a,
3195                              int lda,
3196                              const DeviceMemory<std::complex<float>> &b,
3197                              int ldb, std::complex<float> beta,
3198                              DeviceMemory<std::complex<float>> *c, int ldc) {
3199   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3200             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3201             PARAM(beta), PARAM(c), PARAM(ldc));
3202 
3203   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3204                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3205                int, const DeviceMemory<std::complex<float>> &, int,
3206                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3207       impl;
3208   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3209               alpha, a, lda, b, ldb, beta, c, ldc);
3210 }
3211 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3212 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3213                              uint64 m, uint64 n, uint64 k,
3214                              std::complex<double> alpha,
3215                              const DeviceMemory<std::complex<double>> &a,
3216                              int lda,
3217                              const DeviceMemory<std::complex<double>> &b,
3218                              int ldb, std::complex<double> beta,
3219                              DeviceMemory<std::complex<double>> *c, int ldc) {
3220   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3221             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3222             PARAM(beta), PARAM(c), PARAM(ldc));
3223 
3224   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3225                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3226                int, const DeviceMemory<std::complex<double>> &, int,
3227                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3228       impl;
3229   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3230               alpha, a, lda, b, ldb, beta, c, ldc);
3231 }
3232 
3233 namespace {
3234 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3235 // blas::ProfileResult*.  This functor doesn't put the stream into an error
3236 // state if the op fails and the profile result is non-null.  Instead, the
3237 // error-ness is returned in the profile result itself.
3238 template <typename... Args>
3239 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon852f14a50211::ThenBlasWithProfileImpl3240   Stream &operator()(Stream *stream,
3241                      bool (blas::BlasSupport::*blas_func)(
3242                          Stream *, Args..., blas::ProfileResult *),
3243                      Args... args, blas::ProfileResult *profile_result) {
3244     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3245     bool record_error = profile_result == nullptr;
3246     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3247   }
3248 };
3249 }  // anonymous namespace
3250 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3251 Stream &Stream::ThenBlasGemvWithProfiling(
3252     blas::Transpose trans, uint64 m, uint64 n, float alpha,
3253     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3254     int incx, float beta, DeviceMemory<float> *y, int incy,
3255     blas::ProfileResult *output_profile_result) {
3256   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3257             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3258             PARAM(incy));
3259 
3260   ThenBlasWithProfileImpl<
3261       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3262       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3263       impl;
3264   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3265               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3266 }
3267 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3268 Stream &Stream::ThenBlasGemvWithProfiling(
3269     blas::Transpose trans, uint64 m, uint64 n, double alpha,
3270     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3271     int incx, double beta, DeviceMemory<double> *y, int incy,
3272     blas::ProfileResult *output_profile_result) {
3273   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3274             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3275             PARAM(incy));
3276 
3277   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3278                           const DeviceMemory<double> &, int,
3279                           const DeviceMemory<double> &, int, double,
3280                           DeviceMemory<double> *, int>
3281       impl;
3282   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3283               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3284 }
3285 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3286 Stream &Stream::ThenBlasGemvWithProfiling(
3287     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3288     const DeviceMemory<std::complex<float>> &a, int lda,
3289     const DeviceMemory<std::complex<float>> &x, int incx,
3290     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3291     blas::ProfileResult *output_profile_result) {
3292   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3293             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3294             PARAM(incy));
3295 
3296   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3297                           const DeviceMemory<std::complex<float>> &, int,
3298                           const DeviceMemory<std::complex<float>> &, int,
3299                           std::complex<float>,
3300                           DeviceMemory<std::complex<float>> *, int>
3301       impl;
3302   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3303               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3304 }
3305 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3306 Stream &Stream::ThenBlasGemvWithProfiling(
3307     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3308     const DeviceMemory<std::complex<double>> &a, int lda,
3309     const DeviceMemory<std::complex<double>> &x, int incx,
3310     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3311     blas::ProfileResult *output_profile_result) {
3312   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3313             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3314             PARAM(incy));
3315 
3316   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3317                           const DeviceMemory<std::complex<double>> &, int,
3318                           const DeviceMemory<std::complex<double>> &, int,
3319                           std::complex<double>,
3320                           DeviceMemory<std::complex<double>> *, int>
3321       impl;
3322   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3323               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3324 }
3325 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3326 Stream &Stream::ThenBlasGemmWithProfiling(
3327     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3328     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3329     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3330     DeviceMemory<Eigen::half> *c, int ldc,
3331     blas::ProfileResult *output_profile_result) {
3332   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3333             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3334             PARAM(beta), PARAM(c), PARAM(ldc));
3335 
3336   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3337                           uint64, float, const DeviceMemory<Eigen::half> &, int,
3338                           const DeviceMemory<Eigen::half> &, int, float,
3339                           DeviceMemory<Eigen::half> *, int>
3340       impl;
3341   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3342               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3343               output_profile_result);
3344 }
3345 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3346 Stream &Stream::ThenBlasGemmWithProfiling(
3347     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3348     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3349     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3350     int ldc, blas::ProfileResult *output_profile_result) {
3351   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3352             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3353             PARAM(beta), PARAM(c), PARAM(ldc));
3354 
3355   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3356                           uint64, float, const DeviceMemory<float> &, int,
3357                           const DeviceMemory<float> &, int, float,
3358                           DeviceMemory<float> *, int>
3359       impl;
3360   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3361               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3362               output_profile_result);
3363 }
3364 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3365 Stream &Stream::ThenBlasGemmWithProfiling(
3366     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3367     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3368     const DeviceMemory<double> &b, int ldb, double beta,
3369     DeviceMemory<double> *c, int ldc,
3370     blas::ProfileResult *output_profile_result) {
3371   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3372             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3373             PARAM(beta), PARAM(c), PARAM(ldc));
3374 
3375   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3376                           uint64, double, const DeviceMemory<double> &, int,
3377                           const DeviceMemory<double> &, int, double,
3378                           DeviceMemory<double> *, int>
3379       impl;
3380   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3381               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3382               output_profile_result);
3383 }
3384 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3385 Stream &Stream::ThenBlasGemmWithProfiling(
3386     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3387     uint64 k, std::complex<float> alpha,
3388     const DeviceMemory<std::complex<float>> &a, int lda,
3389     const DeviceMemory<std::complex<float>> &b, int ldb,
3390     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3391     blas::ProfileResult *output_profile_result) {
3392   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3393             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3394             PARAM(beta), PARAM(c), PARAM(ldc));
3395 
3396   ThenBlasWithProfileImpl<
3397       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3398       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3399       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3400       DeviceMemory<std::complex<float>> *, int>
3401       impl;
3402   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3403               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3404               output_profile_result);
3405 }
3406 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3407 Stream &Stream::ThenBlasGemmWithProfiling(
3408     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3409     uint64 k, std::complex<double> alpha,
3410     const DeviceMemory<std::complex<double>> &a, int lda,
3411     const DeviceMemory<std::complex<double>> &b, int ldb,
3412     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3413     blas::ProfileResult *output_profile_result) {
3414   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3415             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3416             PARAM(beta), PARAM(c), PARAM(ldc));
3417 
3418   ThenBlasWithProfileImpl<
3419       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3420       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3421       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3422       DeviceMemory<std::complex<double>> *, int>
3423       impl;
3424   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3425               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3426               output_profile_result);
3427 }
3428 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<Eigen::half> & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const HostOrDeviceScalar<Eigen::half> & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3429 Stream &Stream::ThenBlasGemmWithAlgorithm(
3430     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3431     uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
3432     const DeviceMemory<Eigen::half> &a, int lda,
3433     const DeviceMemory<Eigen::half> &b, int ldb,
3434     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
3435     int ldc, blas::ComputationType computation_type,
3436     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3437   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3438             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3439             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3440             PARAM(algorithm));
3441 
3442   ThenBlasWithProfileImpl<
3443       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3444       const HostOrDeviceScalar<Eigen::half> &,
3445       const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
3446       int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3447       int, blas::ComputationType, blas::AlgorithmType>
3448       impl;
3449   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3450               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3451               algorithm, output_profile_result);
3452 }
3453 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<int> & alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,const HostOrDeviceScalar<int> & beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3454 Stream &Stream::ThenBlasGemmWithAlgorithm(
3455     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3456     uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
3457     int lda, const DeviceMemory<int8> &b, int ldb,
3458     const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
3459     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3460     blas::ProfileResult *output_profile_result) {
3461   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3462             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3463             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3464             PARAM(algorithm));
3465 
3466   ThenBlasWithProfileImpl<
3467       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3468       const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
3469       const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
3470       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3471       impl;
3472   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3473               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3474               algorithm, output_profile_result);
3475 }
3476 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<float> & alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,const HostOrDeviceScalar<float> & beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3477 Stream &Stream::ThenBlasGemmWithAlgorithm(
3478     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3479     uint64 k, const HostOrDeviceScalar<float> &alpha,
3480     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
3481     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
3482     int ldc, blas::ComputationType computation_type,
3483     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3484   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3485             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3486             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3487             PARAM(algorithm));
3488 
3489   ThenBlasWithProfileImpl<
3490       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3491       const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
3492       const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
3493       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3494       impl;
3495   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3496               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3497               algorithm, output_profile_result);
3498 }
3499 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<double> & alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,const HostOrDeviceScalar<double> & beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3500 Stream &Stream::ThenBlasGemmWithAlgorithm(
3501     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3502     uint64 k, const HostOrDeviceScalar<double> &alpha,
3503     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
3504     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
3505     int ldc, blas::ComputationType computation_type,
3506     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3507   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3508             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3509             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3510             PARAM(algorithm));
3511 
3512   ThenBlasWithProfileImpl<
3513       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3514       const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
3515       const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
3516       DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
3517       impl;
3518   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3519               m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
3520               HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
3521               algorithm, output_profile_result);
3522 }
3523 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<float>> & alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,const HostOrDeviceScalar<std::complex<float>> & beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3524 Stream &Stream::ThenBlasGemmWithAlgorithm(
3525     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3526     uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
3527     const DeviceMemory<std::complex<float>> &a, int lda,
3528     const DeviceMemory<std::complex<float>> &b, int ldb,
3529     const HostOrDeviceScalar<std::complex<float>> &beta,
3530     DeviceMemory<std::complex<float>> *c, int ldc,
3531     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3532     blas::ProfileResult *output_profile_result) {
3533   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3534             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3535             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3536             PARAM(algorithm));
3537 
3538   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3539                           uint64,
3540                           const HostOrDeviceScalar<std::complex<float>> &,
3541                           const DeviceMemory<std::complex<float>> &, int,
3542                           const DeviceMemory<std::complex<float>> &, int,
3543                           const HostOrDeviceScalar<std::complex<float>> &,
3544                           DeviceMemory<std::complex<float>> *, int,
3545                           blas::ComputationType, blas::AlgorithmType>
3546       impl;
3547   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3548               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3549               algorithm, output_profile_result);
3550 }
3551 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const HostOrDeviceScalar<std::complex<double>> & alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,const HostOrDeviceScalar<std::complex<double>> & beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3552 Stream &Stream::ThenBlasGemmWithAlgorithm(
3553     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3554     uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
3555     const DeviceMemory<std::complex<double>> &a, int lda,
3556     const DeviceMemory<std::complex<double>> &b, int ldb,
3557     const HostOrDeviceScalar<std::complex<double>> &beta,
3558     DeviceMemory<std::complex<double>> *c, int ldc,
3559     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3560     blas::ProfileResult *output_profile_result) {
3561   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3562             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3563             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3564             PARAM(algorithm));
3565 
3566   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3567                           uint64,
3568                           const HostOrDeviceScalar<std::complex<double>> &,
3569                           const DeviceMemory<std::complex<double>> &, int,
3570                           const DeviceMemory<std::complex<double>> &, int,
3571                           const HostOrDeviceScalar<std::complex<double>> &,
3572                           DeviceMemory<std::complex<double>> *, int,
3573                           blas::ComputationType, blas::AlgorithmType>
3574       impl;
3575   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3576               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3577               algorithm, output_profile_result);
3578 }
3579 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3580 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3581                              uint64 n, std::complex<float> alpha,
3582                              const DeviceMemory<std::complex<float>> &a,
3583                              int lda,
3584                              const DeviceMemory<std::complex<float>> &b,
3585                              int ldb, std::complex<float> beta,
3586                              DeviceMemory<std::complex<float>> *c, int ldc) {
3587   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3588             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3589             PARAM(ldc));
3590 
3591   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3592                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3593                int, const DeviceMemory<std::complex<float>> &, int,
3594                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3595       impl;
3596   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3597               lda, b, ldb, beta, c, ldc);
3598 }
3599 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3600 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3601                              uint64 n, std::complex<double> alpha,
3602                              const DeviceMemory<std::complex<double>> &a,
3603                              int lda,
3604                              const DeviceMemory<std::complex<double>> &b,
3605                              int ldb, std::complex<double> beta,
3606                              DeviceMemory<std::complex<double>> *c, int ldc) {
3607   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3608             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3609             PARAM(ldc));
3610 
3611   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3612                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3613                int, const DeviceMemory<std::complex<double>> &, int,
3614                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3615       impl;
3616   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3617               lda, b, ldb, beta, c, ldc);
3618 }
3619 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3620 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3621                              uint64 n, uint64 k, float alpha,
3622                              const DeviceMemory<std::complex<float>> &a,
3623                              int lda, float beta,
3624                              DeviceMemory<std::complex<float>> *c, int ldc) {
3625   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3626             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3627 
3628   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3629                const DeviceMemory<std::complex<float>> &, int, float,
3630                DeviceMemory<std::complex<float>> *, int>
3631       impl;
3632   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3633               lda, beta, c, ldc);
3634 }
3635 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3636 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3637                              uint64 n, uint64 k, double alpha,
3638                              const DeviceMemory<std::complex<double>> &a,
3639                              int lda, double beta,
3640                              DeviceMemory<std::complex<double>> *c, int ldc) {
3641   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3642             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3643 
3644   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3645                const DeviceMemory<std::complex<double>> &, int, double,
3646                DeviceMemory<std::complex<double>> *, int>
3647       impl;
3648   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3649               lda, beta, c, ldc);
3650 }
3651 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3652 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3653                               uint64 n, uint64 k, std::complex<float> alpha,
3654                               const DeviceMemory<std::complex<float>> &a,
3655                               int lda,
3656                               const DeviceMemory<std::complex<float>> &b,
3657                               int ldb, float beta,
3658                               DeviceMemory<std::complex<float>> *c, int ldc) {
3659   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3660             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3661             PARAM(ldc));
3662 
3663   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3664                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3665                int, const DeviceMemory<std::complex<float>> &, int, float,
3666                DeviceMemory<std::complex<float>> *, int>
3667       impl;
3668   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3669               a, lda, b, ldb, beta, c, ldc);
3670 }
3671 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3672 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3673                               uint64 n, uint64 k, std::complex<double> alpha,
3674                               const DeviceMemory<std::complex<double>> &a,
3675                               int lda,
3676                               const DeviceMemory<std::complex<double>> &b,
3677                               int ldb, double beta,
3678                               DeviceMemory<std::complex<double>> *c, int ldc) {
3679   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3680             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3681             PARAM(ldc));
3682 
3683   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3684                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3685                int, const DeviceMemory<std::complex<double>> &, int, double,
3686                DeviceMemory<std::complex<double>> *, int>
3687       impl;
3688   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3689               a, lda, b, ldb, beta, c, ldc);
3690 }
3691 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3692 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3693                              uint64 n, float alpha,
3694                              const DeviceMemory<float> &a, int lda,
3695                              const DeviceMemory<float> &b, int ldb, float beta,
3696                              DeviceMemory<float> *c, int ldc) {
3697   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3698             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3699             PARAM(ldc));
3700 
3701   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
3702                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3703                int, float, DeviceMemory<float> *, int>
3704       impl;
3705   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3706               lda, b, ldb, beta, c, ldc);
3707 }
3708 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3709 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3710                              uint64 n, double alpha,
3711                              const DeviceMemory<double> &a, int lda,
3712                              const DeviceMemory<double> &b, int ldb,
3713                              double beta, DeviceMemory<double> *c, int ldc) {
3714   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3715             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3716             PARAM(ldc));
3717 
3718   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
3719                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3720                int, double, DeviceMemory<double> *, int>
3721       impl;
3722   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3723               lda, b, ldb, beta, c, ldc);
3724 }
3725 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3726 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3727                              uint64 n, std::complex<float> alpha,
3728                              const DeviceMemory<std::complex<float>> &a,
3729                              int lda,
3730                              const DeviceMemory<std::complex<float>> &b,
3731                              int ldb, std::complex<float> beta,
3732                              DeviceMemory<std::complex<float>> *c, int ldc) {
3733   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3734             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3735             PARAM(ldc));
3736 
3737   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3738                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3739                int, const DeviceMemory<std::complex<float>> &, int,
3740                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3741       impl;
3742   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3743               lda, b, ldb, beta, c, ldc);
3744 }
3745 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3746 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3747                              uint64 n, std::complex<double> alpha,
3748                              const DeviceMemory<std::complex<double>> &a,
3749                              int lda,
3750                              const DeviceMemory<std::complex<double>> &b,
3751                              int ldb, std::complex<double> beta,
3752                              DeviceMemory<std::complex<double>> *c, int ldc) {
3753   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3754             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3755             PARAM(ldc));
3756 
3757   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3758                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3759                int, const DeviceMemory<std::complex<double>> &, int,
3760                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3761       impl;
3762   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3763               lda, b, ldb, beta, c, ldc);
3764 }
3765 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)3766 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3767                              uint64 n, uint64 k, float alpha,
3768                              const DeviceMemory<float> &a, int lda, float beta,
3769                              DeviceMemory<float> *c, int ldc) {
3770   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3771             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3772 
3773   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3774                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3775                int>
3776       impl;
3777   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3778               lda, beta, c, ldc);
3779 }
3780 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)3781 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3782                              uint64 n, uint64 k, double alpha,
3783                              const DeviceMemory<double> &a, int lda,
3784                              double beta, DeviceMemory<double> *c, int ldc) {
3785   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3786             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3787 
3788   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3789                const DeviceMemory<double> &, int, double,
3790                DeviceMemory<double> *, int>
3791       impl;
3792   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3793               lda, beta, c, ldc);
3794 }
3795 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3796 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3797                              uint64 n, uint64 k, std::complex<float> alpha,
3798                              const DeviceMemory<std::complex<float>> &a,
3799                              int lda, std::complex<float> beta,
3800                              DeviceMemory<std::complex<float>> *c, int ldc) {
3801   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3802             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3803 
3804   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3805                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3806                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
3807                int>
3808       impl;
3809   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3810               lda, beta, c, ldc);
3811 }
3812 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3813 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3814                              uint64 n, uint64 k, std::complex<double> alpha,
3815                              const DeviceMemory<std::complex<double>> &a,
3816                              int lda, std::complex<double> beta,
3817                              DeviceMemory<std::complex<double>> *c, int ldc) {
3818   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3819             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3820 
3821   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3822                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3823                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
3824                int>
3825       impl;
3826   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3827               lda, beta, c, ldc);
3828 }
3829 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3830 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3831                               uint64 n, uint64 k, float alpha,
3832                               const DeviceMemory<float> &a, int lda,
3833                               const DeviceMemory<float> &b, int ldb, float beta,
3834                               DeviceMemory<float> *c, int ldc) {
3835   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3836             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3837             PARAM(ldc));
3838 
3839   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3840                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3841                int, float, DeviceMemory<float> *, int>
3842       impl;
3843   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3844               a, lda, b, ldb, beta, c, ldc);
3845 }
3846 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3847 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3848                               uint64 n, uint64 k, double alpha,
3849                               const DeviceMemory<double> &a, int lda,
3850                               const DeviceMemory<double> &b, int ldb,
3851                               double beta, DeviceMemory<double> *c, int ldc) {
3852   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3853             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3854             PARAM(ldc));
3855 
3856   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3857                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3858                int, double, DeviceMemory<double> *, int>
3859       impl;
3860   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3861               a, lda, b, ldb, beta, c, ldc);
3862 }
3863 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3864 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3865                               uint64 n, uint64 k, std::complex<float> alpha,
3866                               const DeviceMemory<std::complex<float>> &a,
3867                               int lda,
3868                               const DeviceMemory<std::complex<float>> &b,
3869                               int ldb, std::complex<float> beta,
3870                               DeviceMemory<std::complex<float>> *c, int ldc) {
3871   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3872             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3873             PARAM(ldc));
3874 
3875   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3876                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3877                int, const DeviceMemory<std::complex<float>> &, int,
3878                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3879       impl;
3880   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3881               a, lda, b, ldb, beta, c, ldc);
3882 }
3883 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3884 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3885                               uint64 n, uint64 k, std::complex<double> alpha,
3886                               const DeviceMemory<std::complex<double>> &a,
3887                               int lda,
3888                               const DeviceMemory<std::complex<double>> &b,
3889                               int ldb, std::complex<double> beta,
3890                               DeviceMemory<std::complex<double>> *c, int ldc) {
3891   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3892             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3893             PARAM(ldc));
3894 
3895   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3896                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3897                int, const DeviceMemory<std::complex<double>> &, int,
3898                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3899       impl;
3900   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3901               a, lda, b, ldb, beta, c, ldc);
3902 }
3903 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3904 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3905                              blas::Transpose transa, blas::Diagonal diag,
3906                              uint64 m, uint64 n, float alpha,
3907                              const DeviceMemory<float> &a, int lda,
3908                              DeviceMemory<float> *b, int ldb) {
3909   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3910             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3911 
3912   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3913                uint64, uint64, float, const DeviceMemory<float> &, int,
3914                DeviceMemory<float> *, int>
3915       impl;
3916   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3917               n, alpha, a, lda, b, ldb);
3918 }
3919 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3920 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3921                              blas::Transpose transa, blas::Diagonal diag,
3922                              uint64 m, uint64 n, double alpha,
3923                              const DeviceMemory<double> &a, int lda,
3924                              DeviceMemory<double> *b, int ldb) {
3925   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3926             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3927 
3928   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3929                uint64, uint64, double, const DeviceMemory<double> &, int,
3930                DeviceMemory<double> *, int>
3931       impl;
3932   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3933               n, alpha, a, lda, b, ldb);
3934 }
3935 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3936 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3937                              blas::Transpose transa, blas::Diagonal diag,
3938                              uint64 m, uint64 n, std::complex<float> alpha,
3939                              const DeviceMemory<std::complex<float>> &a,
3940                              int lda, DeviceMemory<std::complex<float>> *b,
3941                              int ldb) {
3942   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3943             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3944 
3945   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3946                uint64, uint64, std::complex<float>,
3947                const DeviceMemory<std::complex<float>> &, int,
3948                DeviceMemory<std::complex<float>> *, int>
3949       impl;
3950   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3951               n, alpha, a, lda, b, ldb);
3952 }
3953 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3954 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3955                              blas::Transpose transa, blas::Diagonal diag,
3956                              uint64 m, uint64 n, std::complex<double> alpha,
3957                              const DeviceMemory<std::complex<double>> &a,
3958                              int lda, DeviceMemory<std::complex<double>> *b,
3959                              int ldb) {
3960   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3961             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3962 
3963   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3964                uint64, uint64, std::complex<double>,
3965                const DeviceMemory<std::complex<double>> &, int,
3966                DeviceMemory<std::complex<double>> *, int>
3967       impl;
3968   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3969               n, alpha, a, lda, b, ldb);
3970 }
3971 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3972 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3973                              blas::Transpose transa, blas::Diagonal diag,
3974                              uint64 m, uint64 n, float alpha,
3975                              const DeviceMemory<float> &a, int lda,
3976                              DeviceMemory<float> *b, int ldb) {
3977   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3978             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3979 
3980   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3981                uint64, uint64, float, const DeviceMemory<float> &, int,
3982                DeviceMemory<float> *, int>
3983       impl;
3984   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3985               n, alpha, a, lda, b, ldb);
3986 }
3987 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3988 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3989                              blas::Transpose transa, blas::Diagonal diag,
3990                              uint64 m, uint64 n, double alpha,
3991                              const DeviceMemory<double> &a, int lda,
3992                              DeviceMemory<double> *b, int ldb) {
3993   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3994             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3995 
3996   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3997                uint64, uint64, double, const DeviceMemory<double> &, int,
3998                DeviceMemory<double> *, int>
3999       impl;
4000   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4001               n, alpha, a, lda, b, ldb);
4002 }
4003 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4004 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4005                              blas::Transpose transa, blas::Diagonal diag,
4006                              uint64 m, uint64 n, std::complex<float> alpha,
4007                              const DeviceMemory<std::complex<float>> &a,
4008                              int lda, DeviceMemory<std::complex<float>> *b,
4009                              int ldb) {
4010   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4011             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4012 
4013   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4014                uint64, uint64, std::complex<float>,
4015                const DeviceMemory<std::complex<float>> &, int,
4016                DeviceMemory<std::complex<float>> *, int>
4017       impl;
4018   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4019               n, alpha, a, lda, b, ldb);
4020 }
4021 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4022 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4023                              blas::Transpose transa, blas::Diagonal diag,
4024                              uint64 m, uint64 n, std::complex<double> alpha,
4025                              const DeviceMemory<std::complex<double>> &a,
4026                              int lda, DeviceMemory<std::complex<double>> *b,
4027                              int ldb) {
4028   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4029             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4030 
4031   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4032                uint64, uint64, std::complex<double>,
4033                const DeviceMemory<std::complex<double>> &, int,
4034                DeviceMemory<std::complex<double>> *, int>
4035       impl;
4036   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4037               n, alpha, a, lda, b, ldb);
4038 }
4039 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)4040 Stream &Stream::ThenBlasGemmBatched(
4041     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4042     uint64 k, float alpha,
4043     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4044     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4045     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4046     int batch_count) {
4047   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4048                                         b, ldb, beta, c, ldc, batch_count,
4049                                         /*scratch_allocator=*/nullptr);
4050 }
4051 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4052 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4053     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4054     uint64 k, float alpha,
4055     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
4056     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
4057     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
4058     int batch_count, ScratchAllocator *scratch_allocator) {
4059   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4060             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4061             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4062 
4063   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4064                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4065                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
4066                float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
4067                int, int, ScratchAllocator *>
4068       impl;
4069   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4070               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4071               scratch_allocator);
4072 }
4073 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4074 Stream &Stream::ThenBlasGemmBatched(
4075     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4076     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4077     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4078     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4079     int batch_count) {
4080   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4081                                         b, ldb, beta, c, ldc, batch_count,
4082                                         /*scratch_allocator=*/nullptr);
4083 }
4084 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4085 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4086     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4087     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4088     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4089     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4090     int batch_count, ScratchAllocator *scratch_allocator) {
4091   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4092             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4093             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4094 
4095   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4096                const port::ArraySlice<DeviceMemory<float> *> &, int,
4097                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4098                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4099                ScratchAllocator *>
4100       impl;
4101   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4102               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4103               scratch_allocator);
4104 }
4105 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4106 Stream &Stream::ThenBlasGemmBatched(
4107     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4108     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4109     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4110     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4111     int batch_count) {
4112   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4113                                         b, ldb, beta, c, ldc, batch_count,
4114                                         /*scratch_allocator=*/nullptr);
4115 }
4116 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4117 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4118     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4119     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4120     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4121     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4122     int batch_count, ScratchAllocator *scratch_allocator) {
4123   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4124             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4125             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4126 
4127   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4128                const port::ArraySlice<DeviceMemory<double> *> &, int,
4129                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4130                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4131                ScratchAllocator *>
4132       impl;
4133   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4134               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4135               scratch_allocator);
4136 }
4137 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4138 Stream &Stream::ThenBlasGemmBatched(
4139     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4140     uint64 k, std::complex<float> alpha,
4141     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4142     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4143     std::complex<float> beta,
4144     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4145     int batch_count) {
4146   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4147                                         b, ldb, beta, c, ldc, batch_count,
4148                                         /*scratch_allocator=*/nullptr);
4149 }
4150 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4151 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4152     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4153     uint64 k, std::complex<float> alpha,
4154     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4155     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4156     std::complex<float> beta,
4157     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4158     int batch_count, ScratchAllocator *scratch_allocator) {
4159   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4160             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4161             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4162 
4163   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4164                std::complex<float>,
4165                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4166                int,
4167                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4168                int, std::complex<float>,
4169                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4170                int, int, ScratchAllocator *>
4171       impl;
4172   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4173               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4174               scratch_allocator);
4175 }
4176 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4177 Stream &Stream::ThenBlasGemmBatched(
4178     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4179     uint64 k, std::complex<double> alpha,
4180     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4181     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4182     std::complex<double> beta,
4183     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4184     int batch_count) {
4185   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4186                                         b, ldb, beta, c, ldc, batch_count,
4187                                         /*scratch_allocator=*/nullptr);
4188 }
4189 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4190 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4191     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4192     uint64 k, std::complex<double> alpha,
4193     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4194     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4195     std::complex<double> beta,
4196     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4197     int batch_count, ScratchAllocator *scratch_allocator) {
4198   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4199             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4200             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4201 
4202   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4203                std::complex<double>,
4204                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4205                int,
4206                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4207                int, std::complex<double>,
4208                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4209                int, int, ScratchAllocator *>
4210       impl;
4211   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4212               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4213               scratch_allocator);
4214 }
4215 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,int64 stride_a,const DeviceMemory<Eigen::half> & b,int ldb,int64 stride_b,float beta,DeviceMemory<Eigen::half> * c,int ldc,int64 stride_c,int batch_count)4216 Stream &Stream::ThenBlasGemmStridedBatched(
4217     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4218     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
4219     int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
4220     float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
4221     int batch_count) {
4222   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4223             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4224             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4225             PARAM(stride_c), PARAM(batch_count));
4226 
4227   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4228                const DeviceMemory<Eigen::half> &, int, int64,
4229                const DeviceMemory<Eigen::half> &, int, int64, float,
4230                DeviceMemory<Eigen::half> *, int, int64, int>
4231       impl;
4232   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4233               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4234               c, ldc, stride_c, batch_count);
4235 }
4236 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,int64 stride_a,const DeviceMemory<float> & b,int ldb,int64 stride_b,float beta,DeviceMemory<float> * c,int ldc,int64 stride_c,int batch_count)4237 Stream &Stream::ThenBlasGemmStridedBatched(
4238     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4239     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
4240     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
4241     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
4242     int batch_count) {
4243   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4244             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4245             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4246             PARAM(stride_c), PARAM(batch_count));
4247 
4248   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4249                const DeviceMemory<float> &, int, int64,
4250                const DeviceMemory<float> &, int, int64, float,
4251                DeviceMemory<float> *, int, int64, int>
4252       impl;
4253   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4254               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4255               c, ldc, stride_c, batch_count);
4256 }
4257 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,int64 stride_a,const DeviceMemory<double> & b,int ldb,int64 stride_b,double beta,DeviceMemory<double> * c,int ldc,int64 stride_c,int batch_count)4258 Stream &Stream::ThenBlasGemmStridedBatched(
4259     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4260     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
4261     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
4262     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
4263     int batch_count) {
4264   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4265             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4266             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4267             PARAM(stride_c), PARAM(batch_count));
4268 
4269   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4270                const DeviceMemory<double> &, int, int64,
4271                const DeviceMemory<double> &, int, int64, double,
4272                DeviceMemory<double> *, int, int64, int>
4273       impl;
4274   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4275               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4276               c, ldc, stride_c, batch_count);
4277 }
4278 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<float>> & b,int ldb,int64 stride_b,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,int64 stride_c,int batch_count)4279 Stream &Stream::ThenBlasGemmStridedBatched(
4280     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4281     uint64 k, std::complex<float> alpha,
4282     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
4283     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
4284     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
4285     int64 stride_c, int batch_count) {
4286   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4287             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4288             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4289             PARAM(stride_c), PARAM(batch_count));
4290 
4291   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4292                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4293                int, int64, const DeviceMemory<std::complex<float>> &, int,
4294                int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
4295                int, int64, int>
4296       impl;
4297   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4298               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4299               c, ldc, stride_c, batch_count);
4300 }
4301 
ThenBlasGemmStridedBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,int64 stride_a,const DeviceMemory<std::complex<double>> & b,int ldb,int64 stride_b,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,int64 stride_c,int batch_count)4302 Stream &Stream::ThenBlasGemmStridedBatched(
4303     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4304     uint64 k, std::complex<double> alpha,
4305     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
4306     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
4307     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
4308     int64 stride_c, int batch_count) {
4309   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4310             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
4311             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
4312             PARAM(stride_c), PARAM(batch_count));
4313 
4314   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4315                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4316                int, int64, const DeviceMemory<std::complex<double>> &, int,
4317                int64, std::complex<double>,
4318                DeviceMemory<std::complex<double>> *, int, int64, int>
4319       impl;
4320   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
4321               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
4322               c, ldc, stride_c, batch_count);
4323 }
4324 
4325 template <typename ABType, typename CType>
ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<CType> & alpha,const DeviceMemory<ABType> & a,const DeviceMemory<ABType> & b,const HostOrDeviceScalar<CType> & beta,DeviceMemory<CType> * c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,const DeviceMemory<CType> & bias,blas::ProfileResult * output_profile_result)4326 Stream &Stream::ThenBlasLtMatmulImpl(
4327     const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
4328     const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
4329     const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
4330     ScratchAllocator *scratch_allocator,
4331     const blas::IBlasLtMatmulAlgorithm *algorithm,
4332     const DeviceMemory<CType> &bias,
4333     blas::ProfileResult *output_profile_result) {
4334   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
4335             PARAM(c), PARAM(algorithm), PARAM(bias));
4336 
4337   ThenBlasWithProfileImpl<
4338       const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
4339       const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
4340       const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
4341       ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4342       const DeviceMemory<CType> &>
4343       impl;
4344   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
4345               c, scratch_allocator, algorithm, bias, output_profile_result);
4346 }
4347 
4348 // Explicit template instantiations for each supported type combination.
4349 template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
4350     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
4351     const DeviceMemory<int8> &, const DeviceMemory<int8> &,
4352     const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
4353     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4354     const DeviceMemory<int32> &, blas::ProfileResult *);
4355 
4356 template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
4357     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
4358     const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
4359     const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
4360     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4361     const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
4362 
4363 template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
4364     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
4365     const DeviceMemory<float> &, const DeviceMemory<float> &,
4366     const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
4367     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4368     const DeviceMemory<float> &, blas::ProfileResult *);
4369 
4370 template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
4371     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
4372     const DeviceMemory<double> &, const DeviceMemory<double> &,
4373     const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
4374     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
4375     const DeviceMemory<double> &, blas::ProfileResult *);
4376 
4377 template Stream &
4378 Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
4379     const blas::IBlasLtMatmulPlan *,
4380     const HostOrDeviceScalar<std::complex<float>> &,
4381     const DeviceMemory<std::complex<float>> &,
4382     const DeviceMemory<std::complex<float>> &,
4383     const HostOrDeviceScalar<std::complex<float>> &,
4384     DeviceMemory<std::complex<float>> *, ScratchAllocator *,
4385     const blas::IBlasLtMatmulAlgorithm *,
4386     const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
4387 
4388 template Stream &
4389 Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
4390     const blas::IBlasLtMatmulPlan *,
4391     const HostOrDeviceScalar<std::complex<double>> &,
4392     const DeviceMemory<std::complex<double>> &,
4393     const DeviceMemory<std::complex<double>> &,
4394     const HostOrDeviceScalar<std::complex<double>> &,
4395     DeviceMemory<std::complex<double>> *, ScratchAllocator *,
4396     const blas::IBlasLtMatmulAlgorithm *,
4397     const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
4398 
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4399 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4400   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4401 
4402   if (rng::RngSupport *rng = parent_->AsRng()) {
4403     CheckError(rng->SetSeed(this, seed, seed_bytes));
4404   } else {
4405     SetError();
4406     LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
4407   }
4408   return *this;
4409 }
4410 
ThenPopulateRandUniform(DeviceMemory<float> * values)4411 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4412   VLOG_CALL(PARAM(values));
4413 
4414   if (rng::RngSupport *rng = parent_->AsRng()) {
4415     CheckError(rng->DoPopulateRandUniform(this, values));
4416   } else {
4417     SetError();
4418     LOG(INFO) << DebugStreamPointers()
4419               << " attempting to perform RNG operation using StreamExecutor"
4420                  " without RNG support.";
4421   }
4422   return *this;
4423 }
4424 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4425 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4426                                          DeviceMemory<float> *values) {
4427   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4428 
4429   if (rng::RngSupport *rng = parent_->AsRng()) {
4430     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4431   } else {
4432     SetError();
4433     LOG(INFO) << DebugStreamPointers()
4434               << " attempting to perform RNG operation using StreamExecutor"
4435                  " without RNG support.";
4436   }
4437   return *this;
4438 }
4439 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4440 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4441                                          DeviceMemory<double> *values) {
4442   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4443 
4444   if (rng::RngSupport *rng = parent_->AsRng()) {
4445     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4446   } else {
4447     SetError();
4448     LOG(INFO) << DebugStreamPointers()
4449               << " attempting to perform RNG operation using StreamExecutor"
4450                  " without RNG support.";
4451   }
4452   return *this;
4453 }
4454 
ThenPopulateRandUniform(DeviceMemory<double> * values)4455 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4456   VLOG_CALL(PARAM(values));
4457 
4458   if (rng::RngSupport *rng = parent_->AsRng()) {
4459     CheckError(rng->DoPopulateRandUniform(this, values));
4460   } else {
4461     SetError();
4462     LOG(INFO) << DebugStreamPointers()
4463               << " attempting to perform RNG operation using StreamExecutor"
4464                  " without RNG support.";
4465   }
4466   return *this;
4467 }
4468 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4469 Stream &Stream::ThenPopulateRandUniform(
4470     DeviceMemory<std::complex<float>> *values) {
4471   VLOG_CALL(PARAM(values));
4472 
4473   if (rng::RngSupport *rng = parent_->AsRng()) {
4474     CheckError(rng->DoPopulateRandUniform(this, values));
4475   } else {
4476     SetError();
4477     LOG(INFO) << DebugStreamPointers()
4478               << " attempting to perform RNG operation using StreamExecutor"
4479                  " without RNG support.";
4480   }
4481   return *this;
4482 }
4483 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4484 Stream &Stream::ThenPopulateRandUniform(
4485     DeviceMemory<std::complex<double>> *values) {
4486   VLOG_CALL(PARAM(values));
4487 
4488   if (rng::RngSupport *rng = parent_->AsRng()) {
4489     CheckError(rng->DoPopulateRandUniform(this, values));
4490   } else {
4491     SetError();
4492     LOG(INFO) << DebugStreamPointers()
4493               << " attempting to perform RNG operation using StreamExecutor"
4494                  " without RNG support.";
4495   }
4496   return *this;
4497 }
4498 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4499 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4500                            uint64 size) {
4501   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4502 
4503   CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4504   return *this;
4505 }
4506 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4507 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4508                            uint64 size) {
4509   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4510 
4511   CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4512   return *this;
4513 }
4514 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4515 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4516                            const DeviceMemoryBase &gpu_src, uint64 size) {
4517   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4518 
4519   CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4520   return *this;
4521 }
4522 
ThenMemZero(DeviceMemoryBase * location,uint64 size)4523 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4524   VLOG_CALL(PARAM(location), PARAM(size));
4525 
4526   CheckStatus(parent_->MemZero(this, location, size));
4527   return *this;
4528 }
4529 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4530 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4531                              uint64 size) {
4532   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4533 
4534   CheckStatus(parent_->Memset32(this, location, pattern, size));
4535   return *this;
4536 }
4537 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4538 Stream &Stream::ThenRnnForward(
4539     const dnn::RnnDescriptor &rnn_desc,
4540     const dnn::RnnSequenceTensorDescriptor &input_desc,
4541     const DeviceMemory<Eigen::half> &input_data,
4542     const dnn::RnnStateTensorDescriptor &input_h_desc,
4543     const DeviceMemory<Eigen::half> &input_h_data,
4544     const dnn::RnnStateTensorDescriptor &input_c_desc,
4545     const DeviceMemory<Eigen::half> &input_c_data,
4546     const DeviceMemory<Eigen::half> &params,
4547     const dnn::RnnSequenceTensorDescriptor &output_desc,
4548     DeviceMemory<Eigen::half> *output_data,
4549     const dnn::RnnStateTensorDescriptor &output_h_desc,
4550     DeviceMemory<Eigen::half> *output_h_data,
4551     const dnn::RnnStateTensorDescriptor &output_c_desc,
4552     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4553     ScratchAllocator *reserve_space_allocator,
4554     ScratchAllocator *workspace_allocator,
4555     dnn::ProfileResult *output_profile_result) {
4556   // TODO(zhengxq): add VLOG PARAM calls.
4557   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4558     auto status = dnn->DoRnnForward(
4559         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4560         input_c_desc, input_c_data, params, output_desc, output_data,
4561         output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4562         reserve_space_allocator, workspace_allocator, output_profile_result);
4563     if (!status && !output_profile_result) {
4564       SetError();
4565     }
4566   } else {
4567     SetErrorAndLogNoDnnSupport();
4568   }
4569   return *this;
4570 }
4571 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4572 Stream &Stream::ThenRnnForward(
4573     const dnn::RnnDescriptor &rnn_desc,
4574     const dnn::RnnSequenceTensorDescriptor &input_desc,
4575     const DeviceMemory<float> &input_data,
4576     const dnn::RnnStateTensorDescriptor &input_h_desc,
4577     const DeviceMemory<float> &input_h_data,
4578     const dnn::RnnStateTensorDescriptor &input_c_desc,
4579     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4580     const dnn::RnnSequenceTensorDescriptor &output_desc,
4581     DeviceMemory<float> *output_data,
4582     const dnn::RnnStateTensorDescriptor &output_h_desc,
4583     DeviceMemory<float> *output_h_data,
4584     const dnn::RnnStateTensorDescriptor &output_c_desc,
4585     DeviceMemory<float> *output_c_data, bool is_training,
4586     ScratchAllocator *reserve_space_allocator,
4587     ScratchAllocator *workspace_allocator,
4588     dnn::ProfileResult *output_profile_result) {
4589   // TODO(zhengxq): add VLOG PARAM calls.
4590   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4591     auto status = dnn->DoRnnForward(
4592         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4593         input_c_desc, input_c_data, params, output_desc, output_data,
4594         output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4595         reserve_space_allocator, workspace_allocator, output_profile_result);
4596     if (!status && !output_profile_result) {
4597       SetError();
4598     }
4599   } else {
4600     SetErrorAndLogNoDnnSupport();
4601   }
4602   return *this;
4603 }
4604 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4605 Stream &Stream::ThenRnnForward(
4606     const dnn::RnnDescriptor &rnn_desc,
4607     const dnn::RnnSequenceTensorDescriptor &input_desc,
4608     const DeviceMemory<double> &input_data,
4609     const dnn::RnnStateTensorDescriptor &input_h_desc,
4610     const DeviceMemory<double> &input_h_data,
4611     const dnn::RnnStateTensorDescriptor &input_c_desc,
4612     const DeviceMemory<double> &input_c_data,
4613     const DeviceMemory<double> &params,
4614     const dnn::RnnSequenceTensorDescriptor &output_desc,
4615     DeviceMemory<double> *output_data,
4616     const dnn::RnnStateTensorDescriptor &output_h_desc,
4617     DeviceMemory<double> *output_h_data,
4618     const dnn::RnnStateTensorDescriptor &output_c_desc,
4619     DeviceMemory<double> *output_c_data, bool is_training,
4620     ScratchAllocator *reserve_space_allocator,
4621     ScratchAllocator *workspace_allocator,
4622     dnn::ProfileResult *output_profile_result) {
4623   // TODO(zhengxq): add VLOG PARAM calls.
4624   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4625     auto status = dnn->DoRnnForward(
4626         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4627         input_c_desc, input_c_data, params, output_desc, output_data,
4628         output_h_desc, output_h_data, output_c_desc, output_c_data, is_training,
4629         reserve_space_allocator, workspace_allocator, output_profile_result);
4630     if (!status && !output_profile_result) {
4631       SetError();
4632     }
4633   } else {
4634     SetErrorAndLogNoDnnSupport();
4635   }
4636   return *this;
4637 }
4638 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4639 Stream &Stream::ThenRnnBackward(
4640     const dnn::RnnDescriptor &rnn_desc,
4641     const dnn::RnnSequenceTensorDescriptor &input_desc,
4642     const DeviceMemory<Eigen::half> &input_data,
4643     const dnn::RnnStateTensorDescriptor &input_h_desc,
4644     const DeviceMemory<Eigen::half> &input_h_data,
4645     const dnn::RnnStateTensorDescriptor &input_c_desc,
4646     const DeviceMemory<Eigen::half> &input_c_data,
4647     const DeviceMemory<Eigen::half> &params,
4648     const dnn::RnnSequenceTensorDescriptor &output_desc,
4649     const DeviceMemory<Eigen::half> &output_data,
4650     const dnn::RnnStateTensorDescriptor &output_h_desc,
4651     const DeviceMemory<Eigen::half> &output_h_data,
4652     const dnn::RnnStateTensorDescriptor &output_c_desc,
4653     const DeviceMemory<Eigen::half> &output_c_data,
4654     const DeviceMemory<Eigen::half> &output_backprop_data,
4655     const DeviceMemory<Eigen::half> &output_h_backprop_data,
4656     const DeviceMemory<Eigen::half> &output_c_backprop_data,
4657     DeviceMemory<Eigen::half> *input_backprop_data,
4658     DeviceMemory<Eigen::half> *input_h_backprop_data,
4659     DeviceMemory<Eigen::half> *input_c_backprop_data,
4660     DeviceMemory<Eigen::half> *params_backprop_data,
4661     DeviceMemory<uint8> *reserve_space_data,
4662     ScratchAllocator *workspace_allocator,
4663     dnn::ProfileResult *output_profile_result) {
4664   // TODO(zhengxq): add VLOG PARAM calls.
4665   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4666     auto status = dnn->DoRnnBackward(
4667         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4668         input_c_desc, input_c_data, params, output_desc, output_data,
4669         output_h_desc, output_h_data, output_c_desc, output_c_data,
4670         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4671         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4672         params_backprop_data, reserve_space_data, workspace_allocator,
4673         output_profile_result);
4674     if (!status && !output_profile_result) {
4675       SetError();
4676     }
4677   } else {
4678     SetError();
4679     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4680   }
4681   return *this;
4682 }
4683 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4684 Stream &Stream::ThenRnnBackward(
4685     const dnn::RnnDescriptor &rnn_desc,
4686     const dnn::RnnSequenceTensorDescriptor &input_desc,
4687     const DeviceMemory<float> &input_data,
4688     const dnn::RnnStateTensorDescriptor &input_h_desc,
4689     const DeviceMemory<float> &input_h_data,
4690     const dnn::RnnStateTensorDescriptor &input_c_desc,
4691     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4692     const dnn::RnnSequenceTensorDescriptor &output_desc,
4693     const DeviceMemory<float> &output_data,
4694     const dnn::RnnStateTensorDescriptor &output_h_desc,
4695     const DeviceMemory<float> &output_h_data,
4696     const dnn::RnnStateTensorDescriptor &output_c_desc,
4697     const DeviceMemory<float> &output_c_data,
4698     const DeviceMemory<float> &output_backprop_data,
4699     const DeviceMemory<float> &output_h_backprop_data,
4700     const DeviceMemory<float> &output_c_backprop_data,
4701     DeviceMemory<float> *input_backprop_data,
4702     DeviceMemory<float> *input_h_backprop_data,
4703     DeviceMemory<float> *input_c_backprop_data,
4704     DeviceMemory<float> *params_backprop_data,
4705     DeviceMemory<uint8> *reserve_space_data,
4706     ScratchAllocator *workspace_allocator,
4707     dnn::ProfileResult *output_profile_result) {
4708   // TODO(zhengxq): add VLOG PARAM calls.
4709   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4710     auto status = dnn->DoRnnBackward(
4711         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4712         input_c_desc, input_c_data, params, output_desc, output_data,
4713         output_h_desc, output_h_data, output_c_desc, output_c_data,
4714         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4715         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4716         params_backprop_data, reserve_space_data, workspace_allocator,
4717         output_profile_result);
4718     if (!status && !output_profile_result) {
4719       SetError();
4720     }
4721   } else {
4722     SetError();
4723     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4724   }
4725   return *this;
4726 }
4727 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4728 Stream &Stream::ThenRnnBackward(
4729     const dnn::RnnDescriptor &rnn_desc,
4730     const dnn::RnnSequenceTensorDescriptor &input_desc,
4731     const DeviceMemory<double> &input_data,
4732     const dnn::RnnStateTensorDescriptor &input_h_desc,
4733     const DeviceMemory<double> &input_h_data,
4734     const dnn::RnnStateTensorDescriptor &input_c_desc,
4735     const DeviceMemory<double> &input_c_data,
4736     const DeviceMemory<double> &params,
4737     const dnn::RnnSequenceTensorDescriptor &output_desc,
4738     const DeviceMemory<double> &output_data,
4739     const dnn::RnnStateTensorDescriptor &output_h_desc,
4740     const DeviceMemory<double> &output_h_data,
4741     const dnn::RnnStateTensorDescriptor &output_c_desc,
4742     const DeviceMemory<double> &output_c_data,
4743     const DeviceMemory<double> &output_backprop_data,
4744     const DeviceMemory<double> &output_h_backprop_data,
4745     const DeviceMemory<double> &output_c_backprop_data,
4746     DeviceMemory<double> *input_backprop_data,
4747     DeviceMemory<double> *input_h_backprop_data,
4748     DeviceMemory<double> *input_c_backprop_data,
4749     DeviceMemory<double> *params_backprop_data,
4750     DeviceMemory<uint8> *reserve_space_data,
4751     ScratchAllocator *workspace_allocator,
4752     dnn::ProfileResult *output_profile_result) {
4753   // TODO(zhengxq): add VLOG PARAM calls.
4754   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4755     auto status = dnn->DoRnnBackward(
4756         this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4757         input_c_desc, input_c_data, params, output_desc, output_data,
4758         output_h_desc, output_h_data, output_c_desc, output_c_data,
4759         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4760         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4761         params_backprop_data, reserve_space_data, workspace_allocator,
4762         output_profile_result);
4763     if (!status && !output_profile_result) {
4764       SetError();
4765     }
4766   } else {
4767     SetError();
4768     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4769   }
4770   return *this;
4771 }
4772 
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)4773 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
4774                             const DeviceMemory<float> &probs_data,
4775                             absl::Span<const int> labels_data,
4776                             absl::Span<const int> labels_lengths_data,
4777                             absl::Span<const int> input_lengths_data,
4778                             DeviceMemory<float> *costs_data,
4779                             const dnn::RnnStateTensorDescriptor &grads_desc,
4780                             DeviceMemory<float> *grads_data,
4781                             ScratchAllocator *workspace_allocator) {
4782   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4783     DeviceMemory<uint8> scratch_memory;
4784     int ctc_loss_algo_id;
4785     auto status =
4786         dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
4787                                labels_data, labels_lengths_data,
4788                                input_lengths_data, workspace_allocator,
4789                                &scratch_memory, &ctc_loss_algo_id)
4790             .ok();
4791     if (status) {
4792       status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
4793                               labels_lengths_data, input_lengths_data,
4794                               costs_data, grads_desc, grads_data,
4795                               &scratch_memory, ctc_loss_algo_id);
4796     }
4797     if (!status) {
4798       SetError();
4799     }
4800   } else {
4801     SetErrorAndLogNoDnnSupport();
4802   }
4803   return *this;
4804 }
4805 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4806 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4807                                     dnn::DataType input_type,
4808                                     const DeviceMemoryBase &input_data,
4809                                     const dnn::BatchDescriptor &output_desc,
4810                                     dnn::DataType output_type, float scale,
4811                                     DeviceMemoryBase *output_data) {
4812   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4813             PARAM(output_desc), PARAM(output_type), PARAM(scale),
4814             PARAM(output_data));
4815   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4816     CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
4817                                       output_desc, output_type, scale,
4818                                       output_data));
4819   } else {
4820     SetErrorAndLogNoDnnSupport();
4821   }
4822   return *this;
4823 }
4824 
ThenDoHostCallback(std::function<void ()> callback)4825 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4826   VLOG_CALL(PARAM(callback));
4827 
4828   if (!ok()) {
4829     LOG(INFO) << DebugStreamPointers()
4830               << " was in error state before adding host callback";
4831   }
4832   CheckError(parent_->HostCallback(this, std::move(callback)));
4833   return *this;
4834 }
4835 
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)4836 Stream &Stream::ThenDoHostCallbackWithStatus(
4837     std::function<port::Status()> callback) {
4838   VLOG_CALL(PARAM(callback));
4839 
4840   if (!ok()) {
4841     LOG(INFO) << DebugStreamPointers()
4842               << " was in error state before adding host callback";
4843   }
4844   CheckError(parent_->HostCallback(this, std::move(callback)));
4845   return *this;
4846 }
4847 
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)4848 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
4849     std::function<void()> callback) {
4850   VLOG_CALL(PARAM(callback));
4851 
4852   if (!ok()) {
4853     LOG(INFO) << DebugStreamPointers()
4854               << " was in error state before adding callback to be run after "
4855                  "next block-host-until-done.";
4856   }
4857   absl::MutexLock lock(&mu_);
4858   after_block_host_until_done_callbacks_.push_back(std::move(callback));
4859   return *this;
4860 }
4861 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4862 Stream &Stream::ThenFft(fft::Plan *plan,
4863                         const DeviceMemory<std::complex<float>> &input,
4864                         DeviceMemory<std::complex<float>> *output) {
4865   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4866 
4867   if (fft::FftSupport *fft = parent_->AsFft()) {
4868     CheckError(fft->DoFft(this, plan, input, output));
4869   } else {
4870     SetError();
4871     LOG(INFO) << DebugStreamPointers()
4872               << " attempting to perform FFT operation using StreamExecutor"
4873                  " without FFT support";
4874   }
4875   return *this;
4876 }
4877 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4878 Stream &Stream::ThenFft(fft::Plan *plan,
4879                         const DeviceMemory<std::complex<double>> &input,
4880                         DeviceMemory<std::complex<double>> *output) {
4881   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4882 
4883   if (fft::FftSupport *fft = parent_->AsFft()) {
4884     CheckError(fft->DoFft(this, plan, input, output));
4885   } else {
4886     SetError();
4887     LOG(INFO) << DebugStreamPointers()
4888               << " attempting to perform FFT operation using StreamExecutor"
4889                  " without FFT support";
4890   }
4891   return *this;
4892 }
4893 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4894 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4895                         DeviceMemory<std::complex<float>> *output) {
4896   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4897 
4898   if (fft::FftSupport *fft = parent_->AsFft()) {
4899     CheckError(fft->DoFft(this, plan, input, output));
4900   } else {
4901     SetError();
4902     LOG(INFO) << DebugStreamPointers()
4903               << " attempting to perform FFT operation using StreamExecutor"
4904                  " without FFT support";
4905   }
4906   return *this;
4907 }
4908 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4909 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4910                         DeviceMemory<std::complex<double>> *output) {
4911   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4912 
4913   if (fft::FftSupport *fft = parent_->AsFft()) {
4914     CheckError(fft->DoFft(this, plan, input, output));
4915   } else {
4916     SetError();
4917     LOG(INFO) << DebugStreamPointers()
4918               << " attempting to perform FFT operation using StreamExecutor"
4919                  " without FFT support";
4920   }
4921   return *this;
4922 }
4923 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)4924 Stream &Stream::ThenFft(fft::Plan *plan,
4925                         const DeviceMemory<std::complex<float>> &input,
4926                         DeviceMemory<float> *output) {
4927   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4928 
4929   if (fft::FftSupport *fft = parent_->AsFft()) {
4930     CheckError(fft->DoFft(this, plan, input, output));
4931   } else {
4932     SetError();
4933     LOG(INFO) << DebugStreamPointers()
4934               << " attempting to perform FFT operation using StreamExecutor"
4935                  " without FFT support";
4936   }
4937   return *this;
4938 }
4939 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)4940 Stream &Stream::ThenFft(fft::Plan *plan,
4941                         const DeviceMemory<std::complex<double>> &input,
4942                         DeviceMemory<double> *output) {
4943   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4944 
4945   if (fft::FftSupport *fft = parent_->AsFft()) {
4946     CheckError(fft->DoFft(this, plan, input, output));
4947   } else {
4948     SetError();
4949     LOG(INFO) << DebugStreamPointers()
4950               << " attempting to perform FFT operation using StreamExecutor"
4951                  " without FFT support";
4952   }
4953   return *this;
4954 }
4955 
4956 // It looks confusing, but all this is doing is inserting a callback at the
4957 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)4958 Stream &Stream::ThenEnqueueOnBackgroundThread(
4959     std::function<void(StreamExecutor *)> task) {
4960   VLOG_CALL(PARAM(task));
4961 
4962   StreamExecutor *stream_executor = this->parent_;
4963   std::function<void()> bound_task = std::bind(task, stream_executor);
4964 
4965   return ThenDoHostCallback([stream_executor, bound_task]() {
4966     stream_executor->EnqueueOnBackgroundThread(bound_task);
4967   });
4968 }
4969 
BlockHostUntilDone()4970 port::Status Stream::BlockHostUntilDone() {
4971   VLOG_CALL();
4972 
4973   if (!ok()) {
4974     port::Status status = port::Status(
4975         port::error::INTERNAL,
4976         "stream did not block host until done; was already in an error state");
4977     LOG(INFO) << DebugStreamPointers() << " " << status;
4978     return status;
4979   }
4980 
4981   temporary_memory_manager_.DeallocateFinalizedTemporaries();
4982 
4983   port::Status error = parent_->BlockHostUntilDone(this);
4984   CheckError(error.ok());
4985 
4986   RunAfterBlockHostUntilDoneCallbacks();
4987   return error;
4988 }
4989 
RunAfterBlockHostUntilDoneCallbacks()4990 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
4991   std::vector<std::function<void()>> callbacks;
4992   {
4993     absl::MutexLock lock(&mu_);
4994     std::swap(callbacks, after_block_host_until_done_callbacks_);
4995   }
4996   for (const auto &fn : callbacks) {
4997     fn();
4998   }
4999 }
5000 
DebugStreamPointers() const5001 std::string Stream::DebugStreamPointers() const {
5002   // Relies on the ToVlogString(const void*) overload above.
5003   return absl::StrCat("[stream=", ToVlogString(this),
5004                       ",impl=", ToVlogString(implementation_.get()), "]");
5005 }
5006 
CheckStatus(port::Status status)5007 void Stream::CheckStatus(port::Status status) {
5008   if (status.ok()) {
5009     return;
5010   }
5011   LOG(ERROR) << status;
5012   absl::MutexLock lock(&mu_);
5013   status_ = status;
5014 }
5015 
5016 }  // namespace stream_executor
5017