• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/stream.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "third_party/eigen3/Eigen/Core"
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_or_device_scalar.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/platform.h"
24 #include "tensorflow/stream_executor/platform/logging.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29 
30 namespace stream_executor {
31 
32 namespace {
33 // Code to turn parameters to functions on stream into strings that
34 // will be VLOG'ed. We need overloads, instead of
35 // e.g. BatchDescriptorToVlogString(), as the code that calls these
36 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)37 std::string ToVlogString(const dnn::BatchDescriptor &descriptor) {
38   return descriptor.ToShortString();
39 }
40 
ToVlogString(const dnn::FilterDescriptor & descriptor)41 std::string ToVlogString(const dnn::FilterDescriptor &descriptor) {
42   return descriptor.ToShortString();
43 }
44 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)45 std::string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
46   return descriptor.ToShortString();
47 }
48 
ToVlogString(const dnn::PoolingDescriptor & descriptor)49 std::string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
50   return descriptor.ToShortString();
51 }
52 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)53 std::string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
54   return descriptor.ToShortString();
55 }
56 
ToVlogString(dnn::ActivationMode mode)57 std::string ToVlogString(dnn::ActivationMode mode) {
58   return dnn::ActivationModeString(mode);
59 }
60 
ToVlogString(const dnn::AlgorithmConfig & algo_config)61 std::string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
62   return algo_config.ToString();
63 }
64 
ToVlogString(dnn::ElementwiseOperation op)65 std::string ToVlogString(dnn::ElementwiseOperation op) {
66   return dnn::ElementwiseOperationString(op);
67 }
68 
ToVlogString(dnn::QuantizedActivationMode mode)69 std::string ToVlogString(dnn::QuantizedActivationMode mode) {
70   return dnn::QuantizedActivationModeString(mode);
71 }
72 
ToVlogString(blas::Transpose t)73 std::string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
74 
ToVlogString(blas::UpperLower ul)75 std::string ToVlogString(blas::UpperLower ul) {
76   return blas::UpperLowerString(ul);
77 }
78 
ToVlogString(blas::Diagonal d)79 std::string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
80 
ToVlogString(blas::Side s)81 std::string ToVlogString(blas::Side s) { return blas::SideString(s); }
82 
ToVlogString(blas::ComputationType ty)83 std::string ToVlogString(blas::ComputationType ty) {
84   return blas::ComputationTypeString(ty);
85 }
86 
ToVlogString(const void * ptr)87 std::string ToVlogString(const void *ptr) {
88   if (ptr == nullptr) {
89     return "null";
90   }
91 
92   // StrCat does not convert pointers to text.
93   std::ostringstream out;
94   out << ptr;
95   return out.str();
96 }
97 
98 template <class T>
ToVlogString(const std::complex<T> & c)99 std::string ToVlogString(const std::complex<T> &c) {
100   // StrCat does not convert std::complex to text.
101   std::ostringstream out;
102   out << c;
103   return out.str();
104 }
105 
106 template <class T>
ToVlogString(const std::function<T> & f)107 std::string ToVlogString(const std::function<T> &f) {
108   return f == nullptr ? "null" : "<non-null function>";
109 }
110 
ToVlogString(const DeviceMemoryBase & memory)111 std::string ToVlogString(const DeviceMemoryBase &memory) {
112   return ToVlogString(memory.opaque());
113 }
114 
ToVlogString(const DeviceMemoryBase * memory)115 std::string ToVlogString(const DeviceMemoryBase *memory) {
116   return memory == nullptr ? "null" : ToVlogString(*memory);
117 }
118 
ToVlogString(const Eigen::half & h)119 std::string ToVlogString(const Eigen::half &h) {
120   return absl::StrCat(static_cast<float>(h));
121 }
122 
ToVlogString(int i)123 std::string ToVlogString(int i) { return absl::StrCat(i); }
124 
ToVlogString(uint32 i)125 std::string ToVlogString(uint32 i) { return absl::StrCat(i); }
126 
ToVlogString(uint64 i)127 std::string ToVlogString(uint64 i) { return absl::StrCat(i); }
128 
ToVlogString(int64_t i)129 std::string ToVlogString(int64_t i) { return absl::StrCat(i); }
130 
ToVlogString(float f)131 std::string ToVlogString(float f) { return absl::StrCat(f); }
132 
ToVlogString(double d)133 std::string ToVlogString(double d) { return absl::StrCat(d); }
134 
135 template <typename T>
ToVlogString(const HostOrDeviceScalar<T> & memory_or_constant)136 std::string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
137   if (memory_or_constant.is_pointer()) {
138     return ToVlogString(memory_or_constant.pointer());
139   }
140   return ToVlogString(memory_or_constant.value());
141 }
142 
143 template <class T>
ToVlogString(port::ArraySlice<T> elements)144 std::string ToVlogString(port::ArraySlice<T> elements) {
145   std::string str = absl::StrCat(
146       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
147       elements.size(), "]{");
148   const char *separator = "";
149   size_t max_to_show = std::numeric_limits<size_t>::max();
150   if (!VLOG_IS_ON(2)) {
151     max_to_show = 5;
152   } else if (!VLOG_IS_ON(3)) {
153     max_to_show = 20;
154   } else if (!VLOG_IS_ON(11)) {
155     max_to_show = 1000;
156   }
157   for (size_t i = 0; i < elements.size(); ++i) {
158     if (i == max_to_show) {
159       str += ", ...";
160       break;
161     }
162     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
163     separator = ", ";
164   }
165   str += "}";
166   return str;
167 }
168 
169 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)170 std::string ToVlogString(port::MutableArraySlice<T> elements) {
171   return ToVlogString(port::ArraySlice<T>(elements));
172 }
173 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)174 std::string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
175   switch (depth_to_space_layout) {
176     case dnn::DepthToSpaceLayout::DepthHeightWidth:
177       return "DepthToSpaceLayout::DepthHeightWidth";
178   }
179   return "unknown DepthToSpaceLayout";
180 }
181 
ToVlogString(dnn::DataType data_type)182 std::string ToVlogString(dnn::DataType data_type) {
183   switch (data_type) {
184     case dnn::DataType::kFloat:
185       return "dnn::DataType::kFloat";
186     case dnn::DataType::kDouble:
187       return "dnn::DataType::kDouble";
188     case dnn::DataType::kHalf:
189       return "dnn::DataType::kHalf";
190     case dnn::DataType::kInt8:
191       return "dnn::DataType::kInt8";
192     case dnn::DataType::kInt32:
193       return "dnn::DataType::kInt32";
194     default:
195       return "unknown DataType";
196   }
197 }
198 
199 // Used together with PARAM to VLOG calls made to the stream. Intended
200 // to be used like this:
201 //
202 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
203 //
204 // where a and b are the parameters to MyFunction.
205 //
206 // See VLOG_CALL for a short-hand for this. This way of doing it saves
207 // a tremendous amount of boilerplate code given how many functions
208 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,std::string>> params)209 std::string CallStr(const char *function_name, Stream *stream,
210                     std::vector<std::pair<const char *, std::string>> params) {
211   // Do not call this function unless VLOG is on since just
212   // constructing all the strings in params is expensive.
213   CHECK(VLOG_IS_ON(1));
214 
215   std::string str = absl::StrCat(stream->DebugStreamPointers(),
216                                  " Called Stream::", function_name, "(");
217   const char *separator = "";
218   for (const auto &param : params) {
219     absl::StrAppend(&str, separator, param.first, "=", param.second);
220     separator = ", ";
221   }
222   absl::StrAppend(&str, ")");
223   if (VLOG_IS_ON(10)) {
224     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
225   }
226   return str;
227 }
228 
229 // Use this macro to avoid having to type every parameter twice to log
230 // it with VLOG and CallStr.
231 #define PARAM(parameter) \
232   { #parameter, ToVlogString(parameter) }
233 
234 // Use this macro to avoid having to type out the name of each
235 // function and to save some boilerplate. Intended to be used like this:
236 //
237 //   VLOG_CALL(PARAM(a), PARAM(b))
238 //
239 // This saves a tremendous amount of boilerplate compared to the alternative:
240 //
241 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
242 //           << ", b=" << ToVlogString(b);
243 //
244 // Note here that most of the parameter names are not short and that
245 // most of the functions take many more than 2 parameters.
246 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
247 
248 }  // namespace
249 
Stream(StreamExecutor * parent)250 Stream::Stream(StreamExecutor *parent)
251     : parent_(parent),
252       implementation_(parent->implementation()->GetStreamImplementation()),
253       allocated_(false),
254       status_(port::InternalError("Uninitialized stream")),
255       temporary_memory_manager_(this) {
256   VLOG_CALL(PARAM(parent));
257 }
258 
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)259 Stream::Stream(StreamExecutor *parent,
260                internal::StreamInterface *implementation)
261     : parent_(parent),
262       implementation_(implementation),
263       allocated_(false),
264       status_(port::InternalError("Uninitialized stream")),
265       temporary_memory_manager_(this) {
266   VLOG_CALL(PARAM(parent), PARAM(implementation));
267 }
268 
~Stream()269 Stream::~Stream() {
270   VLOG_CALL();
271 
272   // Ensure the stream is completed.
273   auto status = BlockHostUntilDone();
274   if (!status.ok()) {
275     LOG(WARNING) << "Error blocking host until done in stream destructor: "
276                  << status;
277   }
278   temporary_memory_manager_.ForceDeallocateAll();
279   RunAfterBlockHostUntilDoneCallbacks();
280 
281   if (allocated_) {
282     parent_->DeallocateStream(this);
283   }
284 }
285 
RefreshStatus()286 port::Status Stream::RefreshStatus() {
287   port::Status status = parent_->GetStatus(this);
288   // We should not put the stream in an error state, just because the GetStatus
289   // method is unimplemented.
290   if (status != port::Status(port::error::UNIMPLEMENTED,
291                              "GetStatus is not supported on this executor.")) {
292     CheckStatus(status);
293   }
294   return status;
295 }
296 
Init()297 Stream &Stream::Init() {
298   VLOG_CALL();
299 
300   absl::MutexLock lock(&mu_);
301   CHECK_EQ(false, allocated_)
302       << "stream appears to already have been initialized";
303   CHECK(!status_.ok()) << "stream should be in !ok() state pre-initialization";
304 
305   if (parent_->AllocateStream(this)) {
306     // Successful initialization!
307     allocated_ = true;
308     status_ = port::Status::OK();
309   } else {
310     LOG(ERROR) << "failed to allocate stream during initialization";
311   }
312 
313   return *this;
314 }
315 
InitTimer(Timer * timer)316 Stream &Stream::InitTimer(Timer *timer) {
317   VLOG_CALL(PARAM(timer));
318 
319   CheckError(parent_->AllocateTimer(timer));
320   return *this;
321 }
322 
InitWithTimer(Timer * timer)323 Stream &Stream::InitWithTimer(Timer *timer) {
324   VLOG_CALL(PARAM(timer));
325 
326   return Init().InitTimer(timer);
327 }
328 
ThenRecordEvent(Event * event)329 Stream &Stream::ThenRecordEvent(Event *event) {
330   VLOG_CALL(PARAM(event));
331 
332   port::Status status = parent_->RecordEvent(this, event);
333   if (!status.ok()) {
334     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
335                << "; not marking stream as bad, as the Event object may be "
336                << "at fault. Monitor for further errors.";
337   }
338 
339   return *this;
340 }
341 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)342 Stream &Stream::ThenBatchNormalizationForward(
343     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
344     const DeviceMemory<float> &offset,
345     const DeviceMemory<float> &estimated_mean,
346     const DeviceMemory<float> &estimated_variance,
347     const DeviceMemory<float> &side_input, const dnn::BatchDescriptor &x_desc,
348     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
349     const double exponential_average_factor,
350     dnn::ActivationMode activation_mode, DeviceMemory<float> *y,
351     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
352     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
353     bool is_training, ScratchAllocator *reserve_space_allocator,
354     ScratchAllocator *workspace_allocator) {
355   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
356             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
357   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
358     CheckError(dnn->DoBatchNormalizationForward(
359         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
360         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
361         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
362         is_training, reserve_space_allocator, workspace_allocator));
363   } else {
364     SetErrorAndLogNoDnnSupport();
365   }
366   return *this;
367 }
368 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)369 Stream &Stream::ThenBatchNormalizationBackward(
370     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
371     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
372     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
373     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
374     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
375     DeviceMemory<float> *offset_backprop,
376     DeviceMemory<uint8> *reserve_space_data,
377     ScratchAllocator *workspace_allocator) {
378   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
379             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
380             PARAM(scale_backprop), PARAM(offset_backprop));
381   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
382     CheckError(dnn->DoBatchNormalizationBackward(
383         this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
384         epsilon, x_backprop, scale_backprop, offset_backprop,
385         reserve_space_data, workspace_allocator));
386   } else {
387     SetErrorAndLogNoDnnSupport();
388   }
389   return *this;
390 }
391 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)392 Stream &Stream::ThenBatchNormalizationForward(
393     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
394     const DeviceMemory<float> &offset,
395     const DeviceMemory<float> &estimated_mean,
396     const DeviceMemory<float> &estimated_variance,
397     const DeviceMemory<Eigen::half> &side_input,
398     const dnn::BatchDescriptor &x_desc,
399     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
400     const double exponential_average_factor,
401     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half> *y,
402     DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
403     DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
404     bool is_training, ScratchAllocator *reserve_space_allocator,
405     ScratchAllocator *workspace_allocator) {
406   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
407             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
408   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
409     CheckError(dnn->DoBatchNormalizationForward(
410         this, x, scale, offset, estimated_mean, estimated_variance, side_input,
411         x_desc, scale_offset_desc, epsilon, exponential_average_factor,
412         activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
413         is_training, reserve_space_allocator, workspace_allocator));
414   } else {
415     SetErrorAndLogNoDnnSupport();
416   }
417   return *this;
418 }
419 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)420 Stream &Stream::ThenBatchNormalizationBackward(
421     const DeviceMemory<Eigen::half> &y_backprop,
422     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
423     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
424     const dnn::BatchDescriptor &x_desc,
425     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
426     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
427     DeviceMemory<float> *offset_backprop,
428     DeviceMemory<uint8> *reserve_space_data,
429     ScratchAllocator *workspace_allocator) {
430   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
431             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
432             PARAM(scale_backprop), PARAM(offset_backprop));
433   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
434     CheckError(dnn->DoBatchNormalizationBackward(
435         this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
436         epsilon, x_backprop, scale_backprop, offset_backprop,
437         reserve_space_data, workspace_allocator));
438 
439   } else {
440     SetErrorAndLogNoDnnSupport();
441   }
442   return *this;
443 }
444 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)445 Stream &Stream::ThenConvolve(
446     const dnn::BatchDescriptor &input_descriptor,
447     const DeviceMemory<float> &input_data,
448     const dnn::FilterDescriptor &filter_descriptor,
449     const DeviceMemory<float> &filter_data,
450     const dnn::ConvolutionDescriptor &convolution_descriptor,
451     const dnn::BatchDescriptor &output_descriptor,
452     DeviceMemory<float> *output) {
453   if (ok()) {
454     CheckError(ConvolveWithAlgorithm(
455                    input_descriptor, input_data, filter_descriptor, filter_data,
456                    convolution_descriptor, output_descriptor, output,
457                    /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
458                    /*output_profile_result=*/nullptr)
459                    .ok());
460   }
461   return *this;
462 }
463 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)464 Stream &Stream::ThenConvolveQuantized(
465     const dnn::BatchDescriptor &input_descriptor,
466     const DeviceMemory<float> &input_data,
467     const dnn::FilterDescriptor &filter_descriptor,
468     const DeviceMemory<int8> &filter_coefficients,
469     const DeviceMemory<float> &coefficient_scales,
470     const dnn::ConvolutionDescriptor &convolution_descriptor,
471     const dnn::BatchDescriptor &output_descriptor,
472     DeviceMemory<float> *output) {
473   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
474             PARAM(filter_descriptor), PARAM(filter_coefficients),
475             PARAM(coefficient_scales), PARAM(convolution_descriptor),
476             PARAM(output_descriptor), PARAM(output));
477 
478   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
479     CheckError(dnn->DoConvolveQuantized(
480         this, input_descriptor, input_data, filter_descriptor,
481         filter_coefficients, coefficient_scales, convolution_descriptor,
482         output_descriptor, output));
483   } else {
484     SetError();
485     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
486                     "without DNN support";
487   }
488   return *this;
489 }
490 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)491 Stream &Stream::ThenConvolveQuantized(
492     const dnn::BatchDescriptor &input_descriptor,
493     const DeviceMemory<float> &input_data,
494     const dnn::FilterDescriptor &filter_descriptor,
495     const DeviceMemory<int16> &filter_coefficients,
496     const DeviceMemory<float> &coefficient_scales,
497     const dnn::ConvolutionDescriptor &convolution_descriptor,
498     const dnn::BatchDescriptor &output_descriptor,
499     DeviceMemory<float> *output) {
500   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
501             PARAM(filter_descriptor), PARAM(filter_coefficients),
502             PARAM(coefficient_scales), PARAM(convolution_descriptor),
503             PARAM(output_descriptor), PARAM(output));
504 
505   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
506     CheckError(dnn->DoConvolveQuantized(
507         this, input_descriptor, input_data, filter_descriptor,
508         filter_coefficients, coefficient_scales, convolution_descriptor,
509         output_descriptor, output));
510   } else {
511     SetError();
512     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
513                     "without DNN support";
514   }
515   return *this;
516 }
517 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)518 Stream &Stream::ThenSeparableConvolve(
519     const dnn::BatchDescriptor &batch_descriptor,
520     const DeviceMemory<float> &input_data,
521     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
522     const DeviceMemory<float> &first_weights,
523     const DeviceMemory<float> &second_weights,
524     const dnn::ConvolutionDescriptor &convolution_descriptor,
525     const dnn::BatchDescriptor &output_descriptor,
526     DeviceMemory<float> *output) {
527   VLOG_CALL(
528       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
529       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
530       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
531 
532   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
533     CheckError(dnn->DoSeparableConvolve(
534         this, batch_descriptor, input_data, filter_descriptor, depth_multiplier,
535         first_weights, second_weights, convolution_descriptor,
536         output_descriptor, output));
537   } else {
538     SetErrorAndLogNoDnnSupport();
539   }
540   return *this;
541 }
542 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)543 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
544                            const DeviceMemory<float> &weights,
545                            const dnn::BatchDescriptor &input_dimensions,
546                            const dnn::BatchDescriptor &output_dimensions,
547                            DeviceMemory<float> *output_data) {
548   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
549             PARAM(output_dimensions), PARAM(output_data));
550 
551   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
552     CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
553                              output_dimensions, output_data));
554   } else {
555     SetErrorAndLogNoDnnSupport();
556   }
557   return *this;
558 }
559 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)560 Stream &Stream::ThenMatMulQuantized(
561     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
562     const DeviceMemory<float> &weight_scales,
563     const dnn::BatchDescriptor &input_dimensions,
564     const dnn::BatchDescriptor &output_dimensions,
565     DeviceMemory<float> *output_data) {
566   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
567             PARAM(input_dimensions), PARAM(output_dimensions),
568             PARAM(output_data));
569 
570   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
571     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
572                                       input_dimensions, output_dimensions,
573                                       output_data));
574   } else {
575     SetErrorAndLogNoDnnSupport();
576   }
577   return *this;
578 }
579 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)580 Stream &Stream::ThenMatMulQuantized(
581     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
582     const DeviceMemory<float> &weight_scales,
583     const dnn::BatchDescriptor &input_dimensions,
584     const dnn::BatchDescriptor &output_dimensions,
585     DeviceMemory<float> *output_data) {
586   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
587             PARAM(input_dimensions), PARAM(output_dimensions),
588             PARAM(output_data));
589 
590   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
591     CheckError(dnn->DoMatMulQuantized(this, input_data, weights, weight_scales,
592                                       input_dimensions, output_dimensions,
593                                       output_data));
594   } else {
595     SetErrorAndLogNoDnnSupport();
596   }
597   return *this;
598 }
599 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)600 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
601                             const DeviceMemory<float> &biases,
602                             const dnn::BatchDescriptor &dimensions,
603                             DeviceMemory<float> *output_data) {
604   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
605             PARAM(output_data));
606 
607   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
608     CheckError(
609         dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
610   } else {
611     SetErrorAndLogNoDnnSupport();
612   }
613   return *this;
614 }
615 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)616 Stream &Stream::ThenPoolForward(
617     const dnn::PoolingDescriptor &pooling_dimensions,
618     const dnn::BatchDescriptor &input_dimensions,
619     const DeviceMemory<double> &input_data,
620     const dnn::BatchDescriptor &output_dimensions,
621     DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
622   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
623             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
624             PARAM(workspace_allocator));
625 
626   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
627     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
628                                   input_data, output_dimensions, output_data,
629                                   workspace_allocator));
630   } else {
631     SetError();
632     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
633                     "without DNN support";
634   }
635   return *this;
636 }
637 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)638 Stream &Stream::ThenPoolForward(
639     const dnn::PoolingDescriptor &pooling_dimensions,
640     const dnn::BatchDescriptor &input_dimensions,
641     const DeviceMemory<float> &input_data,
642     const dnn::BatchDescriptor &output_dimensions,
643     DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
644   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
645             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
646             PARAM(workspace_allocator));
647 
648   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
649     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
650                                   input_data, output_dimensions, output_data,
651                                   workspace_allocator));
652   } else {
653     SetErrorAndLogNoDnnSupport();
654   }
655   return *this;
656 }
657 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)658 Stream &Stream::ThenPoolForward(
659     const dnn::PoolingDescriptor &pooling_dimensions,
660     const dnn::BatchDescriptor &input_dimensions,
661     const DeviceMemory<Eigen::half> &input_data,
662     const dnn::BatchDescriptor &output_dimensions,
663     DeviceMemory<Eigen::half> *output_data,
664     ScratchAllocator *workspace_allocator) {
665   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
666             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
667             PARAM(workspace_allocator));
668 
669   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
670     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
671                                   input_data, output_dimensions, output_data,
672                                   workspace_allocator));
673   } else {
674     SetErrorAndLogNoDnnSupport();
675   }
676   return *this;
677 }
678 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<int8> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<int8> * output_data,ScratchAllocator * workspace_allocator)679 Stream &Stream::ThenPoolForward(
680     const dnn::PoolingDescriptor &pooling_dimensions,
681     const dnn::BatchDescriptor &input_dimensions,
682     const DeviceMemory<int8> &input_data,
683     const dnn::BatchDescriptor &output_dimensions,
684     DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
685   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
686             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
687             PARAM(workspace_allocator));
688 
689   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
690     CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
691                                   input_data, output_dimensions, output_data,
692                                   workspace_allocator));
693   } else {
694     SetErrorAndLogNoDnnSupport();
695   }
696   return *this;
697 }
698 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)699 Stream &Stream::ThenPoolBackward(
700     const dnn::PoolingDescriptor &pooling_dimensions,
701     const dnn::BatchDescriptor &input_dimensions,
702     const DeviceMemory<double> &input_data,
703     const dnn::BatchDescriptor &output_dimensions,
704     const DeviceMemory<double> &output_data,
705     const DeviceMemory<double> &input_diff_data,
706     DeviceMemory<double> *output_diff_data,
707     ScratchAllocator *workspace_allocator) {
708   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
709             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
710             PARAM(input_diff_data), PARAM(output_diff_data),
711             PARAM(workspace_allocator));
712 
713   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
714     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
715                                    input_data, output_dimensions, output_data,
716                                    input_diff_data, output_diff_data,
717                                    workspace_allocator));
718   } else {
719     SetError();
720     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
721                     "without DNN support";
722   }
723   return *this;
724 }
725 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)726 Stream &Stream::ThenPoolBackward(
727     const dnn::PoolingDescriptor &pooling_dimensions,
728     const dnn::BatchDescriptor &input_dimensions,
729     const DeviceMemory<float> &input_data,
730     const dnn::BatchDescriptor &output_dimensions,
731     const DeviceMemory<float> &output_data,
732     const DeviceMemory<float> &input_diff_data,
733     DeviceMemory<float> *output_diff_data,
734     ScratchAllocator *workspace_allocator) {
735   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
736             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
737             PARAM(input_diff_data), PARAM(output_diff_data),
738             PARAM(workspace_allocator));
739 
740   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
741     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
742                                    input_data, output_dimensions, output_data,
743                                    input_diff_data, output_diff_data,
744                                    workspace_allocator));
745   } else {
746     SetErrorAndLogNoDnnSupport();
747   }
748   return *this;
749 }
750 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)751 Stream &Stream::ThenPoolBackward(
752     const dnn::PoolingDescriptor &pooling_dimensions,
753     const dnn::BatchDescriptor &input_dimensions,
754     const DeviceMemory<Eigen::half> &input_data,
755     const dnn::BatchDescriptor &output_dimensions,
756     const DeviceMemory<Eigen::half> &output_data,
757     const DeviceMemory<Eigen::half> &input_diff_data,
758     DeviceMemory<Eigen::half> *output_diff_data,
759     ScratchAllocator *workspace_allocator) {
760   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
761             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
762             PARAM(input_diff_data), PARAM(output_diff_data),
763             PARAM(workspace_allocator));
764 
765   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
766     CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
767                                    input_data, output_dimensions, output_data,
768                                    input_diff_data, output_diff_data,
769                                    workspace_allocator));
770   } else {
771     SetErrorAndLogNoDnnSupport();
772   }
773   return *this;
774 }
775 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)776 Stream &Stream::ThenNormalizeWithDimensions(
777     const dnn::NormalizeDescriptor &normalize_descriptor,
778     const dnn::BatchDescriptor &dimensions,
779     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
780   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
781             PARAM(output_data));
782 
783   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
784     CheckError(dnn->DoNormalizeWithDimensions(
785         this, normalize_descriptor, dimensions, input_data, output_data));
786   } else {
787     SetErrorAndLogNoDnnSupport();
788   }
789   return *this;
790 }
791 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)792 Stream &Stream::ThenNormalizeBackwardWithDimensions(
793     const dnn::NormalizeDescriptor &normalize_descriptor,
794     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
795     const DeviceMemory<float> &normalized_data,
796     const DeviceMemory<float> &normalized_variable_gradient,
797     DeviceMemory<float> *raw_variable_gradient,
798     ScratchAllocator *workspace_allocator) {
799   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
800             PARAM(normalized_data), PARAM(normalized_variable_gradient),
801             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
802 
803   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
804     CheckError(dnn->DoNormalizeBackwardWithDimensions(
805         this, normalize_descriptor, dimensions, raw_data, normalized_data,
806         normalized_variable_gradient, raw_variable_gradient,
807         workspace_allocator));
808   } else {
809     SetErrorAndLogNoDnnSupport();
810   }
811   return *this;
812 }
813 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)814 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
815                              const dnn::BatchDescriptor &dimensions,
816                              const DeviceMemory<float> &input_data,
817                              DeviceMemory<float> *output_data) {
818   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
819                                  output_data, /*options=*/0);
820 }
821 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)822 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
823                                         const dnn::BatchDescriptor &dimensions,
824                                         const DeviceMemory<float> &input_data,
825                                         DeviceMemory<float> *output_data,
826                                         uint64 options) {
827   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
828             PARAM(output_data), PARAM(options));
829 
830   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
831     CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
832                                output_data, options));
833   } else {
834     SetErrorAndLogNoDnnSupport();
835   }
836   return *this;
837 }
838 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)839 Stream &Stream::ThenDepthConcatenate(
840     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
841     port::ArraySlice<const DeviceMemory<float> *> input_data,
842     DeviceMemory<float> *output_data) {
843   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
844 
845   for (size_t i = 1; i < input_dimensions.size(); ++i) {
846     if (input_dimensions[i].count() != input_dimensions[0].count() ||
847         input_dimensions[i].height() != input_dimensions[0].height() ||
848         input_dimensions[i].width() != input_dimensions[0].width()) {
849       SetError();
850       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
851                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
852                  << "input_dimensions[" << i
853                  << "]: " << input_dimensions[i].ToString();
854       return *this;
855     }
856   }
857 
858   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
859     CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
860                                        output_data));
861   } else {
862     SetErrorAndLogNoDnnSupport();
863   }
864   return *this;
865 }
866 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)867 Stream &Stream::ThenSpaceConcatenate(
868     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
869     port::ArraySlice<const DeviceMemory<float> *> input_data,
870     DeviceMemory<float> *output_data,
871     dnn::SpaceConcatenateMode concat_direction) {
872   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
873 
874   // Check that the input dimensions of all the other batches match those of the
875   // first batch.
876   for (size_t i = 1; i < input_dimensions.size(); ++i) {
877     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
878         (input_dimensions[i].count() != input_dimensions[0].count() ||
879          input_dimensions[i].height() != input_dimensions[0].height() ||
880          input_dimensions[i].feature_map_count() !=
881              input_dimensions[0].feature_map_count())) {
882       SetError();
883       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
884                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
885                  << "input_dimensions[" << i
886                  << "]: " << input_dimensions[i].ToString();
887       return *this;
888     }
889 
890     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
891         (input_dimensions[i].count() != input_dimensions[0].count() ||
892          input_dimensions[i].width() != input_dimensions[0].width() ||
893          input_dimensions[i].feature_map_count() !=
894              input_dimensions[0].feature_map_count())) {
895       SetError();
896       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
897                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
898                  << "input_dimensions[" << i
899                  << "]: " << input_dimensions[i].ToString();
900       return *this;
901     }
902   }
903   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
904     CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
905                                        output_data, concat_direction));
906   } else {
907     SetErrorAndLogNoDnnSupport();
908   }
909   return *this;
910 }
911 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)912 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
913                             const DeviceMemory<float> &input_data,
914                             const dnn::BatchDescriptor &output_dimensions,
915                             DeviceMemory<float> *output_data) {
916   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
917             PARAM(output_dimensions), PARAM(output_data));
918 
919   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
920     CheckError(dnn->DoReshape(this, input_dimensions, input_data,
921                               output_dimensions, output_data));
922   } else {
923     SetErrorAndLogNoDnnSupport();
924   }
925   return *this;
926 }
927 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)928 Stream &Stream::ThenDepthToSpace(
929     const dnn::BatchDescriptor &input_dimensions,
930     const DeviceMemory<float> &input_data,
931     const dnn::DepthToSpaceLayout &depth_to_space_layout,
932     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
933   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
934             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
935             PARAM(output_data));
936 
937   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
938     CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
939                                    depth_to_space_layout, sqrt_depth_reduction,
940                                    output_data));
941   } else {
942     SetErrorAndLogNoDnnSupport();
943   }
944   return *this;
945 }
946 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)947 Stream &Stream::ThenSpaceToDepth(
948     const dnn::BatchDescriptor &input_dimensions,
949     const DeviceMemory<float> &input_data,
950     const dnn::DepthToSpaceLayout &space_to_depth_layout,
951     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
952   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
953             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
954             PARAM(output_data));
955 
956   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
957     CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
958                                    space_to_depth_layout, sqrt_depth_increase,
959                                    output_data));
960   } else {
961     SetErrorAndLogNoDnnSupport();
962   }
963   return *this;
964 }
965 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)966 Stream &Stream::ThenElementwiseOperate(
967     dnn::ElementwiseOperation operation,
968     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
969     port::ArraySlice<const DeviceMemory<float> *> input_data,
970     const dnn::BatchDescriptor &output_dimensions,
971     DeviceMemory<float> *output_data) {
972   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
973             PARAM(output_dimensions), PARAM(output_data));
974 
975   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
976     CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
977                                          input_data, output_dimensions,
978                                          output_data));
979   } else {
980     SetErrorAndLogNoDnnSupport();
981   }
982   return *this;
983 }
984 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)985 Stream &Stream::ThenElementwiseOperateScaledQuantized(
986     dnn::ElementwiseOperation operation,
987     port::ArraySlice<int> input_multiplicands, int output_divisor,
988     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
989     port::ArraySlice<const DeviceMemory<float> *> input_data,
990     const dnn::BatchDescriptor &output_dimensions,
991     DeviceMemory<float> *output_data) {
992   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
993             PARAM(input_dimensions), PARAM(input_data),
994             PARAM(output_dimensions), PARAM(output_data));
995 
996   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
997     CheckError(dnn->DoElementwiseOperateScaledQuantized(
998         this, operation, input_multiplicands, output_divisor, input_dimensions,
999         input_data, output_dimensions, output_data));
1000   } else {
1001     SetErrorAndLogNoDnnSupport();
1002   }
1003   return *this;
1004 }
1005 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_pad,int64_t right_pad,int64_t top_pad,int64_t bottom_pad,DeviceMemory<float> * output_data)1006 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1007                           const DeviceMemory<float> &input_data,
1008                           int64_t left_pad, int64_t right_pad, int64_t top_pad,
1009                           int64_t bottom_pad,
1010                           DeviceMemory<float> *output_data) {
1011   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1012             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1013             PARAM(output_data));
1014 
1015   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1016     CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1017                             top_pad, bottom_pad, output_data));
1018   } else {
1019     SetErrorAndLogNoDnnSupport();
1020   }
1021   return *this;
1022 }
1023 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t left_trim,int64_t right_trim,int64_t top_trim,int64_t bottom_trim,DeviceMemory<float> * output_data)1024 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1025                             const DeviceMemory<float> &input_data,
1026                             int64_t left_trim, int64_t right_trim,
1027                             int64_t top_trim, int64_t bottom_trim,
1028                             DeviceMemory<float> *output_data) {
1029   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1030             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1031             PARAM(output_data));
1032 
1033   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1034     CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1035                               right_trim, top_trim, bottom_trim, output_data));
1036   } else {
1037     SetErrorAndLogNoDnnSupport();
1038   }
1039   return *this;
1040 }
1041 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64_t replicate_x,int64_t replicate_y,DeviceMemory<float> * output_data)1042 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1043                                 const DeviceMemory<float> &input_data,
1044                                 int64_t replicate_x, int64_t replicate_y,
1045                                 DeviceMemory<float> *output_data) {
1046   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1047             PARAM(replicate_y), PARAM(output_data));
1048 
1049   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1050     CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1051                                   replicate_y, output_data));
1052   } else {
1053     SetErrorAndLogNoDnnSupport();
1054   }
1055   return *this;
1056 }
1057 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1058 Stream &Stream::ThenMemcpyD2HQuantized(
1059     const DeviceMemory<float> &gpu_unquantized_src,
1060     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1061   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1062             PARAM(size));
1063 
1064   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1065     CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1066                                          host_dst, size));
1067   } else {
1068     SetErrorAndLogNoDnnSupport();
1069   }
1070   return *this;
1071 }
1072 
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1073 Stream &Stream::ThenMemcpyH2DQuantized(
1074     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1075     DeviceMemory<float> *gpu_unquantized_dst) {
1076   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1077             PARAM(gpu_unquantized_dst));
1078 
1079   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1080     CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1081                                          gpu_unquantized_dst));
1082   } else {
1083     SetErrorAndLogNoDnnSupport();
1084   }
1085   return *this;
1086 }
1087 
GetOrCreateSubStream()1088 Stream *Stream::GetOrCreateSubStream() {
1089   // Do not destroy bad streams when holding mu_ because ~Stream() may
1090   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1091   std::vector<std::unique_ptr<Stream>> bad_streams;
1092 
1093   absl::MutexLock lock(&mu_);
1094 
1095   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
1096   // we encounter along the way.
1097   for (size_t index = 0; index < sub_streams_.size();) {
1098     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1099     if (pair.second) {
1100       // The sub_stream is reusable.
1101       Stream *sub_stream = pair.first.get();
1102       if (sub_stream->ok()) {
1103         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
1104                 << sub_stream->DebugStreamPointers();
1105         pair.second = false;
1106         return sub_stream;
1107       }
1108 
1109       // The stream is reusable and not ok. Streams have a monotonic state
1110       // machine; the stream will remain in !ok forever. Swap it with the last
1111       // stream and pop it off.
1112       const int64_t last = sub_streams_.size() - 1;
1113       if (index != last) {
1114         std::swap(pair, sub_streams_[last]);
1115       }
1116       bad_streams.push_back(std::move(sub_streams_.back().first));
1117       sub_streams_.pop_back();
1118       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
1119               << sub_stream->DebugStreamPointers();
1120     } else {
1121       // The sub_stream is not reusable, move on to the next one.
1122       ++index;
1123     }
1124   }
1125 
1126   // No streams are reusable; create a new stream.
1127   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1128                             false);
1129   Stream *sub_stream = sub_streams_.back().first.get();
1130   sub_stream->Init();
1131   if (!sub_stream->ok()) {
1132     LOG(ERROR) << "sub-stream failed to be initialized";
1133   }
1134   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
1135           << sub_stream->DebugStreamPointers();
1136 
1137   return sub_stream;
1138 }
1139 
ReturnSubStream(Stream * sub_stream)1140 void Stream::ReturnSubStream(Stream *sub_stream) {
1141   // Do not destroy bad streams when holding mu_ because ~Stream() may
1142   // BlockHostUntilDone and it's host callbacks might attempt to acquire mu_.
1143   std::unique_ptr<Stream> bad_stream;
1144 
1145   absl::MutexLock lock(&mu_);
1146 
1147   // Look for the sub-stream.
1148   for (int64_t index = 0, end = sub_streams_.size(); index < end; ++index) {
1149     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
1150     if (pair.first.get() != sub_stream) {
1151       continue;
1152     }
1153 
1154     // Found the sub_stream.
1155     if (sub_stream->ok()) {
1156       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
1157               << sub_stream->DebugStreamPointers();
1158       pair.second = true;
1159     } else {
1160       // The returned stream is not ok. Streams have a monotonic state
1161       // machine; the stream will remain in !ok forever. Swap it with the last
1162       // stream and pop it off.
1163       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
1164               << sub_stream->DebugStreamPointers();
1165       const int64_t last = sub_streams_.size() - 1;
1166       if (index != last) {
1167         std::swap(pair, sub_streams_[last]);
1168       }
1169       std::swap(bad_stream, sub_streams_.back().first);
1170       sub_streams_.pop_back();
1171     }
1172     return;
1173   }
1174 
1175   LOG(FATAL) << DebugStreamPointers()
1176              << " did not create the returned sub-stream "
1177              << sub_stream->DebugStreamPointers();
1178 }
1179 
ThenStartTimer(Timer * t)1180 Stream &Stream::ThenStartTimer(Timer *t) {
1181   VLOG_CALL(PARAM(t));
1182 
1183   CheckError(parent_->StartTimer(this, t));
1184   return *this;
1185 }
1186 
ThenStopTimer(Timer * t)1187 Stream &Stream::ThenStopTimer(Timer *t) {
1188   VLOG_CALL(PARAM(t));
1189 
1190   CheckError(parent_->StopTimer(this, t));
1191   return *this;
1192 }
1193 
ThenWaitFor(Stream * other)1194 Stream &Stream::ThenWaitFor(Stream *other) {
1195   VLOG_CALL(PARAM(other));
1196 
1197   CHECK(this != other) << "stream cannot wait for itself";
1198   if (ok() && other->ok()) {
1199     CheckError(parent_->CreateStreamDependency(this, other));
1200   } else {
1201     SetError();
1202     LOG(INFO) << DebugStreamPointers() << " did not wait for "
1203               << other->DebugStreamPointers();
1204   }
1205   return *this;
1206 }
1207 
ThenWaitFor(Event * event)1208 Stream &Stream::ThenWaitFor(Event *event) {
1209   VLOG_CALL(PARAM(event));
1210 
1211   if (ok()) {
1212     port::Status status = parent_->WaitForEvent(this, event);
1213     if (!status.ok()) {
1214       LOG(ERROR) << "Error waiting for event in stream: "
1215                  << status.error_message()
1216                  << "; not marking stream as bad, as the Event object may be "
1217                  << "at fault. Monitor for further errors.";
1218     }
1219   } else {
1220     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
1221   }
1222   return *this;
1223 }
1224 
1225 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1226 // functions and logs for errors.
1227 template <typename... Args>
1228 struct ThenBlasImpl {
1229   // blas_func is the DoBlasXXX member function pointer, and args are its
1230   // arguments except the first one of Stream* type.
operator ()stream_executor::ThenBlasImpl1231   Stream &operator()(Stream *stream,
1232                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1233                      Args... args) {
1234     return Run(stream, blas_func, /*record_error=*/true, args...);
1235   }
1236 
1237   // Like operator(), but only calls stream->CheckError() if record_error is
1238   // true.
1239   Stream &Run(Stream *stream,
1240               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1241               bool record_error, Args... args);
1242 };
1243 
1244 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1245 Stream &ThenBlasImpl<Args...>::Run(
1246     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1247     bool record_error, Args... args) {
1248   if (stream->ok()) {
1249     bool ok;
1250     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1251       ok = (blas->*blas_func)(stream, args...);
1252     } else {
1253       LOG(WARNING)
1254           << "attempting to perform BLAS operation using StreamExecutor "
1255              "without BLAS support";
1256       ok = false;
1257     }
1258     if (record_error) {
1259       stream->CheckError(ok);
1260     }
1261   }
1262   return *stream;
1263 }
1264 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1265 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1266                              int incx, DeviceMemory<float> *result) {
1267   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1268 
1269   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1270       impl;
1271   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1272               result);
1273 }
1274 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1275 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1276                              int incx, DeviceMemory<double> *result) {
1277   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1278 
1279   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1280                DeviceMemory<double> *>
1281       impl;
1282   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1283               result);
1284 }
1285 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1286 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1287                              const DeviceMemory<std::complex<float>> &x,
1288                              int incx, DeviceMemory<float> *result) {
1289   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1290 
1291   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1292                DeviceMemory<float> *>
1293       impl;
1294   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1295               result);
1296 }
1297 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1298 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1299                              const DeviceMemory<std::complex<double>> &x,
1300                              int incx, DeviceMemory<double> *result) {
1301   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1302 
1303   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1304                DeviceMemory<double> *>
1305       impl;
1306   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1307               result);
1308 }
1309 
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1310 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1311                              const DeviceMemory<float> &x, int incx,
1312                              DeviceMemory<float> *y, int incy) {
1313   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1314             PARAM(incy));
1315 
1316   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1317                DeviceMemory<float> *, int>
1318       impl;
1319   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1320               y, incy);
1321 }
1322 
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1323 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1324                              const DeviceMemory<double> &x, int incx,
1325                              DeviceMemory<double> *y, int incy) {
1326   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1327             PARAM(incy));
1328 
1329   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1330                DeviceMemory<double> *, int>
1331       impl;
1332   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1333               y, incy);
1334 }
1335 
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1336 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1337                              const DeviceMemory<std::complex<float>> &x,
1338                              int incx, DeviceMemory<std::complex<float>> *y,
1339                              int incy) {
1340   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1341             PARAM(incy));
1342 
1343   ThenBlasImpl<uint64, std::complex<float>,
1344                const DeviceMemory<std::complex<float>> &, int,
1345                DeviceMemory<std::complex<float>> *, int>
1346       impl;
1347   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1348               y, incy);
1349 }
1350 
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1351 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1352                              const DeviceMemory<std::complex<double>> &x,
1353                              int incx, DeviceMemory<std::complex<double>> *y,
1354                              int incy) {
1355   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1356             PARAM(incy));
1357 
1358   ThenBlasImpl<uint64, std::complex<double>,
1359                const DeviceMemory<std::complex<double>> &, int,
1360                DeviceMemory<std::complex<double>> *, int>
1361       impl;
1362   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1363               y, incy);
1364 }
1365 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1366 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
1367                              int incx, DeviceMemory<float> *y, int incy) {
1368   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1369 
1370   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
1371                int>
1372       impl;
1373   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1374               incy);
1375 }
1376 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1377 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
1378                              int incx, DeviceMemory<double> *y, int incy) {
1379   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1380 
1381   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1382                DeviceMemory<double> *, int>
1383       impl;
1384   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1385               incy);
1386 }
1387 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1388 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1389                              const DeviceMemory<std::complex<float>> &x,
1390                              int incx, DeviceMemory<std::complex<float>> *y,
1391                              int incy) {
1392   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1393 
1394   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1395                DeviceMemory<std::complex<float>> *, int>
1396       impl;
1397   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1398               incy);
1399 }
1400 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1401 Stream &Stream::ThenBlasCopy(uint64 elem_count,
1402                              const DeviceMemory<std::complex<double>> &x,
1403                              int incx, DeviceMemory<std::complex<double>> *y,
1404                              int incy) {
1405   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1406 
1407   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1408                DeviceMemory<std::complex<double>> *, int>
1409       impl;
1410   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
1411               incy);
1412 }
1413 
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)1414 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
1415                             int incx, const DeviceMemory<float> &y, int incy,
1416                             DeviceMemory<float> *result) {
1417   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1418             PARAM(result));
1419 
1420   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
1421                const DeviceMemory<float> &, int, DeviceMemory<float> *>
1422       impl;
1423   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1424               result);
1425 }
1426 
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)1427 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
1428                             int incx, const DeviceMemory<double> &y, int incy,
1429                             DeviceMemory<double> *result) {
1430   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1431             PARAM(result));
1432 
1433   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1434                const DeviceMemory<double> &, int, DeviceMemory<double> *>
1435       impl;
1436   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
1437               result);
1438 }
1439 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1440 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1441                              const DeviceMemory<std::complex<float>> &x,
1442                              int incx,
1443                              const DeviceMemory<std::complex<float>> &y,
1444                              int incy,
1445                              DeviceMemory<std::complex<float>> *result) {
1446   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1447             PARAM(result));
1448 
1449   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1450                const DeviceMemory<std::complex<float>> &, int,
1451                DeviceMemory<std::complex<float>> *>
1452       impl;
1453   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1454               incy, result);
1455 }
1456 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1457 Stream &Stream::ThenBlasDotc(uint64 elem_count,
1458                              const DeviceMemory<std::complex<double>> &x,
1459                              int incx,
1460                              const DeviceMemory<std::complex<double>> &y,
1461                              int incy,
1462                              DeviceMemory<std::complex<double>> *result) {
1463   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1464             PARAM(result));
1465 
1466   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1467                const DeviceMemory<std::complex<double>> &, int,
1468                DeviceMemory<std::complex<double>> *>
1469       impl;
1470   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
1471               incy, result);
1472 }
1473 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)1474 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1475                              const DeviceMemory<std::complex<float>> &x,
1476                              int incx,
1477                              const DeviceMemory<std::complex<float>> &y,
1478                              int incy,
1479                              DeviceMemory<std::complex<float>> *result) {
1480   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1481             PARAM(result));
1482 
1483   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1484                const DeviceMemory<std::complex<float>> &, int,
1485                DeviceMemory<std::complex<float>> *>
1486       impl;
1487   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1488               incy, result);
1489 }
1490 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)1491 Stream &Stream::ThenBlasDotu(uint64 elem_count,
1492                              const DeviceMemory<std::complex<double>> &x,
1493                              int incx,
1494                              const DeviceMemory<std::complex<double>> &y,
1495                              int incy,
1496                              DeviceMemory<std::complex<double>> *result) {
1497   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1498             PARAM(result));
1499 
1500   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1501                const DeviceMemory<std::complex<double>> &, int,
1502                DeviceMemory<std::complex<double>> *>
1503       impl;
1504   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
1505               incy, result);
1506 }
1507 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1508 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
1509                              int incx, DeviceMemory<float> *result) {
1510   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1511 
1512   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1513       impl;
1514   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1515               result);
1516 }
1517 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1518 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
1519                              int incx, DeviceMemory<double> *result) {
1520   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1521 
1522   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1523                DeviceMemory<double> *>
1524       impl;
1525   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1526               result);
1527 }
1528 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1529 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1530                              const DeviceMemory<std::complex<float>> &x,
1531                              int incx, DeviceMemory<float> *result) {
1532   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1533 
1534   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1535                DeviceMemory<float> *>
1536       impl;
1537   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1538               result);
1539 }
1540 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1541 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
1542                              const DeviceMemory<std::complex<double>> &x,
1543                              int incx, DeviceMemory<double> *result) {
1544   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1545 
1546   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1547                DeviceMemory<double> *>
1548       impl;
1549   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
1550               result);
1551 }
1552 
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)1553 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
1554                             DeviceMemory<float> *y, int incy, float c,
1555                             float s) {
1556   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1557             PARAM(c), PARAM(s));
1558 
1559   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1560                float, float>
1561       impl;
1562   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1563               c, s);
1564 }
1565 
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)1566 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
1567                             int incx, DeviceMemory<double> *y, int incy,
1568                             double c, double s) {
1569   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1570             PARAM(c), PARAM(s));
1571 
1572   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1573                double, double>
1574       impl;
1575   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1576               c, s);
1577 }
1578 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)1579 Stream &Stream::ThenBlasRot(uint64 elem_count,
1580                             DeviceMemory<std::complex<float>> *x, int incx,
1581                             DeviceMemory<std::complex<float>> *y, int incy,
1582                             float c, float s) {
1583   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1584             PARAM(c), PARAM(s));
1585 
1586   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1587                DeviceMemory<std::complex<float>> *, int, float, float>
1588       impl;
1589   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1590               c, s);
1591 }
1592 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)1593 Stream &Stream::ThenBlasRot(uint64 elem_count,
1594                             DeviceMemory<std::complex<double>> *x, int incx,
1595                             DeviceMemory<std::complex<double>> *y, int incy,
1596                             double c, double s) {
1597   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1598             PARAM(c), PARAM(s));
1599 
1600   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1601                DeviceMemory<std::complex<double>> *, int, double, double>
1602       impl;
1603   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
1604               c, s);
1605 }
1606 
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)1607 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
1608                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
1609   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1610 
1611   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1612                DeviceMemory<float> *, DeviceMemory<float> *>
1613       impl;
1614   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1615 }
1616 
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)1617 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
1618                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
1619   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1620 
1621   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1622                DeviceMemory<double> *, DeviceMemory<double> *>
1623       impl;
1624   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1625 }
1626 
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)1627 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
1628                              DeviceMemory<std::complex<float>> *b,
1629                              DeviceMemory<float> *c,
1630                              DeviceMemory<std::complex<float>> *s) {
1631   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1632 
1633   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
1634                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
1635                DeviceMemory<std::complex<float>> *>
1636       impl;
1637   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1638 }
1639 
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)1640 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
1641                              DeviceMemory<std::complex<double>> *b,
1642                              DeviceMemory<double> *c,
1643                              DeviceMemory<std::complex<double>> *s) {
1644   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
1645 
1646   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
1647                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
1648                DeviceMemory<std::complex<double>> *>
1649       impl;
1650   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
1651 }
1652 
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)1653 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
1654                              int incx, DeviceMemory<float> *y, int incy,
1655                              const DeviceMemory<float> &param) {
1656   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1657             PARAM(param));
1658 
1659   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
1660                const DeviceMemory<float> &>
1661       impl;
1662   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1663               incy, param);
1664 }
1665 
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)1666 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
1667                              int incx, DeviceMemory<double> *y, int incy,
1668                              const DeviceMemory<double> &param) {
1669   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
1670             PARAM(param));
1671 
1672   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
1673                const DeviceMemory<double> &>
1674       impl;
1675   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
1676               incy, param);
1677 }
1678 
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)1679 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
1680                               DeviceMemory<float> *x1,
1681                               const DeviceMemory<float> &y1,
1682                               DeviceMemory<float> *param) {
1683   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1684 
1685   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
1686                DeviceMemory<float> *, const DeviceMemory<float> &,
1687                DeviceMemory<float> *>
1688       impl;
1689   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1690 }
1691 
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)1692 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
1693                               DeviceMemory<double> *d2,
1694                               DeviceMemory<double> *x1,
1695                               const DeviceMemory<double> &y1,
1696                               DeviceMemory<double> *param) {
1697   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
1698 
1699   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
1700                DeviceMemory<double> *, const DeviceMemory<double> &,
1701                DeviceMemory<double> *>
1702       impl;
1703   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
1704 }
1705 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)1706 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1707                              DeviceMemory<float> *x, int incx) {
1708   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1709 
1710   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
1711   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1712 }
1713 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)1714 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1715                              DeviceMemory<double> *x, int incx) {
1716   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1717 
1718   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
1719   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1720 }
1721 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)1722 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
1723                              DeviceMemory<std::complex<float>> *x, int incx) {
1724   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1725 
1726   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
1727   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1728 }
1729 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)1730 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
1731                              DeviceMemory<std::complex<double>> *x, int incx) {
1732   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1733 
1734   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
1735   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1736 }
1737 
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)1738 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
1739                              DeviceMemory<std::complex<float>> *x, int incx) {
1740   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1741 
1742   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
1743                int>
1744       impl;
1745   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1746 }
1747 
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)1748 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
1749                              DeviceMemory<std::complex<double>> *x, int incx) {
1750   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
1751 
1752   ThenBlasImpl<uint64, std::complex<double>,
1753                DeviceMemory<std::complex<double>> *, int>
1754       impl;
1755   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
1756 }
1757 
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)1758 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
1759                              int incx, DeviceMemory<float> *y, int incy) {
1760   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1761 
1762   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
1763       impl;
1764   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1765               incy);
1766 }
1767 
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)1768 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
1769                              int incx, DeviceMemory<double> *y, int incy) {
1770   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1771 
1772   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
1773       impl;
1774   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1775               incy);
1776 }
1777 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1778 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1779                              DeviceMemory<std::complex<float>> *x, int incx,
1780                              DeviceMemory<std::complex<float>> *y, int incy) {
1781   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1782 
1783   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
1784                DeviceMemory<std::complex<float>> *, int>
1785       impl;
1786   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1787               incy);
1788 }
1789 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1790 Stream &Stream::ThenBlasSwap(uint64 elem_count,
1791                              DeviceMemory<std::complex<double>> *x, int incx,
1792                              DeviceMemory<std::complex<double>> *y, int incy) {
1793   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
1794 
1795   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
1796                DeviceMemory<std::complex<double>> *, int>
1797       impl;
1798   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
1799               incy);
1800 }
1801 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)1802 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
1803                               int incx, DeviceMemory<int> *result) {
1804   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1805 
1806   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
1807       impl;
1808   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1809               result);
1810 }
1811 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)1812 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
1813                               int incx, DeviceMemory<int> *result) {
1814   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1815 
1816   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
1817       impl;
1818   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1819               result);
1820 }
1821 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)1822 Stream &Stream::ThenBlasIamax(uint64 elem_count,
1823                               const DeviceMemory<std::complex<float>> &x,
1824                               int incx, DeviceMemory<int> *result) {
1825   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1826 
1827   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1828                DeviceMemory<int> *>
1829       impl;
1830   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1831               result);
1832 }
1833 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)1834 Stream &Stream::ThenBlasIamax(uint64 elem_count,
1835                               const DeviceMemory<std::complex<double>> &x,
1836                               int incx, DeviceMemory<int> *result) {
1837   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1838 
1839   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1840                DeviceMemory<int> *>
1841       impl;
1842   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
1843               result);
1844 }
1845 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)1846 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
1847                               int incx, DeviceMemory<int> *result) {
1848   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1849 
1850   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
1851       impl;
1852   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1853               result);
1854 }
1855 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)1856 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
1857                               int incx, DeviceMemory<int> *result) {
1858   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1859 
1860   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
1861       impl;
1862   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1863               result);
1864 }
1865 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)1866 Stream &Stream::ThenBlasIamin(uint64 elem_count,
1867                               const DeviceMemory<std::complex<float>> &x,
1868                               int incx, DeviceMemory<int> *result) {
1869   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1870 
1871   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1872                DeviceMemory<int> *>
1873       impl;
1874   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1875               result);
1876 }
1877 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)1878 Stream &Stream::ThenBlasIamin(uint64 elem_count,
1879                               const DeviceMemory<std::complex<double>> &x,
1880                               int incx, DeviceMemory<int> *result) {
1881   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1882 
1883   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1884                DeviceMemory<int> *>
1885       impl;
1886   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
1887               result);
1888 }
1889 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1890 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1891                              uint64 kl, uint64 ku, float alpha,
1892                              const DeviceMemory<float> &a, int lda,
1893                              const DeviceMemory<float> &x, int incx, float beta,
1894                              DeviceMemory<float> *y, int incy) {
1895   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1896             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1897             PARAM(beta), PARAM(y), PARAM(incy));
1898 
1899   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
1900                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1901                int, float, DeviceMemory<float> *, int>
1902       impl;
1903   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1904               a, lda, x, incx, beta, y, incy);
1905 }
1906 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1907 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1908                              uint64 kl, uint64 ku, double alpha,
1909                              const DeviceMemory<double> &a, int lda,
1910                              const DeviceMemory<double> &x, int incx,
1911                              double beta, DeviceMemory<double> *y, int incy) {
1912   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1913             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1914             PARAM(beta), PARAM(y), PARAM(incy));
1915 
1916   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
1917                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1918                int, double, DeviceMemory<double> *, int>
1919       impl;
1920   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1921               a, lda, x, incx, beta, y, incy);
1922 }
1923 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1924 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1925                              uint64 kl, uint64 ku, std::complex<float> alpha,
1926                              const DeviceMemory<std::complex<float>> &a,
1927                              int lda,
1928                              const DeviceMemory<std::complex<float>> &x,
1929                              int incx, std::complex<float> beta,
1930                              DeviceMemory<std::complex<float>> *y, int incy) {
1931   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1932             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1933             PARAM(beta), PARAM(y), PARAM(incy));
1934 
1935   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
1936                std::complex<float>, const DeviceMemory<std::complex<float>> &,
1937                int, const DeviceMemory<std::complex<float>> &, int,
1938                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
1939       impl;
1940   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1941               a, lda, x, incx, beta, y, incy);
1942 }
1943 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)1944 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
1945                              uint64 kl, uint64 ku, std::complex<double> alpha,
1946                              const DeviceMemory<std::complex<double>> &a,
1947                              int lda,
1948                              const DeviceMemory<std::complex<double>> &x,
1949                              int incx, std::complex<double> beta,
1950                              DeviceMemory<std::complex<double>> *y, int incy) {
1951   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
1952             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
1953             PARAM(beta), PARAM(y), PARAM(incy));
1954 
1955   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
1956                std::complex<double>, const DeviceMemory<std::complex<double>> &,
1957                int, const DeviceMemory<std::complex<double>> &, int,
1958                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
1959       impl;
1960   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
1961               a, lda, x, incx, beta, y, incy);
1962 }
1963 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)1964 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1965                              float alpha, const DeviceMemory<float> &a, int lda,
1966                              const DeviceMemory<float> &x, int incx, float beta,
1967                              DeviceMemory<float> *y, int incy) {
1968   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1969             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1970             PARAM(incy));
1971 
1972   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
1973                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
1974                int, float, DeviceMemory<float> *, int>
1975       impl;
1976   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1977               x, incx, beta, y, incy);
1978 }
1979 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)1980 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1981                              double alpha, const DeviceMemory<double> &a,
1982                              int lda, const DeviceMemory<double> &x, int incx,
1983                              double beta, DeviceMemory<double> *y, int incy) {
1984   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
1985             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
1986             PARAM(incy));
1987 
1988   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
1989                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
1990                int, double, DeviceMemory<double> *, int>
1991       impl;
1992   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
1993               x, incx, beta, y, incy);
1994 }
1995 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)1996 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
1997                              std::complex<float> alpha,
1998                              const DeviceMemory<std::complex<float>> &a,
1999                              int lda,
2000                              const DeviceMemory<std::complex<float>> &x,
2001                              int incx, std::complex<float> beta,
2002                              DeviceMemory<std::complex<float>> *y, int incy) {
2003   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2004             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2005             PARAM(incy));
2006 
2007   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2008                const DeviceMemory<std::complex<float>> &, int,
2009                const DeviceMemory<std::complex<float>> &, int,
2010                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2011       impl;
2012   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2013               x, incx, beta, y, incy);
2014 }
2015 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2016 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2017                              std::complex<double> alpha,
2018                              const DeviceMemory<std::complex<double>> &a,
2019                              int lda,
2020                              const DeviceMemory<std::complex<double>> &x,
2021                              int incx, std::complex<double> beta,
2022                              DeviceMemory<std::complex<double>> *y, int incy) {
2023   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2024             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2025             PARAM(incy));
2026 
2027   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2028                const DeviceMemory<std::complex<double>> &, int,
2029                const DeviceMemory<std::complex<double>> &, int,
2030                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2031       impl;
2032   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2033               x, incx, beta, y, incy);
2034 }
2035 
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2036 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2037                             const DeviceMemory<float> &x, int incx,
2038                             const DeviceMemory<float> &y, int incy,
2039                             DeviceMemory<float> *a, int lda) {
2040   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2041             PARAM(incy), PARAM(a), PARAM(lda));
2042 
2043   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2044                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2045       impl;
2046   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2047               incy, a, lda);
2048 }
2049 
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2050 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2051                             const DeviceMemory<double> &x, int incx,
2052                             const DeviceMemory<double> &y, int incy,
2053                             DeviceMemory<double> *a, int lda) {
2054   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2055             PARAM(incy), PARAM(a), PARAM(lda));
2056 
2057   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2058                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2059       impl;
2060   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2061               incy, a, lda);
2062 }
2063 
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2064 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2065                              const DeviceMemory<std::complex<float>> &x,
2066                              int incx,
2067                              const DeviceMemory<std::complex<float>> &y,
2068                              int incy, DeviceMemory<std::complex<float>> *a,
2069                              int lda) {
2070   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2071             PARAM(incy), PARAM(a), PARAM(lda));
2072 
2073   ThenBlasImpl<uint64, uint64, std::complex<float>,
2074                const DeviceMemory<std::complex<float>> &, int,
2075                const DeviceMemory<std::complex<float>> &, int,
2076                DeviceMemory<std::complex<float>> *, int>
2077       impl;
2078   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2079               incy, a, lda);
2080 }
2081 
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2082 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2083                              const DeviceMemory<std::complex<double>> &x,
2084                              int incx,
2085                              const DeviceMemory<std::complex<double>> &y,
2086                              int incy, DeviceMemory<std::complex<double>> *a,
2087                              int lda) {
2088   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2089             PARAM(incy), PARAM(a), PARAM(lda));
2090 
2091   ThenBlasImpl<uint64, uint64, std::complex<double>,
2092                const DeviceMemory<std::complex<double>> &, int,
2093                const DeviceMemory<std::complex<double>> &, int,
2094                DeviceMemory<std::complex<double>> *, int>
2095       impl;
2096   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2097               incy, a, lda);
2098 }
2099 
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2100 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2101                              const DeviceMemory<std::complex<float>> &x,
2102                              int incx,
2103                              const DeviceMemory<std::complex<float>> &y,
2104                              int incy, DeviceMemory<std::complex<float>> *a,
2105                              int lda) {
2106   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2107             PARAM(incy), PARAM(a), PARAM(lda));
2108 
2109   ThenBlasImpl<uint64, uint64, std::complex<float>,
2110                const DeviceMemory<std::complex<float>> &, int,
2111                const DeviceMemory<std::complex<float>> &, int,
2112                DeviceMemory<std::complex<float>> *, int>
2113       impl;
2114   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2115               incy, a, lda);
2116 }
2117 
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2118 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2119                              const DeviceMemory<std::complex<double>> &x,
2120                              int incx,
2121                              const DeviceMemory<std::complex<double>> &y,
2122                              int incy, DeviceMemory<std::complex<double>> *a,
2123                              int lda) {
2124   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2125             PARAM(incy), PARAM(a), PARAM(lda));
2126 
2127   ThenBlasImpl<uint64, uint64, std::complex<double>,
2128                const DeviceMemory<std::complex<double>> &, int,
2129                const DeviceMemory<std::complex<double>> &, int,
2130                DeviceMemory<std::complex<double>> *, int>
2131       impl;
2132   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2133               incy, a, lda);
2134 }
2135 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2136 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2137                              std::complex<float> alpha,
2138                              const DeviceMemory<std::complex<float>> &a,
2139                              int lda,
2140                              const DeviceMemory<std::complex<float>> &x,
2141                              int incx, std::complex<float> beta,
2142                              DeviceMemory<std::complex<float>> *y, int incy) {
2143   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2144             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2145 
2146   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2147                const DeviceMemory<std::complex<float>> &, int,
2148                const DeviceMemory<std::complex<float>> &, int,
2149                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2150       impl;
2151   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2152               x, incx, beta, y, incy);
2153 }
2154 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2155 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2156                              std::complex<double> alpha,
2157                              const DeviceMemory<std::complex<double>> &a,
2158                              int lda,
2159                              const DeviceMemory<std::complex<double>> &x,
2160                              int incx, std::complex<double> beta,
2161                              DeviceMemory<std::complex<double>> *y, int incy) {
2162   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2163             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2164 
2165   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2166                const DeviceMemory<std::complex<double>> &, int,
2167                const DeviceMemory<std::complex<double>> &, int,
2168                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2169       impl;
2170   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2171               x, incx, beta, y, incy);
2172 }
2173 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2174 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2175                              std::complex<float> alpha,
2176                              const DeviceMemory<std::complex<float>> &a,
2177                              int lda,
2178                              const DeviceMemory<std::complex<float>> &x,
2179                              int incx, std::complex<float> beta,
2180                              DeviceMemory<std::complex<float>> *y, int incy) {
2181   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2182             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2183 
2184   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2185                const DeviceMemory<std::complex<float>> &, int,
2186                const DeviceMemory<std::complex<float>> &, int,
2187                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2188       impl;
2189   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2190               incx, beta, y, incy);
2191 }
2192 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2193 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2194                              std::complex<double> alpha,
2195                              const DeviceMemory<std::complex<double>> &a,
2196                              int lda,
2197                              const DeviceMemory<std::complex<double>> &x,
2198                              int incx, std::complex<double> beta,
2199                              DeviceMemory<std::complex<double>> *y, int incy) {
2200   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2201             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2202 
2203   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2204                const DeviceMemory<std::complex<double>> &, int,
2205                const DeviceMemory<std::complex<double>> &, int,
2206                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2207       impl;
2208   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2209               incx, beta, y, incy);
2210 }
2211 
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2212 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2213                             const DeviceMemory<std::complex<float>> &x,
2214                             int incx, DeviceMemory<std::complex<float>> *a,
2215                             int lda) {
2216   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2217             PARAM(a), PARAM(lda));
2218 
2219   ThenBlasImpl<blas::UpperLower, uint64, float,
2220                const DeviceMemory<std::complex<float>> &, int,
2221                DeviceMemory<std::complex<float>> *, int>
2222       impl;
2223   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2224               lda);
2225 }
2226 
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2227 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2228                             const DeviceMemory<std::complex<double>> &x,
2229                             int incx, DeviceMemory<std::complex<double>> *a,
2230                             int lda) {
2231   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2232             PARAM(a), PARAM(lda));
2233 
2234   ThenBlasImpl<blas::UpperLower, uint64, double,
2235                const DeviceMemory<std::complex<double>> &, int,
2236                DeviceMemory<std::complex<double>> *, int>
2237       impl;
2238   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2239               lda);
2240 }
2241 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2242 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2243                              std::complex<float> alpha,
2244                              const DeviceMemory<std::complex<float>> &x,
2245                              int incx,
2246                              const DeviceMemory<std::complex<float>> &y,
2247                              int incy, DeviceMemory<std::complex<float>> *a,
2248                              int lda) {
2249   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2250             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2251 
2252   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2253                const DeviceMemory<std::complex<float>> &, int,
2254                const DeviceMemory<std::complex<float>> &, int,
2255                DeviceMemory<std::complex<float>> *, int>
2256       impl;
2257   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2258               incy, a, lda);
2259 }
2260 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2261 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2262                              std::complex<double> alpha,
2263                              const DeviceMemory<std::complex<double>> &x,
2264                              int incx,
2265                              const DeviceMemory<std::complex<double>> &y,
2266                              int incy, DeviceMemory<std::complex<double>> *a,
2267                              int lda) {
2268   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2269             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2270 
2271   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2272                const DeviceMemory<std::complex<double>> &, int,
2273                const DeviceMemory<std::complex<double>> &, int,
2274                DeviceMemory<std::complex<double>> *, int>
2275       impl;
2276   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2277               incy, a, lda);
2278 }
2279 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2280 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2281                              std::complex<float> alpha,
2282                              const DeviceMemory<std::complex<float>> &ap,
2283                              const DeviceMemory<std::complex<float>> &x,
2284                              int incx, std::complex<float> beta,
2285                              DeviceMemory<std::complex<float>> *y, int incy) {
2286   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2287             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2288 
2289   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2290                const DeviceMemory<std::complex<float>> &,
2291                const DeviceMemory<std::complex<float>> &, int,
2292                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
2293       impl;
2294   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2295               beta, y, incy);
2296 }
2297 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2298 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2299                              std::complex<double> alpha,
2300                              const DeviceMemory<std::complex<double>> &ap,
2301                              const DeviceMemory<std::complex<double>> &x,
2302                              int incx, std::complex<double> beta,
2303                              DeviceMemory<std::complex<double>> *y, int incy) {
2304   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2305             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2306 
2307   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2308                const DeviceMemory<std::complex<double>> &,
2309                const DeviceMemory<std::complex<double>> &, int,
2310                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
2311       impl;
2312   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2313               beta, y, incy);
2314 }
2315 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2316 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2317                             const DeviceMemory<std::complex<float>> &x,
2318                             int incx, DeviceMemory<std::complex<float>> *ap) {
2319   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2320             PARAM(ap));
2321 
2322   ThenBlasImpl<blas::UpperLower, uint64, float,
2323                const DeviceMemory<std::complex<float>> &, int,
2324                DeviceMemory<std::complex<float>> *>
2325       impl;
2326   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2327 }
2328 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2329 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2330                             const DeviceMemory<std::complex<double>> &x,
2331                             int incx, DeviceMemory<std::complex<double>> *ap) {
2332   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2333             PARAM(ap));
2334 
2335   ThenBlasImpl<blas::UpperLower, uint64, double,
2336                const DeviceMemory<std::complex<double>> &, int,
2337                DeviceMemory<std::complex<double>> *>
2338       impl;
2339   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2340 }
2341 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2342 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2343                              std::complex<float> alpha,
2344                              const DeviceMemory<std::complex<float>> &x,
2345                              int incx,
2346                              const DeviceMemory<std::complex<float>> &y,
2347                              int incy, DeviceMemory<std::complex<float>> *ap) {
2348   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2349             PARAM(y), PARAM(incy), PARAM(ap));
2350 
2351   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2352                const DeviceMemory<std::complex<float>> &, int,
2353                const DeviceMemory<std::complex<float>> &, int,
2354                DeviceMemory<std::complex<float>> *>
2355       impl;
2356   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2357               incy, ap);
2358 }
2359 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2360 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2361                              std::complex<double> alpha,
2362                              const DeviceMemory<std::complex<double>> &x,
2363                              int incx,
2364                              const DeviceMemory<std::complex<double>> &y,
2365                              int incy, DeviceMemory<std::complex<double>> *ap) {
2366   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2367             PARAM(y), PARAM(incy), PARAM(ap));
2368 
2369   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2370                const DeviceMemory<std::complex<double>> &, int,
2371                const DeviceMemory<std::complex<double>> &, int,
2372                DeviceMemory<std::complex<double>> *>
2373       impl;
2374   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2375               incy, ap);
2376 }
2377 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2378 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2379                              float alpha, const DeviceMemory<float> &a, int lda,
2380                              const DeviceMemory<float> &x, int incx, float beta,
2381                              DeviceMemory<float> *y, int incy) {
2382   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2383             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2384 
2385   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2386                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2387                int, float, DeviceMemory<float> *, int>
2388       impl;
2389   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2390               x, incx, beta, y, incy);
2391 }
2392 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2393 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2394                              double alpha, const DeviceMemory<double> &a,
2395                              int lda, const DeviceMemory<double> &x, int incx,
2396                              double beta, DeviceMemory<double> *y, int incy) {
2397   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2398             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2399 
2400   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2401                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2402                int, double, DeviceMemory<double> *, int>
2403       impl;
2404   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2405               x, incx, beta, y, incy);
2406 }
2407 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2408 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2409                              const DeviceMemory<float> &ap,
2410                              const DeviceMemory<float> &x, int incx, float beta,
2411                              DeviceMemory<float> *y, int incy) {
2412   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2413             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2414 
2415   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2416                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
2417                int>
2418       impl;
2419   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2420               beta, y, incy);
2421 }
2422 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2423 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
2424                              const DeviceMemory<double> &ap,
2425                              const DeviceMemory<double> &x, int incx,
2426                              double beta, DeviceMemory<double> *y, int incy) {
2427   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2428             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2429 
2430   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2431                const DeviceMemory<double> &, int, double,
2432                DeviceMemory<double> *, int>
2433       impl;
2434   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
2435               beta, y, incy);
2436 }
2437 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)2438 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
2439                             const DeviceMemory<float> &x, int incx,
2440                             DeviceMemory<float> *ap) {
2441   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2442             PARAM(ap));
2443 
2444   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2445                int, DeviceMemory<float> *>
2446       impl;
2447   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2448 }
2449 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)2450 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
2451                             const DeviceMemory<double> &x, int incx,
2452                             DeviceMemory<double> *ap) {
2453   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2454             PARAM(ap));
2455 
2456   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2457                int, DeviceMemory<double> *>
2458       impl;
2459   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
2460 }
2461 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)2462 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
2463                              const DeviceMemory<float> &x, int incx,
2464                              const DeviceMemory<float> &y, int incy,
2465                              DeviceMemory<float> *ap) {
2466   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2467             PARAM(y), PARAM(incy), PARAM(ap));
2468 
2469   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2470                int, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2471       impl;
2472   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2473               incy, ap);
2474 }
2475 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)2476 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
2477                              const DeviceMemory<double> &x, int incx,
2478                              const DeviceMemory<double> &y, int incy,
2479                              DeviceMemory<double> *ap) {
2480   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2481             PARAM(y), PARAM(incy), PARAM(ap));
2482 
2483   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2484                int, const DeviceMemory<double> &, int, DeviceMemory<double> *>
2485       impl;
2486   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
2487               incy, ap);
2488 }
2489 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2490 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
2491                              const DeviceMemory<float> &a, int lda,
2492                              const DeviceMemory<float> &x, int incx, float beta,
2493                              DeviceMemory<float> *y, int incy) {
2494   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2495             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2496 
2497   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2498                int, const DeviceMemory<float> &, int, float,
2499                DeviceMemory<float> *, int>
2500       impl;
2501   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2502               incx, beta, y, incy);
2503 }
2504 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2505 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
2506                              const DeviceMemory<double> &a, int lda,
2507                              const DeviceMemory<double> &x, int incx,
2508                              double beta, DeviceMemory<double> *y, int incy) {
2509   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2510             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2511 
2512   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2513                int, const DeviceMemory<double> &, int, double,
2514                DeviceMemory<double> *, int>
2515       impl;
2516   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
2517               incx, beta, y, incy);
2518 }
2519 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)2520 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
2521                             const DeviceMemory<float> &x, int incx,
2522                             DeviceMemory<float> *a, int lda) {
2523   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2524             PARAM(a), PARAM(lda));
2525 
2526   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2527                int, DeviceMemory<float> *, int>
2528       impl;
2529   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2530               lda);
2531 }
2532 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)2533 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
2534                             const DeviceMemory<double> &x, int incx,
2535                             DeviceMemory<double> *a, int lda) {
2536   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2537             PARAM(a), PARAM(lda));
2538 
2539   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2540                int, DeviceMemory<double> *, int>
2541       impl;
2542   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
2543               lda);
2544 }
2545 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2546 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
2547                              const DeviceMemory<float> &x, int incx,
2548                              const DeviceMemory<float> &y, int incy,
2549                              DeviceMemory<float> *a, int lda) {
2550   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2551             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2552 
2553   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
2554                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2555                int>
2556       impl;
2557   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2558               incy, a, lda);
2559 }
2560 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2561 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
2562                              const DeviceMemory<double> &x, int incx,
2563                              const DeviceMemory<double> &y, int incy,
2564                              DeviceMemory<double> *a, int lda) {
2565   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2566             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2567 
2568   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
2569                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
2570                int>
2571       impl;
2572   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
2573               incy, a, lda);
2574 }
2575 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2576 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2577                              blas::Diagonal diag, uint64 n, uint64 k,
2578                              const DeviceMemory<float> &a, int lda,
2579                              DeviceMemory<float> *x, int incx) {
2580   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2581             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2582 
2583   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2584                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2585                int>
2586       impl;
2587   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2588               lda, x, incx);
2589 }
2590 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2591 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2592                              blas::Diagonal diag, uint64 n, uint64 k,
2593                              const DeviceMemory<double> &a, int lda,
2594                              DeviceMemory<double> *x, int incx) {
2595   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2596             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2597 
2598   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2599                uint64, const DeviceMemory<double> &, int,
2600                DeviceMemory<double> *, int>
2601       impl;
2602   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2603               lda, x, incx);
2604 }
2605 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2606 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2607                              blas::Diagonal diag, uint64 n, uint64 k,
2608                              const DeviceMemory<std::complex<float>> &a,
2609                              int lda, DeviceMemory<std::complex<float>> *x,
2610                              int incx) {
2611   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2612             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2613 
2614   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2615                uint64, const DeviceMemory<std::complex<float>> &, int,
2616                DeviceMemory<std::complex<float>> *, int>
2617       impl;
2618   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2619               lda, x, incx);
2620 }
2621 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2622 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
2623                              blas::Diagonal diag, uint64 n, uint64 k,
2624                              const DeviceMemory<std::complex<double>> &a,
2625                              int lda, DeviceMemory<std::complex<double>> *x,
2626                              int incx) {
2627   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2628             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2629 
2630   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2631                uint64, const DeviceMemory<std::complex<double>> &, int,
2632                DeviceMemory<std::complex<double>> *, int>
2633       impl;
2634   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
2635               lda, x, incx);
2636 }
2637 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2638 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2639                              blas::Diagonal diag, uint64 n, uint64 k,
2640                              const DeviceMemory<float> &a, int lda,
2641                              DeviceMemory<float> *x, int incx) {
2642   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2643             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2644 
2645   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2646                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2647                int>
2648       impl;
2649   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2650               lda, x, incx);
2651 }
2652 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2653 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2654                              blas::Diagonal diag, uint64 n, uint64 k,
2655                              const DeviceMemory<double> &a, int lda,
2656                              DeviceMemory<double> *x, int incx) {
2657   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2658             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2659 
2660   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2661                uint64, const DeviceMemory<double> &, int,
2662                DeviceMemory<double> *, int>
2663       impl;
2664   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2665               lda, x, incx);
2666 }
2667 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2668 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2669                              blas::Diagonal diag, uint64 n, uint64 k,
2670                              const DeviceMemory<std::complex<float>> &a,
2671                              int lda, DeviceMemory<std::complex<float>> *x,
2672                              int incx) {
2673   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2674             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2675 
2676   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2677                uint64, const DeviceMemory<std::complex<float>> &, int,
2678                DeviceMemory<std::complex<float>> *, int>
2679       impl;
2680   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2681               lda, x, incx);
2682 }
2683 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2684 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
2685                              blas::Diagonal diag, uint64 n, uint64 k,
2686                              const DeviceMemory<std::complex<double>> &a,
2687                              int lda, DeviceMemory<std::complex<double>> *x,
2688                              int incx) {
2689   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
2690             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
2691 
2692   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2693                uint64, const DeviceMemory<std::complex<double>> &, int,
2694                DeviceMemory<std::complex<double>> *, int>
2695       impl;
2696   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
2697               lda, x, incx);
2698 }
2699 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2700 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2701                              blas::Diagonal diag, uint64 n,
2702                              const DeviceMemory<float> &ap,
2703                              DeviceMemory<float> *x, int incx) {
2704   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2705             PARAM(x), PARAM(incx));
2706 
2707   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2708                const DeviceMemory<float> &, DeviceMemory<float> *, int>
2709       impl;
2710   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2711               incx);
2712 }
2713 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2714 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2715                              blas::Diagonal diag, uint64 n,
2716                              const DeviceMemory<double> &ap,
2717                              DeviceMemory<double> *x, int incx) {
2718   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2719             PARAM(x), PARAM(incx));
2720 
2721   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2722                const DeviceMemory<double> &, DeviceMemory<double> *, int>
2723       impl;
2724   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2725               incx);
2726 }
2727 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2728 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2729                              blas::Diagonal diag, uint64 n,
2730                              const DeviceMemory<std::complex<float>> &ap,
2731                              DeviceMemory<std::complex<float>> *x, int incx) {
2732   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2733             PARAM(x), PARAM(incx));
2734 
2735   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2736                const DeviceMemory<std::complex<float>> &,
2737                DeviceMemory<std::complex<float>> *, int>
2738       impl;
2739   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2740               incx);
2741 }
2742 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2743 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
2744                              blas::Diagonal diag, uint64 n,
2745                              const DeviceMemory<std::complex<double>> &ap,
2746                              DeviceMemory<std::complex<double>> *x, int incx) {
2747   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2748             PARAM(x), PARAM(incx));
2749 
2750   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2751                const DeviceMemory<std::complex<double>> &,
2752                DeviceMemory<std::complex<double>> *, int>
2753       impl;
2754   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
2755               incx);
2756 }
2757 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)2758 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2759                              blas::Diagonal diag, uint64 n,
2760                              const DeviceMemory<float> &ap,
2761                              DeviceMemory<float> *x, int incx) {
2762   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2763             PARAM(x), PARAM(incx));
2764 
2765   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2766                const DeviceMemory<float> &, DeviceMemory<float> *, int>
2767       impl;
2768   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2769               incx);
2770 }
2771 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)2772 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2773                              blas::Diagonal diag, uint64 n,
2774                              const DeviceMemory<double> &ap,
2775                              DeviceMemory<double> *x, int incx) {
2776   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2777             PARAM(x), PARAM(incx));
2778 
2779   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2780                const DeviceMemory<double> &, DeviceMemory<double> *, int>
2781       impl;
2782   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2783               incx);
2784 }
2785 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)2786 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2787                              blas::Diagonal diag, uint64 n,
2788                              const DeviceMemory<std::complex<float>> &ap,
2789                              DeviceMemory<std::complex<float>> *x, int incx) {
2790   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2791             PARAM(x), PARAM(incx));
2792 
2793   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2794                const DeviceMemory<std::complex<float>> &,
2795                DeviceMemory<std::complex<float>> *, int>
2796       impl;
2797   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2798               incx);
2799 }
2800 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)2801 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
2802                              blas::Diagonal diag, uint64 n,
2803                              const DeviceMemory<std::complex<double>> &ap,
2804                              DeviceMemory<std::complex<double>> *x, int incx) {
2805   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
2806             PARAM(x), PARAM(incx));
2807 
2808   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2809                const DeviceMemory<std::complex<double>> &,
2810                DeviceMemory<std::complex<double>> *, int>
2811       impl;
2812   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
2813               incx);
2814 }
2815 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2816 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2817                              blas::Diagonal diag, uint64 n,
2818                              const DeviceMemory<float> &a, int lda,
2819                              DeviceMemory<float> *x, int incx) {
2820   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2821             PARAM(lda), PARAM(x), PARAM(incx));
2822 
2823   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2824                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2825       impl;
2826   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2827               lda, x, incx);
2828 }
2829 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2830 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2831                              blas::Diagonal diag, uint64 n,
2832                              const DeviceMemory<double> &a, int lda,
2833                              DeviceMemory<double> *x, int incx) {
2834   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2835             PARAM(lda), PARAM(x), PARAM(incx));
2836 
2837   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2838                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2839       impl;
2840   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2841               lda, x, incx);
2842 }
2843 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2844 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2845                              blas::Diagonal diag, uint64 n,
2846                              const DeviceMemory<std::complex<float>> &a,
2847                              int lda, DeviceMemory<std::complex<float>> *x,
2848                              int incx) {
2849   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2850             PARAM(lda), PARAM(x), PARAM(incx));
2851 
2852   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2853                const DeviceMemory<std::complex<float>> &, int,
2854                DeviceMemory<std::complex<float>> *, int>
2855       impl;
2856   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2857               lda, x, incx);
2858 }
2859 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2860 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
2861                              blas::Diagonal diag, uint64 n,
2862                              const DeviceMemory<std::complex<double>> &a,
2863                              int lda, DeviceMemory<std::complex<double>> *x,
2864                              int incx) {
2865   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2866             PARAM(lda), PARAM(x), PARAM(incx));
2867 
2868   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2869                const DeviceMemory<std::complex<double>> &, int,
2870                DeviceMemory<std::complex<double>> *, int>
2871       impl;
2872   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
2873               lda, x, incx);
2874 }
2875 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)2876 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2877                              blas::Diagonal diag, uint64 n,
2878                              const DeviceMemory<float> &a, int lda,
2879                              DeviceMemory<float> *x, int incx) {
2880   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2881             PARAM(lda), PARAM(x), PARAM(incx));
2882 
2883   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2884                const DeviceMemory<float> &, int, DeviceMemory<float> *, int>
2885       impl;
2886   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2887               lda, x, incx);
2888 }
2889 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)2890 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2891                              blas::Diagonal diag, uint64 n,
2892                              const DeviceMemory<double> &a, int lda,
2893                              DeviceMemory<double> *x, int incx) {
2894   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2895             PARAM(lda), PARAM(x), PARAM(incx));
2896 
2897   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2898                const DeviceMemory<double> &, int, DeviceMemory<double> *, int>
2899       impl;
2900   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2901               lda, x, incx);
2902 }
2903 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)2904 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2905                              blas::Diagonal diag, uint64 n,
2906                              const DeviceMemory<std::complex<float>> &a,
2907                              int lda, DeviceMemory<std::complex<float>> *x,
2908                              int incx) {
2909   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2910             PARAM(lda), PARAM(x), PARAM(incx));
2911 
2912   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2913                const DeviceMemory<std::complex<float>> &, int,
2914                DeviceMemory<std::complex<float>> *, int>
2915       impl;
2916   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2917               lda, x, incx);
2918 }
2919 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)2920 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
2921                              blas::Diagonal diag, uint64 n,
2922                              const DeviceMemory<std::complex<double>> &a,
2923                              int lda, DeviceMemory<std::complex<double>> *x,
2924                              int incx) {
2925   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
2926             PARAM(lda), PARAM(x), PARAM(incx));
2927 
2928   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
2929                const DeviceMemory<std::complex<double>> &, int,
2930                DeviceMemory<std::complex<double>> *, int>
2931       impl;
2932   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
2933               lda, x, incx);
2934 }
2935 
2936 namespace {
2937 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
2938 // blas::ProfileResult*.  This functor doesn't put the stream into an error
2939 // state if the op fails and the profile result is non-null.  Instead, the
2940 // error-ness is returned in the profile result itself.
2941 template <typename... Args>
2942 struct ThenBlasWithProfileImpl {
operator ()stream_executor::__anon9675edea0211::ThenBlasWithProfileImpl2943   Stream &operator()(Stream *stream,
2944                      bool (blas::BlasSupport::*blas_func)(
2945                          Stream *, Args..., blas::ProfileResult *),
2946                      Args... args, blas::ProfileResult *profile_result) {
2947     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
2948     bool record_error = profile_result == nullptr;
2949     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
2950   }
2951 };
2952 }  // anonymous namespace
2953 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)2954 Stream &Stream::ThenBlasGemvWithProfiling(
2955     blas::Transpose trans, uint64 m, uint64 n, float alpha,
2956     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
2957     int incx, float beta, DeviceMemory<float> *y, int incy,
2958     blas::ProfileResult *output_profile_result) {
2959   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2960             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2961             PARAM(incy));
2962 
2963   ThenBlasWithProfileImpl<
2964       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
2965       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
2966       impl;
2967   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
2968               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
2969 }
2970 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)2971 Stream &Stream::ThenBlasGemvWithProfiling(
2972     blas::Transpose trans, uint64 m, uint64 n, double alpha,
2973     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
2974     int incx, double beta, DeviceMemory<double> *y, int incy,
2975     blas::ProfileResult *output_profile_result) {
2976   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2977             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2978             PARAM(incy));
2979 
2980   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
2981                           const DeviceMemory<double> &, int,
2982                           const DeviceMemory<double> &, int, double,
2983                           DeviceMemory<double> *, int>
2984       impl;
2985   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
2986               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
2987 }
2988 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)2989 Stream &Stream::ThenBlasGemvWithProfiling(
2990     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
2991     const DeviceMemory<std::complex<float>> &a, int lda,
2992     const DeviceMemory<std::complex<float>> &x, int incx,
2993     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
2994     blas::ProfileResult *output_profile_result) {
2995   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2996             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2997             PARAM(incy));
2998 
2999   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3000                           const DeviceMemory<std::complex<float>> &, int,
3001                           const DeviceMemory<std::complex<float>> &, int,
3002                           std::complex<float>,
3003                           DeviceMemory<std::complex<float>> *, int>
3004       impl;
3005   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3006               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3007 }
3008 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3009 Stream &Stream::ThenBlasGemvWithProfiling(
3010     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3011     const DeviceMemory<std::complex<double>> &a, int lda,
3012     const DeviceMemory<std::complex<double>> &x, int incx,
3013     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3014     blas::ProfileResult *output_profile_result) {
3015   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3016             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3017             PARAM(incy));
3018 
3019   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3020                           const DeviceMemory<std::complex<double>> &, int,
3021                           const DeviceMemory<std::complex<double>> &, int,
3022                           std::complex<double>,
3023                           DeviceMemory<std::complex<double>> *, int>
3024       impl;
3025   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3026               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3027 }
3028 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3029 Stream &Stream::ThenBlasGemmWithProfiling(
3030     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3031     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3032     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3033     DeviceMemory<Eigen::half> *c, int ldc,
3034     blas::ProfileResult *output_profile_result) {
3035   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3036             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3037             PARAM(beta), PARAM(c), PARAM(ldc));
3038 
3039   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3040                           uint64, float, const DeviceMemory<Eigen::half> &, int,
3041                           const DeviceMemory<Eigen::half> &, int, float,
3042                           DeviceMemory<Eigen::half> *, int>
3043       impl;
3044   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3045               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3046               output_profile_result);
3047 }
3048 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3049 Stream &Stream::ThenBlasGemmWithProfiling(
3050     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3051     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3052     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3053     int ldc, blas::ProfileResult *output_profile_result) {
3054   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3055             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3056             PARAM(beta), PARAM(c), PARAM(ldc));
3057 
3058   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3059                           uint64, float, const DeviceMemory<float> &, int,
3060                           const DeviceMemory<float> &, int, float,
3061                           DeviceMemory<float> *, int>
3062       impl;
3063   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3064               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3065               output_profile_result);
3066 }
3067 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3068 Stream &Stream::ThenBlasGemmWithProfiling(
3069     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3070     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3071     const DeviceMemory<double> &b, int ldb, double beta,
3072     DeviceMemory<double> *c, int ldc,
3073     blas::ProfileResult *output_profile_result) {
3074   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3075             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3076             PARAM(beta), PARAM(c), PARAM(ldc));
3077 
3078   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3079                           uint64, double, const DeviceMemory<double> &, int,
3080                           const DeviceMemory<double> &, int, double,
3081                           DeviceMemory<double> *, int>
3082       impl;
3083   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3084               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3085               output_profile_result);
3086 }
3087 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3088 Stream &Stream::ThenBlasGemmWithProfiling(
3089     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3090     uint64 k, std::complex<float> alpha,
3091     const DeviceMemory<std::complex<float>> &a, int lda,
3092     const DeviceMemory<std::complex<float>> &b, int ldb,
3093     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3094     blas::ProfileResult *output_profile_result) {
3095   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3096             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3097             PARAM(beta), PARAM(c), PARAM(ldc));
3098 
3099   ThenBlasWithProfileImpl<
3100       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3101       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3102       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3103       DeviceMemory<std::complex<float>> *, int>
3104       impl;
3105   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3106               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3107               output_profile_result);
3108 }
3109 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3110 Stream &Stream::ThenBlasGemmWithProfiling(
3111     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3112     uint64 k, std::complex<double> alpha,
3113     const DeviceMemory<std::complex<double>> &a, int lda,
3114     const DeviceMemory<std::complex<double>> &b, int ldb,
3115     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3116     blas::ProfileResult *output_profile_result) {
3117   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3118             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3119             PARAM(beta), PARAM(c), PARAM(ldc));
3120 
3121   ThenBlasWithProfileImpl<
3122       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3123       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3124       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3125       DeviceMemory<std::complex<double>> *, int>
3126       impl;
3127   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3128               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3129               output_profile_result);
3130 }
3131 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3132 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3133                              uint64 n, std::complex<float> alpha,
3134                              const DeviceMemory<std::complex<float>> &a,
3135                              int lda,
3136                              const DeviceMemory<std::complex<float>> &b,
3137                              int ldb, std::complex<float> beta,
3138                              DeviceMemory<std::complex<float>> *c, int ldc) {
3139   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3140             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3141             PARAM(ldc));
3142 
3143   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3144                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3145                int, const DeviceMemory<std::complex<float>> &, int,
3146                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3147       impl;
3148   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3149               lda, b, ldb, beta, c, ldc);
3150 }
3151 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3152 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3153                              uint64 n, std::complex<double> alpha,
3154                              const DeviceMemory<std::complex<double>> &a,
3155                              int lda,
3156                              const DeviceMemory<std::complex<double>> &b,
3157                              int ldb, std::complex<double> beta,
3158                              DeviceMemory<std::complex<double>> *c, int ldc) {
3159   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3160             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3161             PARAM(ldc));
3162 
3163   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3164                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3165                int, const DeviceMemory<std::complex<double>> &, int,
3166                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3167       impl;
3168   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3169               lda, b, ldb, beta, c, ldc);
3170 }
3171 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3172 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3173                              uint64 n, uint64 k, float alpha,
3174                              const DeviceMemory<std::complex<float>> &a,
3175                              int lda, float beta,
3176                              DeviceMemory<std::complex<float>> *c, int ldc) {
3177   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3178             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3179 
3180   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3181                const DeviceMemory<std::complex<float>> &, int, float,
3182                DeviceMemory<std::complex<float>> *, int>
3183       impl;
3184   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3185               lda, beta, c, ldc);
3186 }
3187 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3188 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3189                              uint64 n, uint64 k, double alpha,
3190                              const DeviceMemory<std::complex<double>> &a,
3191                              int lda, double beta,
3192                              DeviceMemory<std::complex<double>> *c, int ldc) {
3193   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3194             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3195 
3196   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3197                const DeviceMemory<std::complex<double>> &, int, double,
3198                DeviceMemory<std::complex<double>> *, int>
3199       impl;
3200   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3201               lda, beta, c, ldc);
3202 }
3203 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3204 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3205                               uint64 n, uint64 k, std::complex<float> alpha,
3206                               const DeviceMemory<std::complex<float>> &a,
3207                               int lda,
3208                               const DeviceMemory<std::complex<float>> &b,
3209                               int ldb, float beta,
3210                               DeviceMemory<std::complex<float>> *c, int ldc) {
3211   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3212             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3213             PARAM(ldc));
3214 
3215   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3216                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3217                int, const DeviceMemory<std::complex<float>> &, int, float,
3218                DeviceMemory<std::complex<float>> *, int>
3219       impl;
3220   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3221               a, lda, b, ldb, beta, c, ldc);
3222 }
3223 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3224 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3225                               uint64 n, uint64 k, std::complex<double> alpha,
3226                               const DeviceMemory<std::complex<double>> &a,
3227                               int lda,
3228                               const DeviceMemory<std::complex<double>> &b,
3229                               int ldb, double beta,
3230                               DeviceMemory<std::complex<double>> *c, int ldc) {
3231   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3232             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3233             PARAM(ldc));
3234 
3235   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3236                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3237                int, const DeviceMemory<std::complex<double>> &, int, double,
3238                DeviceMemory<std::complex<double>> *, int>
3239       impl;
3240   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
3241               a, lda, b, ldb, beta, c, ldc);
3242 }
3243 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3244 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3245                              uint64 n, float alpha,
3246                              const DeviceMemory<float> &a, int lda,
3247                              const DeviceMemory<float> &b, int ldb, float beta,
3248                              DeviceMemory<float> *c, int ldc) {
3249   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3250             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3251             PARAM(ldc));
3252 
3253   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
3254                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3255                int, float, DeviceMemory<float> *, int>
3256       impl;
3257   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3258               lda, b, ldb, beta, c, ldc);
3259 }
3260 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3261 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3262                              uint64 n, double alpha,
3263                              const DeviceMemory<double> &a, int lda,
3264                              const DeviceMemory<double> &b, int ldb,
3265                              double beta, DeviceMemory<double> *c, int ldc) {
3266   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3267             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3268             PARAM(ldc));
3269 
3270   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
3271                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3272                int, double, DeviceMemory<double> *, int>
3273       impl;
3274   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3275               lda, b, ldb, beta, c, ldc);
3276 }
3277 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3278 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3279                              uint64 n, std::complex<float> alpha,
3280                              const DeviceMemory<std::complex<float>> &a,
3281                              int lda,
3282                              const DeviceMemory<std::complex<float>> &b,
3283                              int ldb, std::complex<float> beta,
3284                              DeviceMemory<std::complex<float>> *c, int ldc) {
3285   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3286             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3287             PARAM(ldc));
3288 
3289   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3290                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3291                int, const DeviceMemory<std::complex<float>> &, int,
3292                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3293       impl;
3294   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3295               lda, b, ldb, beta, c, ldc);
3296 }
3297 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3298 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
3299                              uint64 n, std::complex<double> alpha,
3300                              const DeviceMemory<std::complex<double>> &a,
3301                              int lda,
3302                              const DeviceMemory<std::complex<double>> &b,
3303                              int ldb, std::complex<double> beta,
3304                              DeviceMemory<std::complex<double>> *c, int ldc) {
3305   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3306             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3307             PARAM(ldc));
3308 
3309   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3310                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3311                int, const DeviceMemory<std::complex<double>> &, int,
3312                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3313       impl;
3314   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
3315               lda, b, ldb, beta, c, ldc);
3316 }
3317 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)3318 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3319                              uint64 n, uint64 k, float alpha,
3320                              const DeviceMemory<float> &a, int lda, float beta,
3321                              DeviceMemory<float> *c, int ldc) {
3322   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3323             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3324 
3325   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3326                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3327                int>
3328       impl;
3329   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3330               lda, beta, c, ldc);
3331 }
3332 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)3333 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3334                              uint64 n, uint64 k, double alpha,
3335                              const DeviceMemory<double> &a, int lda,
3336                              double beta, DeviceMemory<double> *c, int ldc) {
3337   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3338             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3339 
3340   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3341                const DeviceMemory<double> &, int, double,
3342                DeviceMemory<double> *, int>
3343       impl;
3344   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3345               lda, beta, c, ldc);
3346 }
3347 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3348 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3349                              uint64 n, uint64 k, std::complex<float> alpha,
3350                              const DeviceMemory<std::complex<float>> &a,
3351                              int lda, std::complex<float> beta,
3352                              DeviceMemory<std::complex<float>> *c, int ldc) {
3353   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3354             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3355 
3356   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3357                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3358                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
3359                int>
3360       impl;
3361   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3362               lda, beta, c, ldc);
3363 }
3364 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3365 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
3366                              uint64 n, uint64 k, std::complex<double> alpha,
3367                              const DeviceMemory<std::complex<double>> &a,
3368                              int lda, std::complex<double> beta,
3369                              DeviceMemory<std::complex<double>> *c, int ldc) {
3370   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3371             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3372 
3373   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3374                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3375                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
3376                int>
3377       impl;
3378   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
3379               lda, beta, c, ldc);
3380 }
3381 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3382 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3383                               uint64 n, uint64 k, float alpha,
3384                               const DeviceMemory<float> &a, int lda,
3385                               const DeviceMemory<float> &b, int ldb, float beta,
3386                               DeviceMemory<float> *c, int ldc) {
3387   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3388             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3389             PARAM(ldc));
3390 
3391   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3392                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3393                int, float, DeviceMemory<float> *, int>
3394       impl;
3395   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3396               a, lda, b, ldb, beta, c, ldc);
3397 }
3398 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3399 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3400                               uint64 n, uint64 k, double alpha,
3401                               const DeviceMemory<double> &a, int lda,
3402                               const DeviceMemory<double> &b, int ldb,
3403                               double beta, DeviceMemory<double> *c, int ldc) {
3404   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3405             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3406             PARAM(ldc));
3407 
3408   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3409                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3410                int, double, DeviceMemory<double> *, int>
3411       impl;
3412   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3413               a, lda, b, ldb, beta, c, ldc);
3414 }
3415 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3416 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3417                               uint64 n, uint64 k, std::complex<float> alpha,
3418                               const DeviceMemory<std::complex<float>> &a,
3419                               int lda,
3420                               const DeviceMemory<std::complex<float>> &b,
3421                               int ldb, std::complex<float> beta,
3422                               DeviceMemory<std::complex<float>> *c, int ldc) {
3423   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3424             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3425             PARAM(ldc));
3426 
3427   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3428                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3429                int, const DeviceMemory<std::complex<float>> &, int,
3430                std::complex<float>, DeviceMemory<std::complex<float>> *, int>
3431       impl;
3432   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3433               a, lda, b, ldb, beta, c, ldc);
3434 }
3435 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3436 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
3437                               uint64 n, uint64 k, std::complex<double> alpha,
3438                               const DeviceMemory<std::complex<double>> &a,
3439                               int lda,
3440                               const DeviceMemory<std::complex<double>> &b,
3441                               int ldb, std::complex<double> beta,
3442                               DeviceMemory<std::complex<double>> *c, int ldc) {
3443   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3444             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3445             PARAM(ldc));
3446 
3447   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
3448                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3449                int, const DeviceMemory<std::complex<double>> &, int,
3450                std::complex<double>, DeviceMemory<std::complex<double>> *, int>
3451       impl;
3452   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
3453               a, lda, b, ldb, beta, c, ldc);
3454 }
3455 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3456 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3457                              blas::Transpose transa, blas::Diagonal diag,
3458                              uint64 m, uint64 n, float alpha,
3459                              const DeviceMemory<float> &a, int lda,
3460                              DeviceMemory<float> *b, int ldb) {
3461   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3462             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3463 
3464   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3465                uint64, uint64, float, const DeviceMemory<float> &, int,
3466                DeviceMemory<float> *, int>
3467       impl;
3468   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3469               n, alpha, a, lda, b, ldb);
3470 }
3471 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3472 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3473                              blas::Transpose transa, blas::Diagonal diag,
3474                              uint64 m, uint64 n, double alpha,
3475                              const DeviceMemory<double> &a, int lda,
3476                              DeviceMemory<double> *b, int ldb) {
3477   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3478             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3479 
3480   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3481                uint64, uint64, double, const DeviceMemory<double> &, int,
3482                DeviceMemory<double> *, int>
3483       impl;
3484   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3485               n, alpha, a, lda, b, ldb);
3486 }
3487 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3488 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3489                              blas::Transpose transa, blas::Diagonal diag,
3490                              uint64 m, uint64 n, std::complex<float> alpha,
3491                              const DeviceMemory<std::complex<float>> &a,
3492                              int lda, DeviceMemory<std::complex<float>> *b,
3493                              int ldb) {
3494   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3495             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3496 
3497   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3498                uint64, uint64, std::complex<float>,
3499                const DeviceMemory<std::complex<float>> &, int,
3500                DeviceMemory<std::complex<float>> *, int>
3501       impl;
3502   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3503               n, alpha, a, lda, b, ldb);
3504 }
3505 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3506 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
3507                              blas::Transpose transa, blas::Diagonal diag,
3508                              uint64 m, uint64 n, std::complex<double> alpha,
3509                              const DeviceMemory<std::complex<double>> &a,
3510                              int lda, DeviceMemory<std::complex<double>> *b,
3511                              int ldb) {
3512   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3513             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3514 
3515   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3516                uint64, uint64, std::complex<double>,
3517                const DeviceMemory<std::complex<double>> &, int,
3518                DeviceMemory<std::complex<double>> *, int>
3519       impl;
3520   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
3521               n, alpha, a, lda, b, ldb);
3522 }
3523 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)3524 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3525                              blas::Transpose transa, blas::Diagonal diag,
3526                              uint64 m, uint64 n, float alpha,
3527                              const DeviceMemory<float> &a, int lda,
3528                              DeviceMemory<float> *b, int ldb) {
3529   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3530             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3531 
3532   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3533                uint64, uint64, float, const DeviceMemory<float> &, int,
3534                DeviceMemory<float> *, int>
3535       impl;
3536   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3537               n, alpha, a, lda, b, ldb);
3538 }
3539 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)3540 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3541                              blas::Transpose transa, blas::Diagonal diag,
3542                              uint64 m, uint64 n, double alpha,
3543                              const DeviceMemory<double> &a, int lda,
3544                              DeviceMemory<double> *b, int ldb) {
3545   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3546             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3547 
3548   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3549                uint64, uint64, double, const DeviceMemory<double> &, int,
3550                DeviceMemory<double> *, int>
3551       impl;
3552   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3553               n, alpha, a, lda, b, ldb);
3554 }
3555 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)3556 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3557                              blas::Transpose transa, blas::Diagonal diag,
3558                              uint64 m, uint64 n, std::complex<float> alpha,
3559                              const DeviceMemory<std::complex<float>> &a,
3560                              int lda, DeviceMemory<std::complex<float>> *b,
3561                              int ldb) {
3562   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3563             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3564 
3565   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3566                uint64, uint64, std::complex<float>,
3567                const DeviceMemory<std::complex<float>> &, int,
3568                DeviceMemory<std::complex<float>> *, int>
3569       impl;
3570   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3571               n, alpha, a, lda, b, ldb);
3572 }
3573 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)3574 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
3575                              blas::Transpose transa, blas::Diagonal diag,
3576                              uint64 m, uint64 n, std::complex<double> alpha,
3577                              const DeviceMemory<std::complex<double>> &a,
3578                              int lda, DeviceMemory<std::complex<double>> *b,
3579                              int ldb) {
3580   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
3581             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
3582 
3583   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
3584                uint64, uint64, std::complex<double>,
3585                const DeviceMemory<std::complex<double>> &, int,
3586                DeviceMemory<std::complex<double>> *, int>
3587       impl;
3588   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
3589               n, alpha, a, lda, b, ldb);
3590 }
3591 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count)3592 Stream &Stream::ThenBlasGemmBatched(
3593     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3594     uint64 k, float alpha,
3595     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
3596     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
3597     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
3598     int batch_count) {
3599   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3600                                         b, ldb, beta, c, ldc, batch_count,
3601                                         /*scratch_allocator=*/nullptr);
3602 }
3603 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<Eigen::half> * > & a,int lda,const port::ArraySlice<DeviceMemory<Eigen::half> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<Eigen::half> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3604 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3605     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3606     uint64 k, float alpha,
3607     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
3608     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
3609     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
3610     int batch_count, ScratchAllocator *scratch_allocator) {
3611   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3612             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3613             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3614 
3615   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3616                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
3617                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
3618                float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
3619                int, int, ScratchAllocator *>
3620       impl;
3621   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3622               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3623               scratch_allocator);
3624 }
3625 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)3626 Stream &Stream::ThenBlasGemmBatched(
3627     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3628     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
3629     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
3630     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
3631     int batch_count) {
3632   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3633                                         b, ldb, beta, c, ldc, batch_count,
3634                                         /*scratch_allocator=*/nullptr);
3635 }
3636 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3637 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3638     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3639     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
3640     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
3641     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
3642     int batch_count, ScratchAllocator *scratch_allocator) {
3643   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3644             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3645             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3646 
3647   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3648                const port::ArraySlice<DeviceMemory<float> *> &, int,
3649                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
3650                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
3651                ScratchAllocator *>
3652       impl;
3653   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3654               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3655               scratch_allocator);
3656 }
3657 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)3658 Stream &Stream::ThenBlasGemmBatched(
3659     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3660     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
3661     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
3662     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
3663     int batch_count) {
3664   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3665                                         b, ldb, beta, c, ldc, batch_count,
3666                                         /*scratch_allocator=*/nullptr);
3667 }
3668 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3669 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3670     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3671     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
3672     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
3673     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
3674     int batch_count, ScratchAllocator *scratch_allocator) {
3675   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3676             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3677             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3678 
3679   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3680                const port::ArraySlice<DeviceMemory<double> *> &, int,
3681                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
3682                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
3683                ScratchAllocator *>
3684       impl;
3685   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3686               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3687               scratch_allocator);
3688 }
3689 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)3690 Stream &Stream::ThenBlasGemmBatched(
3691     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3692     uint64 k, std::complex<float> alpha,
3693     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
3694     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
3695     std::complex<float> beta,
3696     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
3697     int batch_count) {
3698   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3699                                         b, ldb, beta, c, ldc, batch_count,
3700                                         /*scratch_allocator=*/nullptr);
3701 }
3702 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3703 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3704     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3705     uint64 k, std::complex<float> alpha,
3706     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
3707     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
3708     std::complex<float> beta,
3709     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
3710     int batch_count, ScratchAllocator *scratch_allocator) {
3711   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3712             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3713             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3714 
3715   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3716                std::complex<float>,
3717                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3718                int,
3719                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3720                int, std::complex<float>,
3721                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
3722                int, int, ScratchAllocator *>
3723       impl;
3724   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3725               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3726               scratch_allocator);
3727 }
3728 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)3729 Stream &Stream::ThenBlasGemmBatched(
3730     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3731     uint64 k, std::complex<double> alpha,
3732     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
3733     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
3734     std::complex<double> beta,
3735     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
3736     int batch_count) {
3737   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
3738                                         b, ldb, beta, c, ldc, batch_count,
3739                                         /*scratch_allocator=*/nullptr);
3740 }
3741 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)3742 Stream &Stream::ThenBlasGemmBatchedWithScratch(
3743     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3744     uint64 k, std::complex<double> alpha,
3745     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
3746     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
3747     std::complex<double> beta,
3748     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
3749     int batch_count, ScratchAllocator *scratch_allocator) {
3750   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3751             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3752             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
3753 
3754   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3755                std::complex<double>,
3756                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3757                int,
3758                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3759                int, std::complex<double>,
3760                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
3761                int, int, ScratchAllocator *>
3762       impl;
3763   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
3764               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
3765               scratch_allocator);
3766 }
3767 
3768 template <typename ABType, typename CType>
ThenBlasLtMatmulImpl(const blas::IBlasLtMatmulPlan * plan,const HostOrDeviceScalar<CType> & alpha,const DeviceMemory<ABType> & a,const DeviceMemory<ABType> & b,const HostOrDeviceScalar<CType> & beta,DeviceMemory<CType> * c,ScratchAllocator * scratch_allocator,const blas::IBlasLtMatmulAlgorithm * algorithm,const DeviceMemory<CType> & bias,blas::ProfileResult * output_profile_result)3769 Stream &Stream::ThenBlasLtMatmulImpl(
3770     const blas::IBlasLtMatmulPlan *plan, const HostOrDeviceScalar<CType> &alpha,
3771     const DeviceMemory<ABType> &a, const DeviceMemory<ABType> &b,
3772     const HostOrDeviceScalar<CType> &beta, DeviceMemory<CType> *c,
3773     ScratchAllocator *scratch_allocator,
3774     const blas::IBlasLtMatmulAlgorithm *algorithm,
3775     const DeviceMemory<CType> &bias,
3776     blas::ProfileResult *output_profile_result) {
3777   VLOG_CALL(PARAM(plan), PARAM(alpha), PARAM(a), PARAM(b), PARAM(beta),
3778             PARAM(c), PARAM(algorithm), PARAM(bias));
3779 
3780   ThenBlasWithProfileImpl<
3781       const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<CType> &,
3782       const DeviceMemory<ABType> &, const DeviceMemory<ABType> &,
3783       const HostOrDeviceScalar<CType> &, DeviceMemory<CType> *,
3784       ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3785       const DeviceMemory<CType> &>
3786       impl;
3787   return impl(this, &blas::BlasSupport::DoBlasLtMatmul, plan, alpha, a, b, beta,
3788               c, scratch_allocator, algorithm, bias, output_profile_result);
3789 }
3790 
3791 // Explicit template instantiations for each supported type combination.
3792 template Stream &Stream::ThenBlasLtMatmulImpl<int8, int32>(
3793     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<int32> &,
3794     const DeviceMemory<int8> &, const DeviceMemory<int8> &,
3795     const HostOrDeviceScalar<int32> &, DeviceMemory<int32> *,
3796     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3797     const DeviceMemory<int32> &, blas::ProfileResult *);
3798 
3799 template Stream &Stream::ThenBlasLtMatmulImpl<Eigen::half, Eigen::half>(
3800     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<Eigen::half> &,
3801     const DeviceMemory<Eigen::half> &, const DeviceMemory<Eigen::half> &,
3802     const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
3803     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3804     const DeviceMemory<Eigen::half> &, blas::ProfileResult *);
3805 
3806 template Stream &Stream::ThenBlasLtMatmulImpl<float, float>(
3807     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<float> &,
3808     const DeviceMemory<float> &, const DeviceMemory<float> &,
3809     const HostOrDeviceScalar<float> &, DeviceMemory<float> *,
3810     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3811     const DeviceMemory<float> &, blas::ProfileResult *);
3812 
3813 template Stream &Stream::ThenBlasLtMatmulImpl<double, double>(
3814     const blas::IBlasLtMatmulPlan *, const HostOrDeviceScalar<double> &,
3815     const DeviceMemory<double> &, const DeviceMemory<double> &,
3816     const HostOrDeviceScalar<double> &, DeviceMemory<double> *,
3817     ScratchAllocator *, const blas::IBlasLtMatmulAlgorithm *,
3818     const DeviceMemory<double> &, blas::ProfileResult *);
3819 
3820 template Stream &
3821 Stream::ThenBlasLtMatmulImpl<std::complex<float>, std::complex<float>>(
3822     const blas::IBlasLtMatmulPlan *,
3823     const HostOrDeviceScalar<std::complex<float>> &,
3824     const DeviceMemory<std::complex<float>> &,
3825     const DeviceMemory<std::complex<float>> &,
3826     const HostOrDeviceScalar<std::complex<float>> &,
3827     DeviceMemory<std::complex<float>> *, ScratchAllocator *,
3828     const blas::IBlasLtMatmulAlgorithm *,
3829     const DeviceMemory<std::complex<float>> &, blas::ProfileResult *);
3830 
3831 template Stream &
3832 Stream::ThenBlasLtMatmulImpl<std::complex<double>, std::complex<double>>(
3833     const blas::IBlasLtMatmulPlan *,
3834     const HostOrDeviceScalar<std::complex<double>> &,
3835     const DeviceMemory<std::complex<double>> &,
3836     const DeviceMemory<std::complex<double>> &,
3837     const HostOrDeviceScalar<std::complex<double>> &,
3838     DeviceMemory<std::complex<double>> *, ScratchAllocator *,
3839     const blas::IBlasLtMatmulAlgorithm *,
3840     const DeviceMemory<std::complex<double>> &, blas::ProfileResult *);
3841 
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)3842 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
3843   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
3844 
3845   if (rng::RngSupport *rng = parent_->AsRng()) {
3846     CheckError(rng->SetSeed(this, seed, seed_bytes));
3847   } else {
3848     SetError();
3849     LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
3850   }
3851   return *this;
3852 }
3853 
ThenPopulateRandUniform(DeviceMemory<float> * values)3854 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
3855   VLOG_CALL(PARAM(values));
3856 
3857   if (rng::RngSupport *rng = parent_->AsRng()) {
3858     CheckError(rng->DoPopulateRandUniform(this, values));
3859   } else {
3860     SetError();
3861     LOG(INFO) << DebugStreamPointers()
3862               << " attempting to perform RNG operation using StreamExecutor"
3863                  " without RNG support.";
3864   }
3865   return *this;
3866 }
3867 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)3868 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
3869                                          DeviceMemory<float> *values) {
3870   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
3871 
3872   if (rng::RngSupport *rng = parent_->AsRng()) {
3873     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
3874   } else {
3875     SetError();
3876     LOG(INFO) << DebugStreamPointers()
3877               << " attempting to perform RNG operation using StreamExecutor"
3878                  " without RNG support.";
3879   }
3880   return *this;
3881 }
3882 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)3883 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
3884                                          DeviceMemory<double> *values) {
3885   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
3886 
3887   if (rng::RngSupport *rng = parent_->AsRng()) {
3888     CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
3889   } else {
3890     SetError();
3891     LOG(INFO) << DebugStreamPointers()
3892               << " attempting to perform RNG operation using StreamExecutor"
3893                  " without RNG support.";
3894   }
3895   return *this;
3896 }
3897 
ThenPopulateRandUniform(DeviceMemory<double> * values)3898 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
3899   VLOG_CALL(PARAM(values));
3900 
3901   if (rng::RngSupport *rng = parent_->AsRng()) {
3902     CheckError(rng->DoPopulateRandUniform(this, values));
3903   } else {
3904     SetError();
3905     LOG(INFO) << DebugStreamPointers()
3906               << " attempting to perform RNG operation using StreamExecutor"
3907                  " without RNG support.";
3908   }
3909   return *this;
3910 }
3911 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)3912 Stream &Stream::ThenPopulateRandUniform(
3913     DeviceMemory<std::complex<float>> *values) {
3914   VLOG_CALL(PARAM(values));
3915 
3916   if (rng::RngSupport *rng = parent_->AsRng()) {
3917     CheckError(rng->DoPopulateRandUniform(this, values));
3918   } else {
3919     SetError();
3920     LOG(INFO) << DebugStreamPointers()
3921               << " attempting to perform RNG operation using StreamExecutor"
3922                  " without RNG support.";
3923   }
3924   return *this;
3925 }
3926 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)3927 Stream &Stream::ThenPopulateRandUniform(
3928     DeviceMemory<std::complex<double>> *values) {
3929   VLOG_CALL(PARAM(values));
3930 
3931   if (rng::RngSupport *rng = parent_->AsRng()) {
3932     CheckError(rng->DoPopulateRandUniform(this, values));
3933   } else {
3934     SetError();
3935     LOG(INFO) << DebugStreamPointers()
3936               << " attempting to perform RNG operation using StreamExecutor"
3937                  " without RNG support.";
3938   }
3939   return *this;
3940 }
3941 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)3942 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
3943                            uint64 size) {
3944   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
3945 
3946   CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
3947   return *this;
3948 }
3949 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)3950 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
3951                            uint64 size) {
3952   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
3953 
3954   CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
3955   return *this;
3956 }
3957 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)3958 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
3959                            const DeviceMemoryBase &gpu_src, uint64 size) {
3960   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
3961 
3962   CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
3963   return *this;
3964 }
3965 
ThenMemZero(DeviceMemoryBase * location,uint64 size)3966 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
3967   VLOG_CALL(PARAM(location), PARAM(size));
3968 
3969   CheckStatus(parent_->MemZero(this, location, size));
3970   return *this;
3971 }
3972 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)3973 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
3974                              uint64 size) {
3975   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
3976 
3977   CheckStatus(parent_->Memset32(this, location, pattern, size));
3978   return *this;
3979 }
3980 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)3981 Stream &Stream::ThenRnnForward(
3982     const dnn::RnnDescriptor &rnn_desc,
3983     const dnn::RnnSequenceTensorDescriptor &input_desc,
3984     const DeviceMemory<Eigen::half> &input_data,
3985     const DeviceMemory<int> &seq_lengths_data,
3986     const dnn::RnnStateTensorDescriptor &input_h_desc,
3987     const DeviceMemory<Eigen::half> &input_h_data,
3988     const dnn::RnnStateTensorDescriptor &input_c_desc,
3989     const DeviceMemory<Eigen::half> &input_c_data,
3990     const DeviceMemory<Eigen::half> &params,
3991     const dnn::RnnSequenceTensorDescriptor &output_desc,
3992     DeviceMemory<Eigen::half> *output_data,
3993     const dnn::RnnStateTensorDescriptor &output_h_desc,
3994     DeviceMemory<Eigen::half> *output_h_data,
3995     const dnn::RnnStateTensorDescriptor &output_c_desc,
3996     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
3997     ScratchAllocator *reserve_space_allocator,
3998     ScratchAllocator *workspace_allocator,
3999     dnn::ProfileResult *output_profile_result) {
4000   // TODO(zhengxq): add VLOG PARAM calls.
4001   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4002     auto status = dnn->DoRnnForward(
4003         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4004         input_h_data, input_c_desc, input_c_data, params, output_desc,
4005         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4006         is_training, reserve_space_allocator, workspace_allocator,
4007         output_profile_result);
4008     if (!status && !output_profile_result) {
4009       SetError();
4010     }
4011   } else {
4012     SetErrorAndLogNoDnnSupport();
4013   }
4014   return *this;
4015 }
4016 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4017 Stream &Stream::ThenRnnForward(
4018     const dnn::RnnDescriptor &rnn_desc,
4019     const dnn::RnnSequenceTensorDescriptor &input_desc,
4020     const DeviceMemory<float> &input_data,
4021     const DeviceMemory<int> &seq_lengths_data,
4022     const dnn::RnnStateTensorDescriptor &input_h_desc,
4023     const DeviceMemory<float> &input_h_data,
4024     const dnn::RnnStateTensorDescriptor &input_c_desc,
4025     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4026     const dnn::RnnSequenceTensorDescriptor &output_desc,
4027     DeviceMemory<float> *output_data,
4028     const dnn::RnnStateTensorDescriptor &output_h_desc,
4029     DeviceMemory<float> *output_h_data,
4030     const dnn::RnnStateTensorDescriptor &output_c_desc,
4031     DeviceMemory<float> *output_c_data, bool is_training,
4032     ScratchAllocator *reserve_space_allocator,
4033     ScratchAllocator *workspace_allocator,
4034     dnn::ProfileResult *output_profile_result) {
4035   // TODO(zhengxq): add VLOG PARAM calls.
4036   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4037     auto status = dnn->DoRnnForward(
4038         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4039         input_h_data, input_c_desc, input_c_data, params, output_desc,
4040         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4041         is_training, reserve_space_allocator, workspace_allocator,
4042         output_profile_result);
4043     if (!status && !output_profile_result) {
4044       SetError();
4045     }
4046   } else {
4047     SetErrorAndLogNoDnnSupport();
4048   }
4049   return *this;
4050 }
4051 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4052 Stream &Stream::ThenRnnForward(
4053     const dnn::RnnDescriptor &rnn_desc,
4054     const dnn::RnnSequenceTensorDescriptor &input_desc,
4055     const DeviceMemory<double> &input_data,
4056     const DeviceMemory<int> &seq_lengths_data,
4057     const dnn::RnnStateTensorDescriptor &input_h_desc,
4058     const DeviceMemory<double> &input_h_data,
4059     const dnn::RnnStateTensorDescriptor &input_c_desc,
4060     const DeviceMemory<double> &input_c_data,
4061     const DeviceMemory<double> &params,
4062     const dnn::RnnSequenceTensorDescriptor &output_desc,
4063     DeviceMemory<double> *output_data,
4064     const dnn::RnnStateTensorDescriptor &output_h_desc,
4065     DeviceMemory<double> *output_h_data,
4066     const dnn::RnnStateTensorDescriptor &output_c_desc,
4067     DeviceMemory<double> *output_c_data, bool is_training,
4068     ScratchAllocator *reserve_space_allocator,
4069     ScratchAllocator *workspace_allocator,
4070     dnn::ProfileResult *output_profile_result) {
4071   // TODO(zhengxq): add VLOG PARAM calls.
4072   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4073     auto status = dnn->DoRnnForward(
4074         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4075         input_h_data, input_c_desc, input_c_data, params, output_desc,
4076         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4077         is_training, reserve_space_allocator, workspace_allocator,
4078         output_profile_result);
4079     if (!status && !output_profile_result) {
4080       SetError();
4081     }
4082   } else {
4083     SetErrorAndLogNoDnnSupport();
4084   }
4085   return *this;
4086 }
4087 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4088 Stream &Stream::ThenRnnBackward(
4089     const dnn::RnnDescriptor &rnn_desc,
4090     const dnn::RnnSequenceTensorDescriptor &input_desc,
4091     const DeviceMemory<Eigen::half> &input_data,
4092     const DeviceMemory<int> &seq_lengths_data,
4093     const dnn::RnnStateTensorDescriptor &input_h_desc,
4094     const DeviceMemory<Eigen::half> &input_h_data,
4095     const dnn::RnnStateTensorDescriptor &input_c_desc,
4096     const DeviceMemory<Eigen::half> &input_c_data,
4097     const DeviceMemory<Eigen::half> &params,
4098     const dnn::RnnSequenceTensorDescriptor &output_desc,
4099     const DeviceMemory<Eigen::half> &output_data,
4100     const dnn::RnnStateTensorDescriptor &output_h_desc,
4101     const DeviceMemory<Eigen::half> &output_h_data,
4102     const dnn::RnnStateTensorDescriptor &output_c_desc,
4103     const DeviceMemory<Eigen::half> &output_c_data,
4104     const DeviceMemory<Eigen::half> &output_backprop_data,
4105     const DeviceMemory<Eigen::half> &output_h_backprop_data,
4106     const DeviceMemory<Eigen::half> &output_c_backprop_data,
4107     DeviceMemory<Eigen::half> *input_backprop_data,
4108     DeviceMemory<Eigen::half> *input_h_backprop_data,
4109     DeviceMemory<Eigen::half> *input_c_backprop_data,
4110     DeviceMemory<Eigen::half> *params_backprop_data,
4111     DeviceMemory<uint8> *reserve_space_data,
4112     ScratchAllocator *workspace_allocator,
4113     dnn::ProfileResult *output_profile_result) {
4114   // TODO(zhengxq): add VLOG PARAM calls.
4115   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4116     auto status = dnn->DoRnnBackward(
4117         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4118         input_h_data, input_c_desc, input_c_data, params, output_desc,
4119         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4120         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4121         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4122         params_backprop_data, reserve_space_data, workspace_allocator,
4123         output_profile_result);
4124     if (!status && !output_profile_result) {
4125       SetError();
4126     }
4127   } else {
4128     SetError();
4129     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4130   }
4131   return *this;
4132 }
4133 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4134 Stream &Stream::ThenRnnBackward(
4135     const dnn::RnnDescriptor &rnn_desc,
4136     const dnn::RnnSequenceTensorDescriptor &input_desc,
4137     const DeviceMemory<float> &input_data,
4138     const DeviceMemory<int> &seq_lengths_data,
4139     const dnn::RnnStateTensorDescriptor &input_h_desc,
4140     const DeviceMemory<float> &input_h_data,
4141     const dnn::RnnStateTensorDescriptor &input_c_desc,
4142     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4143     const dnn::RnnSequenceTensorDescriptor &output_desc,
4144     const DeviceMemory<float> &output_data,
4145     const dnn::RnnStateTensorDescriptor &output_h_desc,
4146     const DeviceMemory<float> &output_h_data,
4147     const dnn::RnnStateTensorDescriptor &output_c_desc,
4148     const DeviceMemory<float> &output_c_data,
4149     const DeviceMemory<float> &output_backprop_data,
4150     const DeviceMemory<float> &output_h_backprop_data,
4151     const DeviceMemory<float> &output_c_backprop_data,
4152     DeviceMemory<float> *input_backprop_data,
4153     DeviceMemory<float> *input_h_backprop_data,
4154     DeviceMemory<float> *input_c_backprop_data,
4155     DeviceMemory<float> *params_backprop_data,
4156     DeviceMemory<uint8> *reserve_space_data,
4157     ScratchAllocator *workspace_allocator,
4158     dnn::ProfileResult *output_profile_result) {
4159   // TODO(zhengxq): add VLOG PARAM calls.
4160   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4161     auto status = dnn->DoRnnBackward(
4162         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4163         input_h_data, input_c_desc, input_c_data, params, output_desc,
4164         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4165         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4166         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4167         params_backprop_data, reserve_space_data, workspace_allocator,
4168         output_profile_result);
4169     if (!status && !output_profile_result) {
4170       SetError();
4171     }
4172   } else {
4173     SetError();
4174     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4175   }
4176   return *this;
4177 }
4178 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)4179 Stream &Stream::ThenRnnBackward(
4180     const dnn::RnnDescriptor &rnn_desc,
4181     const dnn::RnnSequenceTensorDescriptor &input_desc,
4182     const DeviceMemory<double> &input_data,
4183     const DeviceMemory<int> &seq_lengths_data,
4184     const dnn::RnnStateTensorDescriptor &input_h_desc,
4185     const DeviceMemory<double> &input_h_data,
4186     const dnn::RnnStateTensorDescriptor &input_c_desc,
4187     const DeviceMemory<double> &input_c_data,
4188     const DeviceMemory<double> &params,
4189     const dnn::RnnSequenceTensorDescriptor &output_desc,
4190     const DeviceMemory<double> &output_data,
4191     const dnn::RnnStateTensorDescriptor &output_h_desc,
4192     const DeviceMemory<double> &output_h_data,
4193     const dnn::RnnStateTensorDescriptor &output_c_desc,
4194     const DeviceMemory<double> &output_c_data,
4195     const DeviceMemory<double> &output_backprop_data,
4196     const DeviceMemory<double> &output_h_backprop_data,
4197     const DeviceMemory<double> &output_c_backprop_data,
4198     DeviceMemory<double> *input_backprop_data,
4199     DeviceMemory<double> *input_h_backprop_data,
4200     DeviceMemory<double> *input_c_backprop_data,
4201     DeviceMemory<double> *params_backprop_data,
4202     DeviceMemory<uint8> *reserve_space_data,
4203     ScratchAllocator *workspace_allocator,
4204     dnn::ProfileResult *output_profile_result) {
4205   // TODO(zhengxq): add VLOG PARAM calls.
4206   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4207     auto status = dnn->DoRnnBackward(
4208         this, rnn_desc, input_desc, input_data, seq_lengths_data, input_h_desc,
4209         input_h_data, input_c_desc, input_c_data, params, output_desc,
4210         output_data, output_h_desc, output_h_data, output_c_desc, output_c_data,
4211         output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4212         input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4213         params_backprop_data, reserve_space_data, workspace_allocator,
4214         output_profile_result);
4215     if (!status && !output_profile_result) {
4216       SetError();
4217     }
4218   } else {
4219     SetError();
4220     LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4221   }
4222   return *this;
4223 }
4224 
ThenCtcLoss(const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemory<float> & probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemory<float> * costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemory<float> * grads_data,ScratchAllocator * workspace_allocator)4225 Stream &Stream::ThenCtcLoss(const dnn::RnnStateTensorDescriptor &probs_desc,
4226                             const DeviceMemory<float> &probs_data,
4227                             absl::Span<const int> labels_data,
4228                             absl::Span<const int> labels_lengths_data,
4229                             absl::Span<const int> input_lengths_data,
4230                             DeviceMemory<float> *costs_data,
4231                             const dnn::RnnStateTensorDescriptor &grads_desc,
4232                             DeviceMemory<float> *grads_data,
4233                             ScratchAllocator *workspace_allocator) {
4234   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4235     DeviceMemory<uint8> scratch_memory;
4236     int ctc_loss_algo_id;
4237     auto status =
4238         dnn->PrepareForCtcLoss(this, probs_desc, probs_data, grads_desc,
4239                                labels_data, labels_lengths_data,
4240                                input_lengths_data, workspace_allocator,
4241                                &scratch_memory, &ctc_loss_algo_id)
4242             .ok();
4243     if (status) {
4244       status = dnn->DoCtcLoss(this, probs_desc, probs_data, labels_data,
4245                               labels_lengths_data, input_lengths_data,
4246                               costs_data, grads_desc, grads_data,
4247                               &scratch_memory, ctc_loss_algo_id);
4248     }
4249     if (!status) {
4250       SetError();
4251     }
4252   } else {
4253     SetErrorAndLogNoDnnSupport();
4254   }
4255   return *this;
4256 }
4257 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4258 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4259                                     dnn::DataType input_type,
4260                                     const DeviceMemoryBase &input_data,
4261                                     const dnn::BatchDescriptor &output_desc,
4262                                     dnn::DataType output_type, float scale,
4263                                     DeviceMemoryBase *output_data) {
4264   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4265             PARAM(output_desc), PARAM(output_type), PARAM(scale),
4266             PARAM(output_data));
4267   if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4268     CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data,
4269                                       output_desc, output_type, scale,
4270                                       output_data));
4271   } else {
4272     SetErrorAndLogNoDnnSupport();
4273   }
4274   return *this;
4275 }
4276 
ThenDoHostCallback(std::function<void ()> callback)4277 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4278   VLOG_CALL(PARAM(callback));
4279 
4280   if (!ok()) {
4281     LOG(INFO) << DebugStreamPointers()
4282               << " was in error state before adding host callback";
4283   }
4284   CheckError(parent_->HostCallback(this, std::move(callback)));
4285   return *this;
4286 }
4287 
ThenDoHostCallbackWithStatus(std::function<port::Status ()> callback)4288 Stream &Stream::ThenDoHostCallbackWithStatus(
4289     std::function<port::Status()> callback) {
4290   VLOG_CALL(PARAM(callback));
4291 
4292   if (!ok()) {
4293     LOG(INFO) << DebugStreamPointers()
4294               << " was in error state before adding host callback";
4295   }
4296   CheckError(parent_->HostCallback(this, std::move(callback)));
4297   return *this;
4298 }
4299 
ThenRunAfterNextBlockHostUntilDone(std::function<void ()> callback)4300 Stream &Stream::ThenRunAfterNextBlockHostUntilDone(
4301     std::function<void()> callback) {
4302   VLOG_CALL(PARAM(callback));
4303 
4304   if (!ok()) {
4305     LOG(INFO) << DebugStreamPointers()
4306               << " was in error state before adding callback to be run after "
4307                  "next block-host-until-done.";
4308   }
4309   absl::MutexLock lock(&mu_);
4310   after_block_host_until_done_callbacks_.push_back(std::move(callback));
4311   return *this;
4312 }
4313 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4314 Stream &Stream::ThenFft(fft::Plan *plan,
4315                         const DeviceMemory<std::complex<float>> &input,
4316                         DeviceMemory<std::complex<float>> *output) {
4317   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4318 
4319   if (fft::FftSupport *fft = parent_->AsFft()) {
4320     CheckError(fft->DoFft(this, plan, input, output));
4321   } else {
4322     SetError();
4323     LOG(INFO) << DebugStreamPointers()
4324               << " attempting to perform FFT operation using StreamExecutor"
4325                  " without FFT support";
4326   }
4327   return *this;
4328 }
4329 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4330 Stream &Stream::ThenFft(fft::Plan *plan,
4331                         const DeviceMemory<std::complex<double>> &input,
4332                         DeviceMemory<std::complex<double>> *output) {
4333   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4334 
4335   if (fft::FftSupport *fft = parent_->AsFft()) {
4336     CheckError(fft->DoFft(this, plan, input, output));
4337   } else {
4338     SetError();
4339     LOG(INFO) << DebugStreamPointers()
4340               << " attempting to perform FFT operation using StreamExecutor"
4341                  " without FFT support";
4342   }
4343   return *this;
4344 }
4345 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4346 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4347                         DeviceMemory<std::complex<float>> *output) {
4348   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4349 
4350   if (fft::FftSupport *fft = parent_->AsFft()) {
4351     CheckError(fft->DoFft(this, plan, input, output));
4352   } else {
4353     SetError();
4354     LOG(INFO) << DebugStreamPointers()
4355               << " attempting to perform FFT operation using StreamExecutor"
4356                  " without FFT support";
4357   }
4358   return *this;
4359 }
4360 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4361 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4362                         DeviceMemory<std::complex<double>> *output) {
4363   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4364 
4365   if (fft::FftSupport *fft = parent_->AsFft()) {
4366     CheckError(fft->DoFft(this, plan, input, output));
4367   } else {
4368     SetError();
4369     LOG(INFO) << DebugStreamPointers()
4370               << " attempting to perform FFT operation using StreamExecutor"
4371                  " without FFT support";
4372   }
4373   return *this;
4374 }
4375 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)4376 Stream &Stream::ThenFft(fft::Plan *plan,
4377                         const DeviceMemory<std::complex<float>> &input,
4378                         DeviceMemory<float> *output) {
4379   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4380 
4381   if (fft::FftSupport *fft = parent_->AsFft()) {
4382     CheckError(fft->DoFft(this, plan, input, output));
4383   } else {
4384     SetError();
4385     LOG(INFO) << DebugStreamPointers()
4386               << " attempting to perform FFT operation using StreamExecutor"
4387                  " without FFT support";
4388   }
4389   return *this;
4390 }
4391 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)4392 Stream &Stream::ThenFft(fft::Plan *plan,
4393                         const DeviceMemory<std::complex<double>> &input,
4394                         DeviceMemory<double> *output) {
4395   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4396 
4397   if (fft::FftSupport *fft = parent_->AsFft()) {
4398     CheckError(fft->DoFft(this, plan, input, output));
4399   } else {
4400     SetError();
4401     LOG(INFO) << DebugStreamPointers()
4402               << " attempting to perform FFT operation using StreamExecutor"
4403                  " without FFT support";
4404   }
4405   return *this;
4406 }
4407 
4408 // It looks confusing, but all this is doing is inserting a callback at the
4409 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)4410 Stream &Stream::ThenEnqueueOnBackgroundThread(
4411     std::function<void(StreamExecutor *)> task) {
4412   VLOG_CALL(PARAM(task));
4413 
4414   StreamExecutor *stream_executor = this->parent_;
4415   std::function<void()> bound_task = std::bind(task, stream_executor);
4416 
4417   return ThenDoHostCallback([stream_executor, bound_task]() {
4418     stream_executor->EnqueueOnBackgroundThread(bound_task);
4419   });
4420 }
4421 
BlockHostUntilDone()4422 port::Status Stream::BlockHostUntilDone() {
4423   VLOG_CALL();
4424 
4425   if (!ok()) {
4426     port::Status status = port::Status(
4427         port::error::INTERNAL,
4428         "stream did not block host until done; was already in an error state");
4429     LOG(INFO) << DebugStreamPointers() << " " << status;
4430     return status;
4431   }
4432 
4433   temporary_memory_manager_.DeallocateFinalizedTemporaries();
4434 
4435   port::Status error = parent_->BlockHostUntilDone(this);
4436   CheckError(error.ok());
4437 
4438   RunAfterBlockHostUntilDoneCallbacks();
4439   return error;
4440 }
4441 
RunAfterBlockHostUntilDoneCallbacks()4442 void Stream::RunAfterBlockHostUntilDoneCallbacks() {
4443   std::vector<std::function<void()>> callbacks;
4444   {
4445     absl::MutexLock lock(&mu_);
4446     std::swap(callbacks, after_block_host_until_done_callbacks_);
4447   }
4448   for (const auto &fn : callbacks) {
4449     fn();
4450   }
4451 }
4452 
DebugStreamPointers() const4453 std::string Stream::DebugStreamPointers() const {
4454   // Relies on the ToVlogString(const void*) overload above.
4455   return absl::StrCat("[stream=", ToVlogString(this),
4456                       ",impl=", ToVlogString(implementation_.get()), "]");
4457 }
4458 
CheckStatus(port::Status status)4459 void Stream::CheckStatus(port::Status status) {
4460   if (status.ok()) {
4461     return;
4462   }
4463   LOG(ERROR) << status;
4464   absl::MutexLock lock(&mu_);
4465   status_ = status;
4466 }
4467 
4468 }  // namespace stream_executor
4469