• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/stream.h"
17 
18 #include "tensorflow/stream_executor/platform/port.h"
19 
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_buffer.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/lib/strcat.h"
24 #include "tensorflow/stream_executor/platform.h"
25 #include "tensorflow/stream_executor/platform/logging.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29 
30 namespace perftools {
31 namespace gputools {
32 
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39   return descriptor.ToShortString();
40 }
41 
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43   return descriptor.ToShortString();
44 }
45 
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47   return descriptor.ToShortString();
48 }
49 
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51   return descriptor.ToShortString();
52 }
53 
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55   return descriptor.ToShortString();
56 }
57 
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59   return dnn::ActivationModeString(mode);
60 }
61 
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63   return algo_config.ToString();
64 }
65 
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67   return dnn::ElementwiseOperationString(op);
68 }
69 
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71   return dnn::QuantizedActivationModeString(mode);
72 }
73 
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75 
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77 
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79 
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81 
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83   return blas::ComputationTypeString(ty);
84 }
85 
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87   if (ptr == nullptr) {
88     return "null";
89   }
90 
91   // StrCat does not convert pointers to text.
92   std::ostringstream out;
93   out << ptr;
94   return out.str();
95 }
96 
ToVlogString(const HostBuffer & buffer)97 string ToVlogString(const HostBuffer &buffer) { return buffer.AsString(); }
98 
99 template <class T>
ToVlogString(const std::complex<T> & c)100 string ToVlogString(const std::complex<T> &c) {
101   // StrCat does not convert std::complex to text.
102   std::ostringstream out;
103   out << c;
104   return out.str();
105 }
106 
107 template <class T>
ToVlogString(const std::function<T> & f)108 string ToVlogString(const std::function<T> &f) {
109   return f == nullptr ? "null" : "<non-null function>";
110 }
111 
ToVlogString(const DeviceMemoryBase & memory)112 string ToVlogString(const DeviceMemoryBase &memory) {
113   return ToVlogString(memory.opaque());
114 }
115 
ToVlogString(const DeviceMemoryBase * memory)116 string ToVlogString(const DeviceMemoryBase *memory) {
117   return ToVlogString(*memory);
118 }
119 
ToVlogString(const Eigen::half & h)120 string ToVlogString(const Eigen::half &h) { return port::StrCat(h); }
121 
ToVlogString(int i)122 string ToVlogString(int i) { return port::StrCat(i); }
123 
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return port::StrCat(i); }
125 
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return port::StrCat(i); }
127 
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return port::StrCat(i); }
129 
ToVlogString(float f)130 string ToVlogString(float f) { return port::StrCat(f); }
131 
ToVlogString(double d)132 string ToVlogString(double d) { return port::StrCat(d); }
133 
134 template <class T>
ToVlogString(port::ArraySlice<T> elements)135 string ToVlogString(port::ArraySlice<T> elements) {
136   string str = port::StrCat(
137       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
138       elements.size(), "]{");
139   const char *separator = "";
140   size_t max_to_show = std::numeric_limits<size_t>::max();
141   if (!VLOG_IS_ON(2)) {
142     max_to_show = 5;
143   } else if (!VLOG_IS_ON(3)) {
144     max_to_show = 20;
145   } else if (!VLOG_IS_ON(11)) {
146     max_to_show = 1000;
147   }
148   for (size_t i = 0; i < elements.size(); ++i) {
149     if (i == max_to_show) {
150       str += ", ...";
151       break;
152     }
153     port::StrAppend(&str, separator, ToVlogString(elements[i]));
154     separator = ", ";
155   }
156   str += "}";
157   return str;
158 }
159 
160 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)161 string ToVlogString(port::MutableArraySlice<T> elements) {
162   return ToVlogString(port::ArraySlice<T>(elements));
163 }
164 
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)165 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
166   switch (depth_to_space_layout) {
167     case dnn::DepthToSpaceLayout::DepthHeightWidth:
168       return "DepthToSpaceLayout::DepthHeightWidth";
169   }
170   return "unknown DepthToSpaceLayout";
171 }
172 
ToVlogString(dnn::DataType data_type)173 string ToVlogString(dnn::DataType data_type) {
174   switch (data_type) {
175     case dnn::DataType::kFloat:
176       return "dnn::DataType::kFloat";
177     case dnn::DataType::kDouble:
178       return "dnn::DataType::kDouble";
179     case dnn::DataType::kHalf:
180       return "dnn::DataType::kHalf";
181     case dnn::DataType::kInt8:
182       return "dnn::DataType::kInt8";
183   }
184 }
185 
186 // Used together with PARAM to VLOG calls made to the stream. Intended
187 // to be used like this:
188 //
189 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
190 //
191 // where a and b are the parameters to MyFunction.
192 //
193 // See VLOG_CALL for a short-hand for this. This way of doing it saves
194 // a tremendous amount of boilerplate code given how many functions
195 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)196 string CallStr(const char *function_name, Stream *stream,
197                std::vector<std::pair<const char *, string>> params) {
198   // Do not call this function unless VLOG is on since just
199   // constructing all the strings in params is expensive.
200   CHECK(VLOG_IS_ON(1));
201 
202   string str = port::StrCat("Called Stream::", function_name, "(");
203   const char *separator = "";
204   for (const auto &param : params) {
205     port::StrAppend(&str, separator, param.first, "=", param.second);
206     separator = ", ";
207   }
208   port::StrAppend(&str, ") stream=", ToVlogString(stream));
209   if (VLOG_IS_ON(10)) {
210     port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
211   }
212   return str;
213 }
214 
215 // Use this macro to avoid having to type every parameter twice to log
216 // it with VLOG and CallStr.
217 #define PARAM(parameter) \
218   { #parameter, ToVlogString(parameter) }
219 
220 // Use this macro to avoid having to type out the name of each
221 // function and to save some boilerplate. Intended to be used like this:
222 //
223 //   VLOG_CALL(PARAM(a), PARAM(b))
224 //
225 // This saves a tremendous amount of boilerplate compared to the alternative:
226 //
227 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
228 //           << ", b=" << ToVlogString(b);
229 //
230 // Note here that most of the parameter names are not short and that
231 // most of the functions take many more than 2 parameters.
232 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
233 
234 }  // namespace
235 
Stream(StreamExecutor * parent)236 Stream::Stream(StreamExecutor *parent)
237     : parent_(parent),
238       implementation_(parent->implementation()->GetStreamImplementation()),
239       allocated_(false),
240       ok_(false),
241       temporary_memory_manager_(this) {
242   VLOG_CALL(PARAM(parent));
243 }
244 
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)245 Stream::Stream(StreamExecutor *parent,
246                internal::StreamInterface *implementation)
247     : parent_(parent),
248       implementation_(implementation),
249       allocated_(false),
250       ok_(false),
251       temporary_memory_manager_(this) {
252   VLOG_CALL(PARAM(parent), PARAM(implementation));
253 }
254 
~Stream()255 Stream::~Stream() {
256   VLOG_CALL();
257 
258   temporary_memory_manager_.ForceDeallocateAll();
259 
260   if (allocated_) {
261     parent_->DeallocateStream(this);
262   }
263 }
264 
Init()265 Stream &Stream::Init() {
266   VLOG_CALL();
267 
268   mutex_lock lock{mu_};
269   CHECK_EQ(false, allocated_)
270       << "stream appears to already have been initialized";
271   CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
272 
273   if (parent_->AllocateStream(this)) {
274     // Successful initialization!
275     allocated_ = true;
276     ok_ = true;
277   } else {
278     LOG(ERROR) << "failed to allocate stream during initialization";
279   }
280 
281   return *this;
282 }
283 
InitTimer(Timer * timer)284 Stream &Stream::InitTimer(Timer *timer) {
285   VLOG_CALL(PARAM(timer));
286 
287   if (ok()) {
288     CheckError(parent_->AllocateTimer(timer));
289   } else {
290     LOG(INFO) << "did not allocate timer: " << timer;
291   }
292   return *this;
293 }
294 
InitWithTimer(Timer * timer)295 Stream &Stream::InitWithTimer(Timer *timer) {
296   VLOG_CALL(PARAM(timer));
297 
298   return Init().InitTimer(timer);
299 }
300 
ThenRecordEvent(Event * event)301 Stream &Stream::ThenRecordEvent(Event *event) {
302   VLOG_CALL(PARAM(event));
303 
304   port::Status status = parent_->RecordEvent(this, event);
305   if (!status.ok()) {
306     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
307                << "; not marking stream as bad, as the Event object may be "
308                << "at fault. Monitor for further errors.";
309   }
310 
311   return *this;
312 }
313 
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)314 Stream &Stream::ThenBatchNormalizationForward(
315     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
316     const DeviceMemory<float> &offset,
317     const DeviceMemory<float> &estimated_mean,
318     const DeviceMemory<float> &estimated_variance,
319     const dnn::BatchDescriptor &x_desc,
320     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
321     DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
322     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
323     DeviceMemory<float> *saved_inv_var, bool is_training,
324     std::function<const DeviceMemory<float> &()> var_to_inv_var,
325     std::function<void()> inv_var_to_var) {
326   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
327             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
328   if (ok()) {
329     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
330       CheckError(dnn->DoBatchNormalizationForward(
331           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
332           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
333           saved_inv_var, is_training, std::move(var_to_inv_var),
334           std::move(inv_var_to_var)));
335     } else {
336       SetErrorAndLogNoDnnSupport();
337     }
338   }
339   return *this;
340 }
341 
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)342 Stream &Stream::ThenBatchNormalizationBackward(
343     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
344     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
345     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
346     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
347     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
348     DeviceMemory<float> *offset_backprop) {
349   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
350             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
351             PARAM(scale_backprop), PARAM(offset_backprop));
352   if (ok()) {
353     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
354       CheckError(dnn->DoBatchNormalizationBackward(
355           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
356           epsilon, x_backprop, scale_backprop, offset_backprop));
357     } else {
358       SetErrorAndLogNoDnnSupport();
359     }
360   }
361   return *this;
362 }
363 
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)364 Stream &Stream::ThenBatchNormalizationForward(
365     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
366     const DeviceMemory<float> &offset,
367     const DeviceMemory<float> &estimated_mean,
368     const DeviceMemory<float> &estimated_variance,
369     const dnn::BatchDescriptor &x_desc,
370     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
371     DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
372     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
373     DeviceMemory<float> *saved_inv_var, bool is_training,
374     std::function<const DeviceMemory<float> &()> var_to_inv_var,
375     std::function<void()> inv_var_to_var) {
376   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
377             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
378   if (ok()) {
379     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
380       CheckError(dnn->DoBatchNormalizationForward(
381           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
382           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
383           saved_inv_var, is_training, std::move(var_to_inv_var),
384           std::move(inv_var_to_var)));
385     } else {
386       SetErrorAndLogNoDnnSupport();
387     }
388   }
389   return *this;
390 }
391 
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)392 Stream &Stream::ThenBatchNormalizationBackward(
393     const DeviceMemory<Eigen::half> &y_backprop,
394     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
395     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
396     const dnn::BatchDescriptor &x_desc,
397     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
398     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
399     DeviceMemory<float> *offset_backprop) {
400   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
401             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
402             PARAM(scale_backprop), PARAM(offset_backprop));
403   if (ok()) {
404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
405       CheckError(dnn->DoBatchNormalizationBackward(
406           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
407           epsilon, x_backprop, scale_backprop, offset_backprop));
408     } else {
409       SetErrorAndLogNoDnnSupport();
410     }
411   }
412   return *this;
413 }
414 
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator)415 Stream &Stream::ThenFusedConvolveWithScratch(
416     const dnn::BatchDescriptor &conv_input_descriptor,
417     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
418     const dnn::FilterDescriptor &filter_descriptor,
419     const DeviceMemory<int8> &filter_data,
420     const dnn::ConvolutionDescriptor &convolution_descriptor,
421     const DeviceMemory<int8> &side_input_data, float side_input_scale,
422     const dnn::BatchDescriptor &bias_descriptor,
423     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
424     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
425     ScratchAllocator *scratch_allocator) {
426   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
427             PARAM(conv_input_scale), PARAM(filter_descriptor),
428             PARAM(filter_data), PARAM(convolution_descriptor),
429             PARAM(side_input_data), PARAM(side_input_scale),
430             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
431             PARAM(output_descriptor), PARAM(output));
432 
433   if (ok()) {
434     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
435       CheckError(dnn->DoFusedConvolve(
436           this, conv_input_descriptor, conv_input_data, conv_input_scale,
437           filter_descriptor, filter_data, convolution_descriptor,
438           side_input_data, side_input_scale, bias_descriptor, biases,
439           activation_mode, output_descriptor, output, scratch_allocator,
440           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
441     } else {
442       SetErrorAndLogNoDnnSupport();
443     }
444   }
445   return *this;
446 }
447 
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator)448 Stream &Stream::ThenFusedConvolveWithScratch(
449     const dnn::BatchDescriptor &conv_input_descriptor,
450     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
451     const dnn::FilterDescriptor &filter_descriptor,
452     const DeviceMemory<Eigen::half> &filter_data,
453     const dnn::ConvolutionDescriptor &convolution_descriptor,
454     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
455     const dnn::BatchDescriptor &bias_descriptor,
456     const DeviceMemory<Eigen::half> &biases,
457     dnn::ActivationMode activation_mode,
458     const dnn::BatchDescriptor &output_descriptor,
459     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
460   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
461             PARAM(conv_input_scale), PARAM(filter_descriptor),
462             PARAM(filter_data), PARAM(convolution_descriptor),
463             PARAM(side_input_data), PARAM(side_input_scale),
464             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
465             PARAM(output_descriptor), PARAM(output));
466 
467   if (ok()) {
468     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
469       CheckError(dnn->DoFusedConvolve(
470           this, conv_input_descriptor, conv_input_data, conv_input_scale,
471           filter_descriptor, filter_data, convolution_descriptor,
472           side_input_data, side_input_scale, bias_descriptor, biases,
473           activation_mode, output_descriptor, output, scratch_allocator,
474           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
475     } else {
476       SetErrorAndLogNoDnnSupport();
477     }
478   }
479   return *this;
480 }
481 
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator)482 Stream &Stream::ThenFusedConvolveWithScratch(
483     const dnn::BatchDescriptor &conv_input_descriptor,
484     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
485     const dnn::FilterDescriptor &filter_descriptor,
486     const DeviceMemory<float> &filter_data,
487     const dnn::ConvolutionDescriptor &convolution_descriptor,
488     const DeviceMemory<float> &side_input_data, float side_input_scale,
489     const dnn::BatchDescriptor &bias_descriptor,
490     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
491     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
492     ScratchAllocator *scratch_allocator) {
493   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
494             PARAM(conv_input_scale), PARAM(filter_descriptor),
495             PARAM(filter_data), PARAM(convolution_descriptor),
496             PARAM(side_input_data), PARAM(side_input_scale),
497             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
498             PARAM(output_descriptor), PARAM(output));
499 
500   if (ok()) {
501     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
502       CheckError(dnn->DoFusedConvolve(
503           this, conv_input_descriptor, conv_input_data, conv_input_scale,
504           filter_descriptor, filter_data, convolution_descriptor,
505           side_input_data, side_input_scale, bias_descriptor, biases,
506           activation_mode, output_descriptor, output, scratch_allocator,
507           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
508     } else {
509       SetErrorAndLogNoDnnSupport();
510     }
511   }
512   return *this;
513 }
514 
ThenConvolveWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator)515 Stream &Stream::ThenConvolveWithScratch(
516     const dnn::BatchDescriptor &input_descriptor,
517     const DeviceMemory<Eigen::half> &input_data,
518     const dnn::FilterDescriptor &filter_descriptor,
519     const DeviceMemory<Eigen::half> &filter_data,
520     const dnn::ConvolutionDescriptor &convolution_descriptor,
521     const dnn::BatchDescriptor &output_descriptor,
522     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
523   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
524             PARAM(filter_descriptor), PARAM(filter_data),
525             PARAM(convolution_descriptor), PARAM(output_descriptor),
526             PARAM(output));
527 
528   if (ok()) {
529     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
530       CheckError(dnn->DoConvolve(
531           this, input_descriptor, input_data, filter_descriptor, filter_data,
532           convolution_descriptor, output_descriptor, output, scratch_allocator,
533           dnn::AlgorithmConfig(),
534           /*output_profile_result=*/nullptr));
535     } else {
536       SetErrorAndLogNoDnnSupport();
537     }
538   }
539   return *this;
540 }
541 
ThenConvolveWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator)542 Stream &Stream::ThenConvolveWithScratch(
543     const dnn::BatchDescriptor &input_descriptor,
544     const DeviceMemory<float> &input_data,
545     const dnn::FilterDescriptor &filter_descriptor,
546     const DeviceMemory<float> &filter_data,
547     const dnn::ConvolutionDescriptor &convolution_descriptor,
548     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
549     ScratchAllocator *scratch_allocator) {
550   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
551             PARAM(filter_descriptor), PARAM(filter_data),
552             PARAM(convolution_descriptor), PARAM(output_descriptor),
553             PARAM(output));
554 
555   if (ok()) {
556     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
557       CheckError(dnn->DoConvolve(
558           this, input_descriptor, input_data, filter_descriptor, filter_data,
559           convolution_descriptor, output_descriptor, output, scratch_allocator,
560           dnn::AlgorithmConfig(),
561           /*output_profile_result=*/nullptr));
562     } else {
563       SetErrorAndLogNoDnnSupport();
564     }
565   }
566   return *this;
567 }
568 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)569 Stream &Stream::ThenFusedConvolveWithAlgorithm(
570     const dnn::BatchDescriptor &conv_input_descriptor,
571     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
572     const dnn::FilterDescriptor &filter_descriptor,
573     const DeviceMemory<float> &filter_data,
574     const dnn::ConvolutionDescriptor &convolution_descriptor,
575     const DeviceMemory<float> &side_input_data, float side_input_scale,
576     const dnn::BatchDescriptor &bias_descriptor,
577     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
578     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
579     ScratchAllocator *scratch_allocator,
580     const dnn::AlgorithmConfig &algorithm_config,
581     dnn::ProfileResult *output_profile_result) {
582   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
583             PARAM(conv_input_scale), PARAM(filter_descriptor),
584             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
585             PARAM(side_input_data), PARAM(side_input_scale),
586             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
587             PARAM(algorithm_config));
588 
589   if (ok()) {
590     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
591       auto status = dnn->DoFusedConvolve(
592           this, conv_input_descriptor, conv_input_data, conv_input_scale,
593           filter_descriptor, filter_data, convolution_descriptor,
594           side_input_data, side_input_scale, bias_descriptor, biases,
595           activation_mode, output_descriptor, output, scratch_allocator,
596           algorithm_config, output_profile_result);
597       if (!status && !output_profile_result) {
598         SetError();
599       }
600     } else {
601       SetErrorAndLogNoDnnSupport();
602     }
603   }
604   return *this;
605 }
606 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)607 Stream &Stream::ThenFusedConvolveWithAlgorithm(
608     const dnn::BatchDescriptor &conv_input_descriptor,
609     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
610     const dnn::FilterDescriptor &filter_descriptor,
611     const DeviceMemory<Eigen::half> &filter_data,
612     const dnn::ConvolutionDescriptor &convolution_descriptor,
613     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
614     const dnn::BatchDescriptor &bias_descriptor,
615     const DeviceMemory<Eigen::half> &biases,
616     dnn::ActivationMode activation_mode,
617     const dnn::BatchDescriptor &output_descriptor,
618     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
619     const dnn::AlgorithmConfig &algorithm_config,
620     dnn::ProfileResult *output_profile_result) {
621   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
622             PARAM(conv_input_scale), PARAM(filter_descriptor),
623             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
624             PARAM(side_input_data), PARAM(side_input_scale),
625             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
626             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
627 
628   if (ok()) {
629     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
630       auto status = dnn->DoFusedConvolve(
631           this, conv_input_descriptor, conv_input_data, conv_input_scale,
632           filter_descriptor, filter_data, convolution_descriptor,
633           side_input_data, side_input_scale, bias_descriptor, biases,
634           activation_mode, output_descriptor, output, scratch_allocator,
635           algorithm_config, output_profile_result);
636       if (!status && !output_profile_result) {
637         SetError();
638       }
639     } else {
640       SetErrorAndLogNoDnnSupport();
641     }
642   }
643   return *this;
644 }
645 
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)646 Stream &Stream::ThenFusedConvolveWithAlgorithm(
647     const dnn::BatchDescriptor &conv_input_descriptor,
648     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
649     const dnn::FilterDescriptor &filter_descriptor,
650     const DeviceMemory<int8> &filter_data,
651     const dnn::ConvolutionDescriptor &convolution_descriptor,
652     const DeviceMemory<int8> &side_input_data, float side_input_scale,
653     const dnn::BatchDescriptor &bias_descriptor,
654     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
655     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
656     ScratchAllocator *scratch_allocator,
657     const dnn::AlgorithmConfig &algorithm_config,
658     dnn::ProfileResult *output_profile_result) {
659   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
660             PARAM(conv_input_scale), PARAM(filter_descriptor),
661             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
662             PARAM(side_input_data), PARAM(side_input_scale),
663             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
664             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
665 
666   if (ok()) {
667     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
668       auto status = dnn->DoFusedConvolve(
669           this, conv_input_descriptor, conv_input_data, conv_input_scale,
670           filter_descriptor, filter_data, convolution_descriptor,
671           side_input_data, side_input_scale, bias_descriptor, biases,
672           activation_mode, output_descriptor, output, scratch_allocator,
673           algorithm_config, output_profile_result);
674       if (!status && !output_profile_result) {
675         SetError();
676       }
677     } else {
678       SetErrorAndLogNoDnnSupport();
679     }
680   }
681   return *this;
682 }
683 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)684 Stream &Stream::ThenConvolveWithAlgorithm(
685     const dnn::BatchDescriptor &input_descriptor,
686     const DeviceMemory<float> &input_data,
687     const dnn::FilterDescriptor &filter_descriptor,
688     const DeviceMemory<float> &filter_data,
689     const dnn::ConvolutionDescriptor &convolution_descriptor,
690     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
691     ScratchAllocator *scratch_allocator,
692     const dnn::AlgorithmConfig &algorithm_config,
693     dnn::ProfileResult *output_profile_result) {
694   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
695             PARAM(filter_descriptor), PARAM(filter_data),
696             PARAM(convolution_descriptor), PARAM(output_descriptor),
697             PARAM(output), PARAM(algorithm_config));
698 
699   if (ok()) {
700     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
701       auto status = dnn->DoConvolve(
702           this, input_descriptor, input_data, filter_descriptor, filter_data,
703           convolution_descriptor, output_descriptor, output, scratch_allocator,
704           algorithm_config, output_profile_result);
705       if (!status && !output_profile_result) {
706         SetError();
707       }
708     } else {
709       SetErrorAndLogNoDnnSupport();
710     }
711   }
712   return *this;
713 }
714 
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)715 Stream &Stream::ThenConvolveWithAlgorithm(
716     const dnn::BatchDescriptor &input_descriptor,
717     const DeviceMemory<Eigen::half> &input_data,
718     const dnn::FilterDescriptor &filter_descriptor,
719     const DeviceMemory<Eigen::half> &filter_data,
720     const dnn::ConvolutionDescriptor &convolution_descriptor,
721     const dnn::BatchDescriptor &output_descriptor,
722     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
723     const dnn::AlgorithmConfig &algorithm_config,
724     dnn::ProfileResult *output_profile_result) {
725   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
726             PARAM(filter_descriptor), PARAM(filter_data),
727             PARAM(convolution_descriptor), PARAM(output_descriptor),
728             PARAM(output), PARAM(algorithm_config));
729 
730   if (ok()) {
731     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
732       auto status = dnn->DoConvolve(
733           this, input_descriptor, input_data, filter_descriptor, filter_data,
734           convolution_descriptor, output_descriptor, output, scratch_allocator,
735           algorithm_config, output_profile_result);
736       if (!status && !output_profile_result) {
737         SetError();
738       }
739     } else {
740       SetErrorAndLogNoDnnSupport();
741     }
742   }
743   return *this;
744 }
745 
ThenFusedConvolve(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output)746 Stream &Stream::ThenFusedConvolve(
747     const dnn::BatchDescriptor &conv_input_descriptor,
748     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
749     const dnn::FilterDescriptor &filter_descriptor,
750     const DeviceMemory<int8> &filter_data,
751     const dnn::ConvolutionDescriptor &convolution_descriptor,
752     const DeviceMemory<int8> &side_input_data, float side_input_scale,
753     const dnn::BatchDescriptor &bias_descriptor,
754     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
755     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) {
756   return ThenFusedConvolveWithScratch(
757       conv_input_descriptor, conv_input_data, conv_input_scale,
758       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
759       side_input_scale, bias_descriptor, biases, activation_mode,
760       output_descriptor, output,
761       /*scratch_allocator=*/nullptr);
762 }
763 
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)764 Stream &Stream::ThenConvolve(
765     const dnn::BatchDescriptor &input_descriptor,
766     const DeviceMemory<float> &input_data,
767     const dnn::FilterDescriptor &filter_descriptor,
768     const DeviceMemory<float> &filter_data,
769     const dnn::ConvolutionDescriptor &convolution_descriptor,
770     const dnn::BatchDescriptor &output_descriptor,
771     DeviceMemory<float> *output) {
772   return ThenConvolveWithScratch(input_descriptor, input_data,
773                                  filter_descriptor, filter_data,
774                                  convolution_descriptor, output_descriptor,
775                                  output, /*scratch_allocator=*/nullptr);
776 }
777 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)778 Stream &Stream::ThenConvolveQuantized(
779     const dnn::BatchDescriptor &input_descriptor,
780     const DeviceMemory<float> &input_data,
781     const dnn::FilterDescriptor &filter_descriptor,
782     const DeviceMemory<int8> &filter_coefficients,
783     const DeviceMemory<float> &coefficient_scales,
784     const dnn::ConvolutionDescriptor &convolution_descriptor,
785     const dnn::BatchDescriptor &output_descriptor,
786     DeviceMemory<float> *output) {
787   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
788             PARAM(filter_descriptor), PARAM(filter_coefficients),
789             PARAM(coefficient_scales), PARAM(convolution_descriptor),
790             PARAM(output_descriptor), PARAM(output));
791 
792   if (ok()) {
793     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794       CheckError(dnn->DoConvolveQuantized(
795           this, input_descriptor, input_data, filter_descriptor,
796           filter_coefficients, coefficient_scales, convolution_descriptor,
797           output_descriptor, output));
798     } else {
799       SetError();
800       LOG(WARNING)
801           << "attempting to perform DNN operation using StreamExecutor "
802              "without DNN support";
803     }
804   }
805   return *this;
806 }
807 
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)808 Stream &Stream::ThenConvolveQuantized(
809     const dnn::BatchDescriptor &input_descriptor,
810     const DeviceMemory<float> &input_data,
811     const dnn::FilterDescriptor &filter_descriptor,
812     const DeviceMemory<int16> &filter_coefficients,
813     const DeviceMemory<float> &coefficient_scales,
814     const dnn::ConvolutionDescriptor &convolution_descriptor,
815     const dnn::BatchDescriptor &output_descriptor,
816     DeviceMemory<float> *output) {
817   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
818             PARAM(filter_descriptor), PARAM(filter_coefficients),
819             PARAM(coefficient_scales), PARAM(convolution_descriptor),
820             PARAM(output_descriptor), PARAM(output));
821 
822   if (ok()) {
823     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
824       CheckError(dnn->DoConvolveQuantized(
825           this, input_descriptor, input_data, filter_descriptor,
826           filter_coefficients, coefficient_scales, convolution_descriptor,
827           output_descriptor, output));
828     } else {
829       SetError();
830       LOG(WARNING)
831           << "attempting to perform DNN operation using StreamExecutor "
832              "without DNN support";
833     }
834   }
835   return *this;
836 }
837 
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)838 Stream &Stream::ThenSeparableConvolve(
839     const dnn::BatchDescriptor &batch_descriptor,
840     const DeviceMemory<float> &input_data,
841     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
842     const DeviceMemory<float> &first_weights,
843     const DeviceMemory<float> &second_weights,
844     const dnn::ConvolutionDescriptor &convolution_descriptor,
845     const dnn::BatchDescriptor &output_descriptor,
846     DeviceMemory<float> *output) {
847   VLOG_CALL(
848       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
849       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
850       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
851 
852   if (ok()) {
853     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
854       CheckError(dnn->DoSeparableConvolve(
855           this, batch_descriptor, input_data, filter_descriptor,
856           depth_multiplier, first_weights, second_weights,
857           convolution_descriptor, output_descriptor, output));
858     } else {
859       SetErrorAndLogNoDnnSupport();
860     }
861   }
862   return *this;
863 }
864 
ThenConvolveBackwardDataWithScratch(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator)865 Stream &Stream::ThenConvolveBackwardDataWithScratch(
866     const dnn::FilterDescriptor &filter_descriptor,
867     const DeviceMemory<float> &filter_data,
868     const dnn::BatchDescriptor &output_descriptor,
869     DeviceMemory<float> backward_output_data,
870     const dnn::ConvolutionDescriptor &convolution_descriptor,
871     const dnn::BatchDescriptor &input_descriptor,
872     DeviceMemory<float> *backward_input_data,
873     ScratchAllocator *scratch_allocator) {
874   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
875             PARAM(output_descriptor), PARAM(backward_output_data),
876             PARAM(convolution_descriptor), PARAM(input_descriptor),
877             PARAM(backward_input_data));
878 
879   if (ok()) {
880     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
881       CheckError(dnn->DoConvolveBackwardData(
882           this, filter_descriptor, filter_data, output_descriptor,
883           backward_output_data, convolution_descriptor, input_descriptor,
884           backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
885           /*output_profile_result=*/nullptr));
886     } else {
887       SetErrorAndLogNoDnnSupport();
888     }
889   }
890   return *this;
891 }
892 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)893 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
894     const dnn::FilterDescriptor &filter_descriptor,
895     const DeviceMemory<float> &filter_data,
896     const dnn::BatchDescriptor &output_descriptor,
897     DeviceMemory<float> backward_output_data,
898     const dnn::ConvolutionDescriptor &convolution_descriptor,
899     const dnn::BatchDescriptor &input_descriptor,
900     DeviceMemory<float> *backward_input_data,
901     ScratchAllocator *scratch_allocator,
902     const dnn::AlgorithmConfig &algorithm_config,
903     dnn::ProfileResult *output_profile_result) {
904   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
905             PARAM(output_descriptor), PARAM(backward_output_data),
906             PARAM(convolution_descriptor), PARAM(input_descriptor),
907             PARAM(backward_input_data));
908 
909   if (ok()) {
910     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
911       auto status = dnn->DoConvolveBackwardData(
912           this, filter_descriptor, filter_data, output_descriptor,
913           backward_output_data, convolution_descriptor, input_descriptor,
914           backward_input_data, scratch_allocator, algorithm_config,
915           output_profile_result);
916       if (!status && !output_profile_result) {
917         SetError();
918       }
919     } else {
920       SetErrorAndLogNoDnnSupport();
921     }
922   }
923   return *this;
924 }
925 
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)926 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
927     const dnn::FilterDescriptor &filter_descriptor,
928     const DeviceMemory<Eigen::half> &filter_data,
929     const dnn::BatchDescriptor &output_descriptor,
930     DeviceMemory<Eigen::half> backward_output_data,
931     const dnn::ConvolutionDescriptor &convolution_descriptor,
932     const dnn::BatchDescriptor &input_descriptor,
933     DeviceMemory<Eigen::half> *backward_input_data,
934     ScratchAllocator *scratch_allocator,
935     const dnn::AlgorithmConfig &algorithm_config,
936     dnn::ProfileResult *output_profile_result) {
937   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
938             PARAM(output_descriptor), PARAM(backward_output_data),
939             PARAM(convolution_descriptor), PARAM(input_descriptor),
940             PARAM(backward_input_data));
941 
942   if (ok()) {
943     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
944       auto status = dnn->DoConvolveBackwardData(
945           this, filter_descriptor, filter_data, output_descriptor,
946           backward_output_data, convolution_descriptor, input_descriptor,
947           backward_input_data, scratch_allocator, algorithm_config,
948           output_profile_result);
949       if (!status && !output_profile_result) {
950         SetError();
951       }
952     } else {
953       SetErrorAndLogNoDnnSupport();
954     }
955   }
956   return *this;
957 }
958 
ThenConvolveBackwardDataWithScratch(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator)959 Stream &Stream::ThenConvolveBackwardDataWithScratch(
960     const dnn::FilterDescriptor &filter_descriptor,
961     const DeviceMemory<Eigen::half> &filter_data,
962     const dnn::BatchDescriptor &output_descriptor,
963     DeviceMemory<Eigen::half> backward_output_data,
964     const dnn::ConvolutionDescriptor &convolution_descriptor,
965     const dnn::BatchDescriptor &input_descriptor,
966     DeviceMemory<Eigen::half> *backward_input_data,
967     ScratchAllocator *scratch_allocator) {
968   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
969             PARAM(output_descriptor), PARAM(backward_output_data),
970             PARAM(convolution_descriptor), PARAM(input_descriptor),
971             PARAM(backward_input_data));
972 
973   if (ok()) {
974     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
975       CheckError(dnn->DoConvolveBackwardData(
976           this, filter_descriptor, filter_data, output_descriptor,
977           backward_output_data, convolution_descriptor, input_descriptor,
978           backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
979           /*output_profile_result=*/nullptr));
980     } else {
981       SetErrorAndLogNoDnnSupport();
982     }
983   }
984   return *this;
985 }
986 
ThenConvolveBackwardData(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data)987 Stream &Stream::ThenConvolveBackwardData(
988     const dnn::FilterDescriptor &filter_descriptor,
989     const DeviceMemory<float> &filter_data,
990     const dnn::BatchDescriptor &output_descriptor,
991     DeviceMemory<float> backward_output_data,
992     const dnn::ConvolutionDescriptor &convolution_descriptor,
993     const dnn::BatchDescriptor &input_descriptor,
994     DeviceMemory<float> *backward_input_data) {
995   return ThenConvolveBackwardDataWithScratch(
996       filter_descriptor, filter_data, output_descriptor, backward_output_data,
997       convolution_descriptor, input_descriptor, backward_input_data,
998       /*scratch_allocator=*/nullptr);
999 }
1000 
ThenConvolveBackwardFilterWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator)1001 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
1002     const dnn::BatchDescriptor &input_descriptor,
1003     const DeviceMemory<float> &input_data,
1004     const dnn::BatchDescriptor &output_descriptor,
1005     DeviceMemory<float> backward_output_data,
1006     const dnn::ConvolutionDescriptor &convolution_descriptor,
1007     const dnn::FilterDescriptor &filter_descriptor,
1008     DeviceMemory<float> *backward_filter_data,
1009     ScratchAllocator *scratch_allocator) {
1010   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1011             PARAM(output_descriptor), PARAM(backward_output_data),
1012             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1013             PARAM(backward_filter_data));
1014 
1015   if (ok()) {
1016     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1017       CheckError(dnn->DoConvolveBackwardFilter(
1018           this, input_descriptor, input_data, output_descriptor,
1019           backward_output_data, convolution_descriptor, filter_descriptor,
1020           backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
1021           /*output_profile_result=*/nullptr));
1022     } else {
1023       SetErrorAndLogNoDnnSupport();
1024     }
1025   }
1026   return *this;
1027 }
1028 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1029 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1030     const dnn::BatchDescriptor &input_descriptor,
1031     const DeviceMemory<float> &input_data,
1032     const dnn::BatchDescriptor &output_descriptor,
1033     DeviceMemory<float> backward_output_data,
1034     const dnn::ConvolutionDescriptor &convolution_descriptor,
1035     const dnn::FilterDescriptor &filter_descriptor,
1036     DeviceMemory<float> *backward_filter_data,
1037     ScratchAllocator *scratch_allocator,
1038     const dnn::AlgorithmConfig &algorithm_config,
1039     dnn::ProfileResult *output_profile_result) {
1040   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1041             PARAM(output_descriptor), PARAM(backward_output_data),
1042             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1043             PARAM(backward_filter_data));
1044 
1045   if (ok()) {
1046     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1047       auto status = dnn->DoConvolveBackwardFilter(
1048           this, input_descriptor, input_data, output_descriptor,
1049           backward_output_data, convolution_descriptor, filter_descriptor,
1050           backward_filter_data, scratch_allocator, algorithm_config,
1051           output_profile_result);
1052       if (!status && !output_profile_result) {
1053         SetError();
1054       }
1055     } else {
1056       SetErrorAndLogNoDnnSupport();
1057     }
1058   }
1059   return *this;
1060 }
1061 
ThenConvolveBackwardFilterWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator)1062 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
1063     const dnn::BatchDescriptor &input_descriptor,
1064     const DeviceMemory<Eigen::half> &input_data,
1065     const dnn::BatchDescriptor &output_descriptor,
1066     DeviceMemory<Eigen::half> backward_output_data,
1067     const dnn::ConvolutionDescriptor &convolution_descriptor,
1068     const dnn::FilterDescriptor &filter_descriptor,
1069     DeviceMemory<Eigen::half> *backward_filter_data,
1070     ScratchAllocator *scratch_allocator) {
1071   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1072             PARAM(output_descriptor), PARAM(backward_output_data),
1073             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1074             PARAM(backward_filter_data));
1075 
1076   if (ok()) {
1077     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1078       CheckError(dnn->DoConvolveBackwardFilter(
1079           this, input_descriptor, input_data, output_descriptor,
1080           backward_output_data, convolution_descriptor, filter_descriptor,
1081           backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
1082           /*output_profile_result=*/nullptr));
1083     } else {
1084       SetErrorAndLogNoDnnSupport();
1085     }
1086   }
1087   return *this;
1088 }
1089 
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1090 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1091     const dnn::BatchDescriptor &input_descriptor,
1092     const DeviceMemory<Eigen::half> &input_data,
1093     const dnn::BatchDescriptor &output_descriptor,
1094     DeviceMemory<Eigen::half> backward_output_data,
1095     const dnn::ConvolutionDescriptor &convolution_descriptor,
1096     const dnn::FilterDescriptor &filter_descriptor,
1097     DeviceMemory<Eigen::half> *backward_filter_data,
1098     ScratchAllocator *scratch_allocator,
1099     const dnn::AlgorithmConfig &algorithm_config,
1100     dnn::ProfileResult *output_profile_result) {
1101   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1102             PARAM(output_descriptor), PARAM(backward_output_data),
1103             PARAM(convolution_descriptor), PARAM(filter_descriptor),
1104             PARAM(backward_filter_data));
1105 
1106   if (ok()) {
1107     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1108       auto status = dnn->DoConvolveBackwardFilter(
1109           this, input_descriptor, input_data, output_descriptor,
1110           backward_output_data, convolution_descriptor, filter_descriptor,
1111           backward_filter_data, scratch_allocator, algorithm_config,
1112           output_profile_result);
1113       if (!status && !output_profile_result) {
1114         SetError();
1115       }
1116     } else {
1117       SetErrorAndLogNoDnnSupport();
1118     }
1119   }
1120   return *this;
1121 }
1122 
ThenConvolveBackwardFilter(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data)1123 Stream &Stream::ThenConvolveBackwardFilter(
1124     const dnn::BatchDescriptor &input_descriptor,
1125     const DeviceMemory<float> &input_data,
1126     const dnn::BatchDescriptor &output_descriptor,
1127     DeviceMemory<float> backward_output_data,
1128     const dnn::ConvolutionDescriptor &convolution_descriptor,
1129     const dnn::FilterDescriptor &filter_descriptor,
1130     DeviceMemory<float> *backward_filter_data) {
1131   return ThenConvolveBackwardFilterWithScratch(
1132       input_descriptor, input_data, output_descriptor, backward_output_data,
1133       convolution_descriptor, filter_descriptor, backward_filter_data,
1134       /*scratch_allocator=*/nullptr);
1135 }
1136 
1137 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1138 Stream &Stream::ThenConvolveBackwardBiasImpl(
1139     const dnn::BatchDescriptor &input_descriptor,
1140     const DeviceMemory<T> &input_data,
1141     const dnn::BatchDescriptor &bias_descriptor,
1142     DeviceMemory<T> *backward_bias_data) {
1143   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1144             PARAM(backward_bias_data));
1145 
1146   if (ok()) {
1147     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1148       CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1149                                              bias_descriptor,
1150                                              backward_bias_data));
1151     } else {
1152       SetErrorAndLogNoDnnSupport();
1153     }
1154   }
1155   return *this;
1156 }
1157 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1158 Stream &Stream::ThenConvolveBackwardBias(
1159     const dnn::BatchDescriptor &input_descriptor,
1160     const DeviceMemory<double> &input_data,
1161     const dnn::BatchDescriptor &bias_descriptor,
1162     DeviceMemory<double> *backward_bias_data) {
1163   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1164                                       bias_descriptor, backward_bias_data);
1165 }
1166 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1167 Stream &Stream::ThenConvolveBackwardBias(
1168     const dnn::BatchDescriptor &input_descriptor,
1169     const DeviceMemory<float> &input_data,
1170     const dnn::BatchDescriptor &bias_descriptor,
1171     DeviceMemory<float> *backward_bias_data) {
1172   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1173                                       bias_descriptor, backward_bias_data);
1174 }
1175 
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1176 Stream &Stream::ThenConvolveBackwardBias(
1177     const dnn::BatchDescriptor &input_descriptor,
1178     const DeviceMemory<Eigen::half> &input_data,
1179     const dnn::BatchDescriptor &bias_descriptor,
1180     DeviceMemory<Eigen::half> *backward_bias_data) {
1181   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1182                                       bias_descriptor, backward_bias_data);
1183 }
1184 
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1185 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1186                            const DeviceMemory<float> &weights,
1187                            const dnn::BatchDescriptor &input_dimensions,
1188                            const dnn::BatchDescriptor &output_dimensions,
1189                            DeviceMemory<float> *output_data) {
1190   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1191             PARAM(output_dimensions), PARAM(output_data));
1192 
1193   if (ok()) {
1194     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1195       CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1196                                output_dimensions, output_data));
1197     } else {
1198       SetErrorAndLogNoDnnSupport();
1199     }
1200   }
1201   return *this;
1202 }
1203 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1204 Stream &Stream::ThenMatMulQuantized(
1205     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1206     const DeviceMemory<float> &weight_scales,
1207     const dnn::BatchDescriptor &input_dimensions,
1208     const dnn::BatchDescriptor &output_dimensions,
1209     DeviceMemory<float> *output_data) {
1210   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1211             PARAM(input_dimensions), PARAM(output_dimensions),
1212             PARAM(output_data));
1213 
1214   if (ok()) {
1215     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1216       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1217                                         weight_scales, input_dimensions,
1218                                         output_dimensions, output_data));
1219     } else {
1220       SetErrorAndLogNoDnnSupport();
1221     }
1222   }
1223   return *this;
1224 }
1225 
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1226 Stream &Stream::ThenMatMulQuantized(
1227     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1228     const DeviceMemory<float> &weight_scales,
1229     const dnn::BatchDescriptor &input_dimensions,
1230     const dnn::BatchDescriptor &output_dimensions,
1231     DeviceMemory<float> *output_data) {
1232   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1233             PARAM(input_dimensions), PARAM(output_dimensions),
1234             PARAM(output_data));
1235 
1236   if (ok()) {
1237     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1238       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1239                                         weight_scales, input_dimensions,
1240                                         output_dimensions, output_data));
1241     } else {
1242       SetErrorAndLogNoDnnSupport();
1243     }
1244   }
1245   return *this;
1246 }
1247 
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1248 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1249                             const DeviceMemory<float> &biases,
1250                             const dnn::BatchDescriptor &dimensions,
1251                             DeviceMemory<float> *output_data) {
1252   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1253             PARAM(output_data));
1254 
1255   if (ok()) {
1256     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1257       CheckError(
1258           dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1259     } else {
1260       SetErrorAndLogNoDnnSupport();
1261     }
1262   }
1263   return *this;
1264 }
1265 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data)1266 Stream &Stream::ThenPoolForward(
1267     const dnn::PoolingDescriptor &pooling_dimensions,
1268     const dnn::BatchDescriptor &input_dimensions,
1269     const DeviceMemory<double> &input_data,
1270     const dnn::BatchDescriptor &output_dimensions,
1271     DeviceMemory<double> *output_data) {
1272   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1273             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1274 
1275   if (ok()) {
1276     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1277       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1278                                     input_data, output_dimensions,
1279                                     output_data));
1280     } else {
1281       SetError();
1282       LOG(WARNING)
1283           << "attempting to perform DNN operation using StreamExecutor "
1284              "without DNN support";
1285     }
1286   }
1287   return *this;
1288 }
1289 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1290 Stream &Stream::ThenPoolForward(
1291     const dnn::PoolingDescriptor &pooling_dimensions,
1292     const dnn::BatchDescriptor &input_dimensions,
1293     const DeviceMemory<float> &input_data,
1294     const dnn::BatchDescriptor &output_dimensions,
1295     DeviceMemory<float> *output_data) {
1296   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1297             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1298 
1299   if (ok()) {
1300     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1301       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1302                                     input_data, output_dimensions,
1303                                     output_data));
1304     } else {
1305       SetErrorAndLogNoDnnSupport();
1306     }
1307   }
1308   return *this;
1309 }
1310 
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data)1311 Stream &Stream::ThenPoolForward(
1312     const dnn::PoolingDescriptor &pooling_dimensions,
1313     const dnn::BatchDescriptor &input_dimensions,
1314     const DeviceMemory<Eigen::half> &input_data,
1315     const dnn::BatchDescriptor &output_dimensions,
1316     DeviceMemory<Eigen::half> *output_data) {
1317   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1318             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1319 
1320   if (ok()) {
1321     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1322       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1323                                     input_data, output_dimensions,
1324                                     output_data));
1325     } else {
1326       SetErrorAndLogNoDnnSupport();
1327     }
1328   }
1329   return *this;
1330 }
1331 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data)1332 Stream &Stream::ThenPoolBackward(
1333     const dnn::PoolingDescriptor &pooling_dimensions,
1334     const dnn::BatchDescriptor &input_dimensions,
1335     const DeviceMemory<double> &input_data,
1336     const dnn::BatchDescriptor &output_dimensions,
1337     const DeviceMemory<double> &output_data,
1338     const DeviceMemory<double> &input_diff_data,
1339     DeviceMemory<double> *output_diff_data) {
1340   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1341             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1342             PARAM(input_diff_data), PARAM(output_diff_data));
1343 
1344   if (ok()) {
1345     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1346       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1347                                      input_data, output_dimensions, output_data,
1348                                      input_diff_data, output_diff_data));
1349     } else {
1350       SetError();
1351       LOG(WARNING)
1352           << "attempting to perform DNN operation using StreamExecutor "
1353              "without DNN support";
1354     }
1355   }
1356   return *this;
1357 }
1358 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data)1359 Stream &Stream::ThenPoolBackward(
1360     const dnn::PoolingDescriptor &pooling_dimensions,
1361     const dnn::BatchDescriptor &input_dimensions,
1362     const DeviceMemory<float> &input_data,
1363     const dnn::BatchDescriptor &output_dimensions,
1364     const DeviceMemory<float> &output_data,
1365     const DeviceMemory<float> &input_diff_data,
1366     DeviceMemory<float> *output_diff_data) {
1367   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1368             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1369             PARAM(input_diff_data), PARAM(output_diff_data));
1370 
1371   if (ok()) {
1372     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1373       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1374                                      input_data, output_dimensions, output_data,
1375                                      input_diff_data, output_diff_data));
1376     } else {
1377       SetErrorAndLogNoDnnSupport();
1378     }
1379   }
1380   return *this;
1381 }
1382 
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data)1383 Stream &Stream::ThenPoolBackward(
1384     const dnn::PoolingDescriptor &pooling_dimensions,
1385     const dnn::BatchDescriptor &input_dimensions,
1386     const DeviceMemory<Eigen::half> &input_data,
1387     const dnn::BatchDescriptor &output_dimensions,
1388     const DeviceMemory<Eigen::half> &output_data,
1389     const DeviceMemory<Eigen::half> &input_diff_data,
1390     DeviceMemory<Eigen::half> *output_diff_data) {
1391   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1392             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1393             PARAM(input_diff_data), PARAM(output_diff_data));
1394 
1395   if (ok()) {
1396     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1397       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1398                                      input_data, output_dimensions, output_data,
1399                                      input_diff_data, output_diff_data));
1400     } else {
1401       SetErrorAndLogNoDnnSupport();
1402     }
1403   }
1404   return *this;
1405 }
1406 
ThenNormalize(const dnn::NormalizeDescriptor & normalize_descriptor,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1407 Stream &Stream::ThenNormalize(
1408     const dnn::NormalizeDescriptor &normalize_descriptor,
1409     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1410   VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data));
1411 
1412   if (ok()) {
1413     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1414       CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data,
1415                                   output_data));
1416     } else {
1417       SetErrorAndLogNoDnnSupport();
1418     }
1419   }
1420   return *this;
1421 }
1422 
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1423 Stream &Stream::ThenNormalizeWithDimensions(
1424     const dnn::NormalizeDescriptor &normalize_descriptor,
1425     const dnn::BatchDescriptor &dimensions,
1426     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1427   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1428             PARAM(output_data));
1429 
1430   if (ok()) {
1431     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1432       CheckError(dnn->DoNormalizeWithDimensions(
1433           this, normalize_descriptor, dimensions, input_data, output_data));
1434     } else {
1435       SetErrorAndLogNoDnnSupport();
1436     }
1437   }
1438   return *this;
1439 }
1440 
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient)1441 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1442     const dnn::NormalizeDescriptor &normalize_descriptor,
1443     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1444     const DeviceMemory<float> &normalized_data,
1445     const DeviceMemory<float> &normalized_variable_gradient,
1446     DeviceMemory<float> *raw_variable_gradient) {
1447   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1448             PARAM(normalized_data), PARAM(normalized_variable_gradient),
1449             PARAM(raw_variable_gradient));
1450 
1451   if (ok()) {
1452     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1453       CheckError(dnn->DoNormalizeBackwardWithDimensions(
1454           this, normalize_descriptor, dimensions, raw_data, normalized_data,
1455           normalized_variable_gradient, raw_variable_gradient));
1456     } else {
1457       SetErrorAndLogNoDnnSupport();
1458     }
1459   }
1460   return *this;
1461 }
1462 
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1463 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1464                              const dnn::BatchDescriptor &dimensions,
1465                              const DeviceMemory<float> &input_data,
1466                              DeviceMemory<float> *output_data) {
1467   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1468                                  output_data, /*options=*/0);
1469 }
1470 
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1471 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1472                                         const dnn::BatchDescriptor &dimensions,
1473                                         const DeviceMemory<float> &input_data,
1474                                         DeviceMemory<float> *output_data,
1475                                         uint64 options) {
1476   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1477             PARAM(output_data), PARAM(options));
1478 
1479   if (ok()) {
1480     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1481       CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1482                                  output_data, options));
1483     } else {
1484       SetErrorAndLogNoDnnSupport();
1485     }
1486   }
1487   return *this;
1488 }
1489 
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1490 Stream &Stream::ThenDepthConcatenate(
1491     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1492     port::ArraySlice<const DeviceMemory<float> *> input_data,
1493     DeviceMemory<float> *output_data) {
1494   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1495 
1496   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1497     if (input_dimensions[i].count() != input_dimensions[0].count() ||
1498         input_dimensions[i].height() != input_dimensions[0].height() ||
1499         input_dimensions[i].width() != input_dimensions[0].width()) {
1500       SetError();
1501       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1502                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1503                  << "input_dimensions[" << i
1504                  << "]: " << input_dimensions[i].ToString();
1505       return *this;
1506     }
1507   }
1508 
1509   if (ok()) {
1510     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1511       CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1512                                          output_data));
1513     } else {
1514       SetErrorAndLogNoDnnSupport();
1515     }
1516   }
1517   return *this;
1518 }
1519 
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1520 Stream &Stream::ThenSpaceConcatenate(
1521     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1522     port::ArraySlice<const DeviceMemory<float> *> input_data,
1523     DeviceMemory<float> *output_data,
1524     dnn::SpaceConcatenateMode concat_direction) {
1525   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1526 
1527   // Check that the input dimensions of all the other batches match those of the
1528   // first batch.
1529   for (size_t i = 1; i < input_dimensions.size(); ++i) {
1530     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1531         (input_dimensions[i].count() != input_dimensions[0].count() ||
1532          input_dimensions[i].height() != input_dimensions[0].height() ||
1533          input_dimensions[i].feature_map_count() !=
1534              input_dimensions[0].feature_map_count())) {
1535       SetError();
1536       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1537                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1538                  << "input_dimensions[" << i
1539                  << "]: " << input_dimensions[i].ToString();
1540       return *this;
1541     }
1542 
1543     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1544         (input_dimensions[i].count() != input_dimensions[0].count() ||
1545          input_dimensions[i].width() != input_dimensions[0].width() ||
1546          input_dimensions[i].feature_map_count() !=
1547              input_dimensions[0].feature_map_count())) {
1548       SetError();
1549       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1550                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
1551                  << "input_dimensions[" << i
1552                  << "]: " << input_dimensions[i].ToString();
1553       return *this;
1554     }
1555   }
1556   if (ok()) {
1557     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1558       CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1559                                          output_data, concat_direction));
1560     } else {
1561       SetErrorAndLogNoDnnSupport();
1562     }
1563   }
1564   return *this;
1565 }
1566 
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1567 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1568                             const DeviceMemory<float> &input_data,
1569                             const dnn::BatchDescriptor &output_dimensions,
1570                             DeviceMemory<float> *output_data) {
1571   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1572             PARAM(output_dimensions), PARAM(output_data));
1573 
1574   if (ok()) {
1575     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1576       CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1577                                 output_dimensions, output_data));
1578     } else {
1579       SetErrorAndLogNoDnnSupport();
1580     }
1581   }
1582   return *this;
1583 }
1584 
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1585 Stream &Stream::ThenDepthToSpace(
1586     const dnn::BatchDescriptor &input_dimensions,
1587     const DeviceMemory<float> &input_data,
1588     const dnn::DepthToSpaceLayout &depth_to_space_layout,
1589     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1590   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1591             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1592             PARAM(output_data));
1593 
1594   if (ok()) {
1595     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1596       CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1597                                      depth_to_space_layout,
1598                                      sqrt_depth_reduction, output_data));
1599     } else {
1600       SetErrorAndLogNoDnnSupport();
1601     }
1602   }
1603   return *this;
1604 }
1605 
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1606 Stream &Stream::ThenSpaceToDepth(
1607     const dnn::BatchDescriptor &input_dimensions,
1608     const DeviceMemory<float> &input_data,
1609     const dnn::DepthToSpaceLayout &space_to_depth_layout,
1610     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1611   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1612             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1613             PARAM(output_data));
1614 
1615   if (ok()) {
1616     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1617       CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1618                                      space_to_depth_layout, sqrt_depth_increase,
1619                                      output_data));
1620     } else {
1621       SetErrorAndLogNoDnnSupport();
1622     }
1623   }
1624   return *this;
1625 }
1626 
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1627 Stream &Stream::ThenElementwiseOperate(
1628     dnn::ElementwiseOperation operation,
1629     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1630     port::ArraySlice<const DeviceMemory<float> *> input_data,
1631     const dnn::BatchDescriptor &output_dimensions,
1632     DeviceMemory<float> *output_data) {
1633   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1634             PARAM(output_dimensions), PARAM(output_data));
1635 
1636   if (ok()) {
1637     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1638       CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1639                                            input_data, output_dimensions,
1640                                            output_data));
1641     } else {
1642       SetErrorAndLogNoDnnSupport();
1643     }
1644   }
1645   return *this;
1646 }
1647 
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1648 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1649     dnn::ElementwiseOperation operation,
1650     port::ArraySlice<int> input_multiplicands, int output_divisor,
1651     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1652     port::ArraySlice<const DeviceMemory<float> *> input_data,
1653     const dnn::BatchDescriptor &output_dimensions,
1654     DeviceMemory<float> *output_data) {
1655   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1656             PARAM(input_dimensions), PARAM(input_data),
1657             PARAM(output_dimensions), PARAM(output_data));
1658 
1659   if (ok()) {
1660     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1661       CheckError(dnn->DoElementwiseOperateScaledQuantized(
1662           this, operation, input_multiplicands, output_divisor,
1663           input_dimensions, input_data, output_dimensions, output_data));
1664     } else {
1665       SetErrorAndLogNoDnnSupport();
1666     }
1667   }
1668   return *this;
1669 }
1670 
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1671 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1672                           const DeviceMemory<float> &input_data, int64 left_pad,
1673                           int64 right_pad, int64 top_pad, int64 bottom_pad,
1674                           DeviceMemory<float> *output_data) {
1675   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1676             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1677             PARAM(output_data));
1678 
1679   if (ok()) {
1680     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1681       CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1682                               top_pad, bottom_pad, output_data));
1683     } else {
1684       SetErrorAndLogNoDnnSupport();
1685     }
1686   }
1687   return *this;
1688 }
1689 
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1690 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1691                             const DeviceMemory<float> &input_data,
1692                             int64 left_trim, int64 right_trim, int64 top_trim,
1693                             int64 bottom_trim,
1694                             DeviceMemory<float> *output_data) {
1695   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1696             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1697             PARAM(output_data));
1698 
1699   if (ok()) {
1700     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1701       CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1702                                 right_trim, top_trim, bottom_trim,
1703                                 output_data));
1704     } else {
1705       SetErrorAndLogNoDnnSupport();
1706     }
1707   }
1708   return *this;
1709 }
1710 
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1711 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1712                                 const DeviceMemory<float> &input_data,
1713                                 int64 replicate_x, int64 replicate_y,
1714                                 DeviceMemory<float> *output_data) {
1715   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1716             PARAM(replicate_y), PARAM(output_data));
1717 
1718   if (ok()) {
1719     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1720       CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1721                                     replicate_y, output_data));
1722     } else {
1723       SetErrorAndLogNoDnnSupport();
1724     }
1725   }
1726   return *this;
1727 }
1728 
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1729 Stream &Stream::ThenMemcpyD2HQuantized(
1730     const DeviceMemory<float> &gpu_unquantized_src,
1731     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1732   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1733             PARAM(size));
1734 
1735   if (ok()) {
1736     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1737       CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1738                                            host_dst, size));
1739     } else {
1740       SetErrorAndLogNoDnnSupport();
1741     }
1742   }
1743   return *this;
1744 }
1745 
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1746 Stream &Stream::ThenMemcpyH2DQuantized(
1747     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1748     DeviceMemory<float> *gpu_unquantized_dst) {
1749   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1750             PARAM(gpu_unquantized_dst));
1751 
1752   if (ok()) {
1753     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1754       CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1755                                            gpu_unquantized_dst));
1756     } else {
1757       SetErrorAndLogNoDnnSupport();
1758     }
1759   }
1760   return *this;
1761 }
1762 
ThenCopyHostBuffer2Device(HostBuffer * buffer_src,DeviceMemory<float> * gpu_unquantized_dst)1763 Stream &Stream::ThenCopyHostBuffer2Device(
1764     HostBuffer *buffer_src, DeviceMemory<float> *gpu_unquantized_dst) {
1765   VLOG_CALL(PARAM(*buffer_src), PARAM(gpu_unquantized_dst));
1766 
1767   if (ok()) {
1768     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1769       CheckError(
1770           dnn->DoCopyHostBuffer2Device(this, buffer_src, gpu_unquantized_dst));
1771     } else {
1772       SetErrorAndLogNoDnnSupport();
1773     }
1774   }
1775   return *this;
1776 }
1777 
ThenCopyDevice2HostBuffer(const DeviceMemory<float> & gpu_unquantized_src,HostBuffer * buffer_dst)1778 Stream &Stream::ThenCopyDevice2HostBuffer(
1779     const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst) {
1780   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(*buffer_dst));
1781 
1782   if (ok()) {
1783     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1784       CheckError(
1785           dnn->DoCopyDevice2HostBuffer(this, gpu_unquantized_src, buffer_dst));
1786     } else {
1787       SetErrorAndLogNoDnnSupport();
1788     }
1789   }
1790   return *this;
1791 }
1792 
GetOrCreateSubStream()1793 Stream *Stream::GetOrCreateSubStream() {
1794   mutex_lock lock{mu_};
1795   for (auto &stream : sub_streams_) {
1796     if (stream.second) {
1797       stream.second = false;
1798       return stream.first.get();
1799     }
1800   }
1801   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1802                             false);
1803   Stream *sub_stream = sub_streams_.back().first.get();
1804   sub_stream->Init();
1805   CHECK(ok_) << "sub-stream failed to be initialized";
1806 
1807   return sub_stream;
1808 }
1809 
ReturnSubStream(Stream * sub_stream)1810 void Stream::ReturnSubStream(Stream *sub_stream) {
1811   mutex_lock lock{mu_};
1812   for (auto &stream : sub_streams_) {
1813     if (stream.first.get() == sub_stream) {
1814       stream.second = true;
1815       return;
1816     }
1817   }
1818   LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
1819 }
1820 
ThenStartTimer(Timer * t)1821 Stream &Stream::ThenStartTimer(Timer *t) {
1822   VLOG_CALL(PARAM(t));
1823 
1824   if (ok()) {
1825     CheckError(parent_->StartTimer(this, t));
1826   } else {
1827     LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
1828   }
1829   return *this;
1830 }
1831 
ThenStopTimer(Timer * t)1832 Stream &Stream::ThenStopTimer(Timer *t) {
1833   VLOG_CALL(PARAM(t));
1834 
1835   if (ok()) {
1836     CheckError(parent_->StopTimer(this, t));
1837   } else {
1838     LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
1839   }
1840   return *this;
1841 }
1842 
ThenWaitFor(Stream * other)1843 Stream &Stream::ThenWaitFor(Stream *other) {
1844   VLOG_CALL(PARAM(other));
1845 
1846   CHECK(this != other) << "stream cannot wait for itself";
1847   if (ok() && other->ok()) {
1848     CheckError(parent_->CreateStreamDependency(this, other));
1849   } else {
1850     SetError();
1851     LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
1852   }
1853   return *this;
1854 }
1855 
ThenWaitFor(Event * event)1856 Stream &Stream::ThenWaitFor(Event *event) {
1857   VLOG_CALL(PARAM(event));
1858 
1859   if (ok()) {
1860     port::Status status = parent_->WaitForEvent(this, event);
1861     if (!status.ok()) {
1862       LOG(ERROR) << "Error waiting for event in stream: "
1863                  << status.error_message()
1864                  << "; not marking stream as bad, as the Event object may be "
1865                  << "at fault. Monitor for further errors.";
1866     }
1867   } else {
1868     LOG(INFO) << "stream " << this << " did not wait for an event.";
1869   }
1870   return *this;
1871 }
1872 
1873 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1874 // functions and logs for errors.
1875 template <typename... Args>
1876 struct ThenBlasImpl {
1877   // blas_func is the DoBlasXXX member function pointer, and args are its
1878   // arguments except the first one of Stream* type.
operator ()perftools::gputools::ThenBlasImpl1879   Stream &operator()(Stream *stream,
1880                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1881                      Args... args) {
1882     return Run(stream, blas_func, /*record_error=*/true, args...);
1883   }
1884 
1885   // Like operator(), but only calls stream->CheckError() if record_error is
1886   // true.
1887   Stream &Run(Stream *stream,
1888               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1889               bool record_error, Args... args);
1890 };
1891 
1892 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1893 Stream &ThenBlasImpl<Args...>::Run(
1894     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1895     bool record_error, Args... args) {
1896   if (stream->ok()) {
1897     bool ok;
1898     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1899       ok = (blas->*blas_func)(stream, args...);
1900     } else {
1901       LOG(WARNING)
1902           << "attempting to perform BLAS operation using StreamExecutor "
1903              "without BLAS support";
1904       ok = false;
1905     }
1906     if (record_error) {
1907       stream->CheckError(ok);
1908     }
1909   }
1910   return *stream;
1911 }
1912 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1913 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1914                              int incx, DeviceMemory<float> *result) {
1915   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1916 
1917   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1918       impl;
1919   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1920               result);
1921 }
1922 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1923 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1924                              int incx, DeviceMemory<double> *result) {
1925   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1926 
1927   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1928                DeviceMemory<double> *> impl;
1929   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1930               result);
1931 }
1932 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1933 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1934                              const DeviceMemory<std::complex<float>> &x,
1935                              int incx, DeviceMemory<float> *result) {
1936   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1937 
1938   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1939                DeviceMemory<float> *> impl;
1940   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1941               result);
1942 }
1943 
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1944 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1945                              const DeviceMemory<std::complex<double>> &x,
1946                              int incx, DeviceMemory<double> *result) {
1947   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1948 
1949   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1950                DeviceMemory<double> *> impl;
1951   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1952               result);
1953 }
1954 
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1955 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1956                              const DeviceMemory<float> &x, int incx,
1957                              DeviceMemory<float> *y, int incy) {
1958   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1959             PARAM(incy));
1960 
1961   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1962                DeviceMemory<float> *, int> impl;
1963   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1964               y, incy);
1965 }
1966 
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1967 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1968                              const DeviceMemory<double> &x, int incx,
1969                              DeviceMemory<double> *y, int incy) {
1970   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1971             PARAM(incy));
1972 
1973   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1974                DeviceMemory<double> *, int> impl;
1975   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1976               y, incy);
1977 }
1978 
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1979 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1980                              const DeviceMemory<std::complex<float>> &x,
1981                              int incx, DeviceMemory<std::complex<float>> *y,
1982                              int incy) {
1983   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1984             PARAM(incy));
1985 
1986   ThenBlasImpl<uint64, std::complex<float>,
1987                const DeviceMemory<std::complex<float>> &, int,
1988                DeviceMemory<std::complex<float>> *, int> impl;
1989   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1990               y, incy);
1991 }
1992 
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1993 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1994                              const DeviceMemory<std::complex<double>> &x,
1995                              int incx, DeviceMemory<std::complex<double>> *y,
1996                              int incy) {
1997   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1998             PARAM(incy));
1999 
2000   ThenBlasImpl<uint64, std::complex<double>,
2001                const DeviceMemory<std::complex<double>> &, int,
2002                DeviceMemory<std::complex<double>> *, int> impl;
2003   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2004               y, incy);
2005 }
2006 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2007 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2008                              int incx, DeviceMemory<float> *y, int incy) {
2009   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2010 
2011   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2012                int> impl;
2013   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2014               incy);
2015 }
2016 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2017 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2018                              int incx, DeviceMemory<double> *y, int incy) {
2019   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2020 
2021   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2022                DeviceMemory<double> *, int> impl;
2023   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2024               incy);
2025 }
2026 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2027 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2028                              const DeviceMemory<std::complex<float>> &x,
2029                              int incx, DeviceMemory<std::complex<float>> *y,
2030                              int incy) {
2031   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2032 
2033   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2034                DeviceMemory<std::complex<float>> *, int> impl;
2035   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2036               incy);
2037 }
2038 
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2039 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2040                              const DeviceMemory<std::complex<double>> &x,
2041                              int incx, DeviceMemory<std::complex<double>> *y,
2042                              int incy) {
2043   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2044 
2045   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2046                DeviceMemory<std::complex<double>> *, int> impl;
2047   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2048               incy);
2049 }
2050 
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2051 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2052                             int incx, const DeviceMemory<float> &y, int incy,
2053                             DeviceMemory<float> *result) {
2054   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2055             PARAM(result));
2056 
2057   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2058                const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2059   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2060               result);
2061 }
2062 
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2063 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2064                             int incx, const DeviceMemory<double> &y, int incy,
2065                             DeviceMemory<double> *result) {
2066   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2067             PARAM(result));
2068 
2069   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2070                const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2071   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2072               result);
2073 }
2074 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2075 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2076                              const DeviceMemory<std::complex<float>> &x,
2077                              int incx,
2078                              const DeviceMemory<std::complex<float>> &y,
2079                              int incy,
2080                              DeviceMemory<std::complex<float>> *result) {
2081   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2082             PARAM(result));
2083 
2084   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2085                const DeviceMemory<std::complex<float>> &, int,
2086                DeviceMemory<std::complex<float>> *> impl;
2087   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2088               incy, result);
2089 }
2090 
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2091 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2092                              const DeviceMemory<std::complex<double>> &x,
2093                              int incx,
2094                              const DeviceMemory<std::complex<double>> &y,
2095                              int incy,
2096                              DeviceMemory<std::complex<double>> *result) {
2097   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2098             PARAM(result));
2099 
2100   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2101                const DeviceMemory<std::complex<double>> &, int,
2102                DeviceMemory<std::complex<double>> *> impl;
2103   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2104               incy, result);
2105 }
2106 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2107 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2108                              const DeviceMemory<std::complex<float>> &x,
2109                              int incx,
2110                              const DeviceMemory<std::complex<float>> &y,
2111                              int incy,
2112                              DeviceMemory<std::complex<float>> *result) {
2113   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2114             PARAM(result));
2115 
2116   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2117                const DeviceMemory<std::complex<float>> &, int,
2118                DeviceMemory<std::complex<float>> *> impl;
2119   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2120               incy, result);
2121 }
2122 
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2123 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2124                              const DeviceMemory<std::complex<double>> &x,
2125                              int incx,
2126                              const DeviceMemory<std::complex<double>> &y,
2127                              int incy,
2128                              DeviceMemory<std::complex<double>> *result) {
2129   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2130             PARAM(result));
2131 
2132   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2133                const DeviceMemory<std::complex<double>> &, int,
2134                DeviceMemory<std::complex<double>> *> impl;
2135   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2136               incy, result);
2137 }
2138 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2139 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2140                              int incx, DeviceMemory<float> *result) {
2141   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2142 
2143   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2144       impl;
2145   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2146               result);
2147 }
2148 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2149 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2150                              int incx, DeviceMemory<double> *result) {
2151   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2152 
2153   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2154                DeviceMemory<double> *> impl;
2155   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2156               result);
2157 }
2158 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2159 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2160                              const DeviceMemory<std::complex<float>> &x,
2161                              int incx, DeviceMemory<float> *result) {
2162   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2163 
2164   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2165                DeviceMemory<float> *> impl;
2166   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2167               result);
2168 }
2169 
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2170 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2171                              const DeviceMemory<std::complex<double>> &x,
2172                              int incx, DeviceMemory<double> *result) {
2173   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2174 
2175   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2176                DeviceMemory<double> *> impl;
2177   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2178               result);
2179 }
2180 
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2181 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2182                             DeviceMemory<float> *y, int incy, float c,
2183                             float s) {
2184   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2185             PARAM(c), PARAM(s));
2186 
2187   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2188                float, float> impl;
2189   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2190               c, s);
2191 }
2192 
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2193 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2194                             int incx, DeviceMemory<double> *y, int incy,
2195                             double c, double s) {
2196   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2197             PARAM(c), PARAM(s));
2198 
2199   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2200                double, double> impl;
2201   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2202               c, s);
2203 }
2204 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2205 Stream &Stream::ThenBlasRot(uint64 elem_count,
2206                             DeviceMemory<std::complex<float>> *x, int incx,
2207                             DeviceMemory<std::complex<float>> *y, int incy,
2208                             float c, float s) {
2209   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2210             PARAM(c), PARAM(s));
2211 
2212   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2213                DeviceMemory<std::complex<float>> *, int, float, float> impl;
2214   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2215               c, s);
2216 }
2217 
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2218 Stream &Stream::ThenBlasRot(uint64 elem_count,
2219                             DeviceMemory<std::complex<double>> *x, int incx,
2220                             DeviceMemory<std::complex<double>> *y, int incy,
2221                             double c, double s) {
2222   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2223             PARAM(c), PARAM(s));
2224 
2225   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2226                DeviceMemory<std::complex<double>> *, int, double, double> impl;
2227   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2228               c, s);
2229 }
2230 
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2231 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2232                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
2233   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2234 
2235   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2236                DeviceMemory<float> *, DeviceMemory<float> *> impl;
2237   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2238 }
2239 
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2240 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2241                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
2242   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2243 
2244   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2245                DeviceMemory<double> *, DeviceMemory<double> *> impl;
2246   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2247 }
2248 
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2249 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2250                              DeviceMemory<std::complex<float>> *b,
2251                              DeviceMemory<float> *c,
2252                              DeviceMemory<std::complex<float>> *s) {
2253   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2254 
2255   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2256                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2257                DeviceMemory<std::complex<float>> *> impl;
2258   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2259 }
2260 
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2261 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2262                              DeviceMemory<std::complex<double>> *b,
2263                              DeviceMemory<double> *c,
2264                              DeviceMemory<std::complex<double>> *s) {
2265   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2266 
2267   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2268                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2269                DeviceMemory<std::complex<double>> *> impl;
2270   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2271 }
2272 
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2273 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2274                              int incx, DeviceMemory<float> *y, int incy,
2275                              const DeviceMemory<float> &param) {
2276   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2277             PARAM(param));
2278 
2279   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2280                const DeviceMemory<float> &> impl;
2281   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2282               incy, param);
2283 }
2284 
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2285 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2286                              int incx, DeviceMemory<double> *y, int incy,
2287                              const DeviceMemory<double> &param) {
2288   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2289             PARAM(param));
2290 
2291   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2292                const DeviceMemory<double> &> impl;
2293   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2294               incy, param);
2295 }
2296 
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2297 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2298                               DeviceMemory<float> *x1,
2299                               const DeviceMemory<float> &y1,
2300                               DeviceMemory<float> *param) {
2301   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2302 
2303   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2304                DeviceMemory<float> *, const DeviceMemory<float> &,
2305                DeviceMemory<float> *> impl;
2306   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2307 }
2308 
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2309 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2310                               DeviceMemory<double> *d2,
2311                               DeviceMemory<double> *x1,
2312                               const DeviceMemory<double> &y1,
2313                               DeviceMemory<double> *param) {
2314   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2315 
2316   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2317                DeviceMemory<double> *, const DeviceMemory<double> &,
2318                DeviceMemory<double> *> impl;
2319   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2320 }
2321 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2322 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2323                              DeviceMemory<float> *x, int incx) {
2324   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2325 
2326   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2327   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2328 }
2329 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2330 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2331                              DeviceMemory<double> *x, int incx) {
2332   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2333 
2334   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2335   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2336 }
2337 
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2338 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2339                              DeviceMemory<std::complex<float>> *x, int incx) {
2340   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2341 
2342   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2343   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2344 }
2345 
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2346 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2347                              DeviceMemory<std::complex<double>> *x, int incx) {
2348   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2349 
2350   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2351   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2352 }
2353 
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2354 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2355                              DeviceMemory<std::complex<float>> *x, int incx) {
2356   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2357 
2358   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2359                int> impl;
2360   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2361 }
2362 
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2363 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2364                              DeviceMemory<std::complex<double>> *x, int incx) {
2365   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2366 
2367   ThenBlasImpl<uint64, std::complex<double>,
2368                DeviceMemory<std::complex<double>> *, int> impl;
2369   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2370 }
2371 
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2372 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2373                              int incx, DeviceMemory<float> *y, int incy) {
2374   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2375 
2376   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2377       impl;
2378   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2379               incy);
2380 }
2381 
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2382 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2383                              int incx, DeviceMemory<double> *y, int incy) {
2384   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2385 
2386   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2387       impl;
2388   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2389               incy);
2390 }
2391 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2392 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2393                              DeviceMemory<std::complex<float>> *x, int incx,
2394                              DeviceMemory<std::complex<float>> *y, int incy) {
2395   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2396 
2397   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2398                DeviceMemory<std::complex<float>> *, int> impl;
2399   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2400               incy);
2401 }
2402 
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2403 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2404                              DeviceMemory<std::complex<double>> *x, int incx,
2405                              DeviceMemory<std::complex<double>> *y, int incy) {
2406   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2407 
2408   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2409                DeviceMemory<std::complex<double>> *, int> impl;
2410   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2411               incy);
2412 }
2413 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2414 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2415                               int incx, DeviceMemory<int> *result) {
2416   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2417 
2418   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2419       impl;
2420   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2421               result);
2422 }
2423 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2424 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2425                               int incx, DeviceMemory<int> *result) {
2426   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2427 
2428   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2429       impl;
2430   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2431               result);
2432 }
2433 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2434 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2435                               const DeviceMemory<std::complex<float>> &x,
2436                               int incx, DeviceMemory<int> *result) {
2437   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2438 
2439   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2440                DeviceMemory<int> *> impl;
2441   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2442               result);
2443 }
2444 
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2445 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2446                               const DeviceMemory<std::complex<double>> &x,
2447                               int incx, DeviceMemory<int> *result) {
2448   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2449 
2450   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2451                DeviceMemory<int> *> impl;
2452   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2453               result);
2454 }
2455 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2456 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2457                               int incx, DeviceMemory<int> *result) {
2458   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2459 
2460   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2461       impl;
2462   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2463               result);
2464 }
2465 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2466 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2467                               int incx, DeviceMemory<int> *result) {
2468   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2469 
2470   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2471       impl;
2472   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2473               result);
2474 }
2475 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2476 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2477                               const DeviceMemory<std::complex<float>> &x,
2478                               int incx, DeviceMemory<int> *result) {
2479   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2480 
2481   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2482                DeviceMemory<int> *> impl;
2483   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2484               result);
2485 }
2486 
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2487 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2488                               const DeviceMemory<std::complex<double>> &x,
2489                               int incx, DeviceMemory<int> *result) {
2490   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2491 
2492   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2493                DeviceMemory<int> *> impl;
2494   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2495               result);
2496 }
2497 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2498 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2499                              uint64 kl, uint64 ku, float alpha,
2500                              const DeviceMemory<float> &a, int lda,
2501                              const DeviceMemory<float> &x, int incx, float beta,
2502                              DeviceMemory<float> *y, int incy) {
2503   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2504             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2505             PARAM(beta), PARAM(y), PARAM(incy));
2506 
2507   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2508                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2509                int, float, DeviceMemory<float> *, int> impl;
2510   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2511               a, lda, x, incx, beta, y, incy);
2512 }
2513 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2514 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2515                              uint64 kl, uint64 ku, double alpha,
2516                              const DeviceMemory<double> &a, int lda,
2517                              const DeviceMemory<double> &x, int incx,
2518                              double beta, DeviceMemory<double> *y, int incy) {
2519   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2520             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2521             PARAM(beta), PARAM(y), PARAM(incy));
2522 
2523   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2524                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2525                int, double, DeviceMemory<double> *, int> impl;
2526   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2527               a, lda, x, incx, beta, y, incy);
2528 }
2529 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2530 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2531                              uint64 kl, uint64 ku, std::complex<float> alpha,
2532                              const DeviceMemory<std::complex<float>> &a,
2533                              int lda,
2534                              const DeviceMemory<std::complex<float>> &x,
2535                              int incx, std::complex<float> beta,
2536                              DeviceMemory<std::complex<float>> *y, int incy) {
2537   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2538             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2539             PARAM(beta), PARAM(y), PARAM(incy));
2540 
2541   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2542                std::complex<float>, const DeviceMemory<std::complex<float>> &,
2543                int, const DeviceMemory<std::complex<float>> &, int,
2544                std::complex<float>, DeviceMemory<std::complex<float>> *,
2545                int> impl;
2546   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2547               a, lda, x, incx, beta, y, incy);
2548 }
2549 
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2550 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2551                              uint64 kl, uint64 ku, std::complex<double> alpha,
2552                              const DeviceMemory<std::complex<double>> &a,
2553                              int lda,
2554                              const DeviceMemory<std::complex<double>> &x,
2555                              int incx, std::complex<double> beta,
2556                              DeviceMemory<std::complex<double>> *y, int incy) {
2557   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2558             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2559             PARAM(beta), PARAM(y), PARAM(incy));
2560 
2561   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2562                std::complex<double>, const DeviceMemory<std::complex<double>> &,
2563                int, const DeviceMemory<std::complex<double>> &, int,
2564                std::complex<double>, DeviceMemory<std::complex<double>> *,
2565                int> impl;
2566   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2567               a, lda, x, incx, beta, y, incy);
2568 }
2569 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2570 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2571                              float alpha, const DeviceMemory<float> &a, int lda,
2572                              const DeviceMemory<float> &x, int incx, float beta,
2573                              DeviceMemory<float> *y, int incy) {
2574   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2575             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2576             PARAM(incy));
2577 
2578   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2579                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2580                int, float, DeviceMemory<float> *, int> impl;
2581   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2582               x, incx, beta, y, incy);
2583 }
2584 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2585 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2586                              double alpha, const DeviceMemory<double> &a,
2587                              int lda, const DeviceMemory<double> &x, int incx,
2588                              double beta, DeviceMemory<double> *y, int incy) {
2589   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2590             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2591             PARAM(incy));
2592 
2593   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2594                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2595                int, double, DeviceMemory<double> *, int> impl;
2596   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2597               x, incx, beta, y, incy);
2598 }
2599 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2600 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2601                              std::complex<float> alpha,
2602                              const DeviceMemory<std::complex<float>> &a,
2603                              int lda,
2604                              const DeviceMemory<std::complex<float>> &x,
2605                              int incx, std::complex<float> beta,
2606                              DeviceMemory<std::complex<float>> *y, int incy) {
2607   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2608             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2609             PARAM(incy));
2610 
2611   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2612                const DeviceMemory<std::complex<float>> &, int,
2613                const DeviceMemory<std::complex<float>> &, int,
2614                std::complex<float>, DeviceMemory<std::complex<float>> *,
2615                int> impl;
2616   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2617               x, incx, beta, y, incy);
2618 }
2619 
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2620 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2621                              std::complex<double> alpha,
2622                              const DeviceMemory<std::complex<double>> &a,
2623                              int lda,
2624                              const DeviceMemory<std::complex<double>> &x,
2625                              int incx, std::complex<double> beta,
2626                              DeviceMemory<std::complex<double>> *y, int incy) {
2627   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2628             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2629             PARAM(incy));
2630 
2631   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2632                const DeviceMemory<std::complex<double>> &, int,
2633                const DeviceMemory<std::complex<double>> &, int,
2634                std::complex<double>, DeviceMemory<std::complex<double>> *,
2635                int> impl;
2636   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2637               x, incx, beta, y, incy);
2638 }
2639 
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2640 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2641                             const DeviceMemory<float> &x, int incx,
2642                             const DeviceMemory<float> &y, int incy,
2643                             DeviceMemory<float> *a, int lda) {
2644   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2645             PARAM(incy), PARAM(a), PARAM(lda));
2646 
2647   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2648                const DeviceMemory<float> &, int, DeviceMemory<float> *,
2649                int> impl;
2650   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2651               incy, a, lda);
2652 }
2653 
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2654 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2655                             const DeviceMemory<double> &x, int incx,
2656                             const DeviceMemory<double> &y, int incy,
2657                             DeviceMemory<double> *a, int lda) {
2658   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2659             PARAM(incy), PARAM(a), PARAM(lda));
2660 
2661   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2662                const DeviceMemory<double> &, int, DeviceMemory<double> *,
2663                int> impl;
2664   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2665               incy, a, lda);
2666 }
2667 
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2668 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2669                              const DeviceMemory<std::complex<float>> &x,
2670                              int incx,
2671                              const DeviceMemory<std::complex<float>> &y,
2672                              int incy, DeviceMemory<std::complex<float>> *a,
2673                              int lda) {
2674   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2675             PARAM(incy), PARAM(a), PARAM(lda));
2676 
2677   ThenBlasImpl<uint64, uint64, std::complex<float>,
2678                const DeviceMemory<std::complex<float>> &, int,
2679                const DeviceMemory<std::complex<float>> &, int,
2680                DeviceMemory<std::complex<float>> *, int> impl;
2681   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2682               incy, a, lda);
2683 }
2684 
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2685 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2686                              const DeviceMemory<std::complex<double>> &x,
2687                              int incx,
2688                              const DeviceMemory<std::complex<double>> &y,
2689                              int incy, DeviceMemory<std::complex<double>> *a,
2690                              int lda) {
2691   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2692             PARAM(incy), PARAM(a), PARAM(lda));
2693 
2694   ThenBlasImpl<uint64, uint64, std::complex<double>,
2695                const DeviceMemory<std::complex<double>> &, int,
2696                const DeviceMemory<std::complex<double>> &, int,
2697                DeviceMemory<std::complex<double>> *, int> impl;
2698   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2699               incy, a, lda);
2700 }
2701 
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2702 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2703                              const DeviceMemory<std::complex<float>> &x,
2704                              int incx,
2705                              const DeviceMemory<std::complex<float>> &y,
2706                              int incy, DeviceMemory<std::complex<float>> *a,
2707                              int lda) {
2708   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2709             PARAM(incy), PARAM(a), PARAM(lda));
2710 
2711   ThenBlasImpl<uint64, uint64, std::complex<float>,
2712                const DeviceMemory<std::complex<float>> &, int,
2713                const DeviceMemory<std::complex<float>> &, int,
2714                DeviceMemory<std::complex<float>> *, int> impl;
2715   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2716               incy, a, lda);
2717 }
2718 
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2719 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2720                              const DeviceMemory<std::complex<double>> &x,
2721                              int incx,
2722                              const DeviceMemory<std::complex<double>> &y,
2723                              int incy, DeviceMemory<std::complex<double>> *a,
2724                              int lda) {
2725   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2726             PARAM(incy), PARAM(a), PARAM(lda));
2727 
2728   ThenBlasImpl<uint64, uint64, std::complex<double>,
2729                const DeviceMemory<std::complex<double>> &, int,
2730                const DeviceMemory<std::complex<double>> &, int,
2731                DeviceMemory<std::complex<double>> *, int> impl;
2732   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2733               incy, a, lda);
2734 }
2735 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2736 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2737                              std::complex<float> alpha,
2738                              const DeviceMemory<std::complex<float>> &a,
2739                              int lda,
2740                              const DeviceMemory<std::complex<float>> &x,
2741                              int incx, std::complex<float> beta,
2742                              DeviceMemory<std::complex<float>> *y, int incy) {
2743   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2744             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2745 
2746   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2747                const DeviceMemory<std::complex<float>> &, int,
2748                const DeviceMemory<std::complex<float>> &, int,
2749                std::complex<float>, DeviceMemory<std::complex<float>> *,
2750                int> impl;
2751   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2752               x, incx, beta, y, incy);
2753 }
2754 
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2755 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2756                              std::complex<double> alpha,
2757                              const DeviceMemory<std::complex<double>> &a,
2758                              int lda,
2759                              const DeviceMemory<std::complex<double>> &x,
2760                              int incx, std::complex<double> beta,
2761                              DeviceMemory<std::complex<double>> *y, int incy) {
2762   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2763             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2764 
2765   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2766                const DeviceMemory<std::complex<double>> &, int,
2767                const DeviceMemory<std::complex<double>> &, int,
2768                std::complex<double>, DeviceMemory<std::complex<double>> *,
2769                int> impl;
2770   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2771               x, incx, beta, y, incy);
2772 }
2773 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2774 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2775                              std::complex<float> alpha,
2776                              const DeviceMemory<std::complex<float>> &a,
2777                              int lda,
2778                              const DeviceMemory<std::complex<float>> &x,
2779                              int incx, std::complex<float> beta,
2780                              DeviceMemory<std::complex<float>> *y, int incy) {
2781   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2782             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2783 
2784   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2785                const DeviceMemory<std::complex<float>> &, int,
2786                const DeviceMemory<std::complex<float>> &, int,
2787                std::complex<float>, DeviceMemory<std::complex<float>> *,
2788                int> impl;
2789   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2790               incx, beta, y, incy);
2791 }
2792 
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2793 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2794                              std::complex<double> alpha,
2795                              const DeviceMemory<std::complex<double>> &a,
2796                              int lda,
2797                              const DeviceMemory<std::complex<double>> &x,
2798                              int incx, std::complex<double> beta,
2799                              DeviceMemory<std::complex<double>> *y, int incy) {
2800   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2801             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2802 
2803   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2804                const DeviceMemory<std::complex<double>> &, int,
2805                const DeviceMemory<std::complex<double>> &, int,
2806                std::complex<double>, DeviceMemory<std::complex<double>> *,
2807                int> impl;
2808   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2809               incx, beta, y, incy);
2810 }
2811 
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2812 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2813                             const DeviceMemory<std::complex<float>> &x,
2814                             int incx, DeviceMemory<std::complex<float>> *a,
2815                             int lda) {
2816   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2817             PARAM(a), PARAM(lda));
2818 
2819   ThenBlasImpl<blas::UpperLower, uint64, float,
2820                const DeviceMemory<std::complex<float>> &, int,
2821                DeviceMemory<std::complex<float>> *, int> impl;
2822   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2823               lda);
2824 }
2825 
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2826 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2827                             const DeviceMemory<std::complex<double>> &x,
2828                             int incx, DeviceMemory<std::complex<double>> *a,
2829                             int lda) {
2830   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2831             PARAM(a), PARAM(lda));
2832 
2833   ThenBlasImpl<blas::UpperLower, uint64, double,
2834                const DeviceMemory<std::complex<double>> &, int,
2835                DeviceMemory<std::complex<double>> *, int> impl;
2836   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2837               lda);
2838 }
2839 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2840 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2841                              std::complex<float> alpha,
2842                              const DeviceMemory<std::complex<float>> &x,
2843                              int incx,
2844                              const DeviceMemory<std::complex<float>> &y,
2845                              int incy, DeviceMemory<std::complex<float>> *a,
2846                              int lda) {
2847   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2848             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2849 
2850   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2851                const DeviceMemory<std::complex<float>> &, int,
2852                const DeviceMemory<std::complex<float>> &, int,
2853                DeviceMemory<std::complex<float>> *, int> impl;
2854   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2855               incy, a, lda);
2856 }
2857 
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2858 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2859                              std::complex<double> alpha,
2860                              const DeviceMemory<std::complex<double>> &x,
2861                              int incx,
2862                              const DeviceMemory<std::complex<double>> &y,
2863                              int incy, DeviceMemory<std::complex<double>> *a,
2864                              int lda) {
2865   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2866             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2867 
2868   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2869                const DeviceMemory<std::complex<double>> &, int,
2870                const DeviceMemory<std::complex<double>> &, int,
2871                DeviceMemory<std::complex<double>> *, int> impl;
2872   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2873               incy, a, lda);
2874 }
2875 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2876 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2877                              std::complex<float> alpha,
2878                              const DeviceMemory<std::complex<float>> &ap,
2879                              const DeviceMemory<std::complex<float>> &x,
2880                              int incx, std::complex<float> beta,
2881                              DeviceMemory<std::complex<float>> *y, int incy) {
2882   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2883             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2884 
2885   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2886                const DeviceMemory<std::complex<float>> &,
2887                const DeviceMemory<std::complex<float>> &, int,
2888                std::complex<float>, DeviceMemory<std::complex<float>> *,
2889                int> impl;
2890   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2891               beta, y, incy);
2892 }
2893 
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2894 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2895                              std::complex<double> alpha,
2896                              const DeviceMemory<std::complex<double>> &ap,
2897                              const DeviceMemory<std::complex<double>> &x,
2898                              int incx, std::complex<double> beta,
2899                              DeviceMemory<std::complex<double>> *y, int incy) {
2900   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2901             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2902 
2903   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2904                const DeviceMemory<std::complex<double>> &,
2905                const DeviceMemory<std::complex<double>> &, int,
2906                std::complex<double>, DeviceMemory<std::complex<double>> *,
2907                int> impl;
2908   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2909               beta, y, incy);
2910 }
2911 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2912 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2913                             const DeviceMemory<std::complex<float>> &x,
2914                             int incx, DeviceMemory<std::complex<float>> *ap) {
2915   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2916             PARAM(ap));
2917 
2918   ThenBlasImpl<blas::UpperLower, uint64, float,
2919                const DeviceMemory<std::complex<float>> &, int,
2920                DeviceMemory<std::complex<float>> *> impl;
2921   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2922 }
2923 
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2924 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2925                             const DeviceMemory<std::complex<double>> &x,
2926                             int incx, DeviceMemory<std::complex<double>> *ap) {
2927   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2928             PARAM(ap));
2929 
2930   ThenBlasImpl<blas::UpperLower, uint64, double,
2931                const DeviceMemory<std::complex<double>> &, int,
2932                DeviceMemory<std::complex<double>> *> impl;
2933   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2934 }
2935 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2936 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2937                              std::complex<float> alpha,
2938                              const DeviceMemory<std::complex<float>> &x,
2939                              int incx,
2940                              const DeviceMemory<std::complex<float>> &y,
2941                              int incy, DeviceMemory<std::complex<float>> *ap) {
2942   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2943             PARAM(y), PARAM(incy), PARAM(ap));
2944 
2945   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2946                const DeviceMemory<std::complex<float>> &, int,
2947                const DeviceMemory<std::complex<float>> &, int,
2948                DeviceMemory<std::complex<float>> *> impl;
2949   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2950               incy, ap);
2951 }
2952 
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2953 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2954                              std::complex<double> alpha,
2955                              const DeviceMemory<std::complex<double>> &x,
2956                              int incx,
2957                              const DeviceMemory<std::complex<double>> &y,
2958                              int incy, DeviceMemory<std::complex<double>> *ap) {
2959   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2960             PARAM(y), PARAM(incy), PARAM(ap));
2961 
2962   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2963                const DeviceMemory<std::complex<double>> &, int,
2964                const DeviceMemory<std::complex<double>> &, int,
2965                DeviceMemory<std::complex<double>> *> impl;
2966   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2967               incy, ap);
2968 }
2969 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2970 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2971                              float alpha, const DeviceMemory<float> &a, int lda,
2972                              const DeviceMemory<float> &x, int incx, float beta,
2973                              DeviceMemory<float> *y, int incy) {
2974   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2975             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2976 
2977   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2978                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2979                int, float, DeviceMemory<float> *, int> impl;
2980   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2981               x, incx, beta, y, incy);
2982 }
2983 
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2984 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2985                              double alpha, const DeviceMemory<double> &a,
2986                              int lda, const DeviceMemory<double> &x, int incx,
2987                              double beta, DeviceMemory<double> *y, int incy) {
2988   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2989             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2990 
2991   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2992                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2993                int, double, DeviceMemory<double> *, int> impl;
2994   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2995               x, incx, beta, y, incy);
2996 }
2997 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2998 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2999                              const DeviceMemory<float> &ap,
3000                              const DeviceMemory<float> &x, int incx, float beta,
3001                              DeviceMemory<float> *y, int incy) {
3002   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3003             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3004 
3005   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3006                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3007                int> impl;
3008   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3009               beta, y, incy);
3010 }
3011 
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3012 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3013                              const DeviceMemory<double> &ap,
3014                              const DeviceMemory<double> &x, int incx,
3015                              double beta, DeviceMemory<double> *y, int incy) {
3016   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3017             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3018 
3019   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3020                const DeviceMemory<double> &, int, double,
3021                DeviceMemory<double> *, int> impl;
3022   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3023               beta, y, incy);
3024 }
3025 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3026 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3027                             const DeviceMemory<float> &x, int incx,
3028                             DeviceMemory<float> *ap) {
3029   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3030             PARAM(ap));
3031 
3032   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3033                int, DeviceMemory<float> *> impl;
3034   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3035 }
3036 
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3037 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3038                             const DeviceMemory<double> &x, int incx,
3039                             DeviceMemory<double> *ap) {
3040   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3041             PARAM(ap));
3042 
3043   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3044                int, DeviceMemory<double> *> impl;
3045   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3046 }
3047 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3048 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3049                              const DeviceMemory<float> &x, int incx,
3050                              const DeviceMemory<float> &y, int incy,
3051                              DeviceMemory<float> *ap) {
3052   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3053             PARAM(y), PARAM(incy), PARAM(ap));
3054 
3055   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3056                int, const DeviceMemory<float> &, int,
3057                DeviceMemory<float> *> impl;
3058   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3059               incy, ap);
3060 }
3061 
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3062 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3063                              const DeviceMemory<double> &x, int incx,
3064                              const DeviceMemory<double> &y, int incy,
3065                              DeviceMemory<double> *ap) {
3066   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3067             PARAM(y), PARAM(incy), PARAM(ap));
3068 
3069   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3070                int, const DeviceMemory<double> &, int,
3071                DeviceMemory<double> *> impl;
3072   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3073               incy, ap);
3074 }
3075 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3076 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3077                              const DeviceMemory<float> &a, int lda,
3078                              const DeviceMemory<float> &x, int incx, float beta,
3079                              DeviceMemory<float> *y, int incy) {
3080   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3081             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3082 
3083   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3084                int, const DeviceMemory<float> &, int, float,
3085                DeviceMemory<float> *, int> impl;
3086   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3087               incx, beta, y, incy);
3088 }
3089 
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3090 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3091                              const DeviceMemory<double> &a, int lda,
3092                              const DeviceMemory<double> &x, int incx,
3093                              double beta, DeviceMemory<double> *y, int incy) {
3094   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3095             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3096 
3097   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3098                int, const DeviceMemory<double> &, int, double,
3099                DeviceMemory<double> *, int> impl;
3100   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3101               incx, beta, y, incy);
3102 }
3103 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3104 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3105                             const DeviceMemory<float> &x, int incx,
3106                             DeviceMemory<float> *a, int lda) {
3107   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3108             PARAM(a), PARAM(lda));
3109 
3110   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3111                int, DeviceMemory<float> *, int> impl;
3112   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3113               lda);
3114 }
3115 
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3116 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3117                             const DeviceMemory<double> &x, int incx,
3118                             DeviceMemory<double> *a, int lda) {
3119   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3120             PARAM(a), PARAM(lda));
3121 
3122   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3123                int, DeviceMemory<double> *, int> impl;
3124   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3125               lda);
3126 }
3127 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3128 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3129                              const DeviceMemory<float> &x, int incx,
3130                              const DeviceMemory<float> &y, int incy,
3131                              DeviceMemory<float> *a, int lda) {
3132   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3133             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3134 
3135   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3136                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3137                int> impl;
3138   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3139               incy, a, lda);
3140 }
3141 
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3142 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3143                              const DeviceMemory<double> &x, int incx,
3144                              const DeviceMemory<double> &y, int incy,
3145                              DeviceMemory<double> *a, int lda) {
3146   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3147             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3148 
3149   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3150                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3151                int> impl;
3152   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3153               incy, a, lda);
3154 }
3155 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3156 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3157                              blas::Diagonal diag, uint64 n, uint64 k,
3158                              const DeviceMemory<float> &a, int lda,
3159                              DeviceMemory<float> *x, int incx) {
3160   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3161             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3162 
3163   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3164                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3165                int> impl;
3166   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3167               lda, x, incx);
3168 }
3169 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3170 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3171                              blas::Diagonal diag, uint64 n, uint64 k,
3172                              const DeviceMemory<double> &a, int lda,
3173                              DeviceMemory<double> *x, int incx) {
3174   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3175             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3176 
3177   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3178                uint64, const DeviceMemory<double> &, int,
3179                DeviceMemory<double> *, int> impl;
3180   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3181               lda, x, incx);
3182 }
3183 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3184 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3185                              blas::Diagonal diag, uint64 n, uint64 k,
3186                              const DeviceMemory<std::complex<float>> &a,
3187                              int lda, DeviceMemory<std::complex<float>> *x,
3188                              int incx) {
3189   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3190             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3191 
3192   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3193                uint64, const DeviceMemory<std::complex<float>> &, int,
3194                DeviceMemory<std::complex<float>> *, int> impl;
3195   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3196               lda, x, incx);
3197 }
3198 
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3199 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3200                              blas::Diagonal diag, uint64 n, uint64 k,
3201                              const DeviceMemory<std::complex<double>> &a,
3202                              int lda, DeviceMemory<std::complex<double>> *x,
3203                              int incx) {
3204   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3205             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3206 
3207   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3208                uint64, const DeviceMemory<std::complex<double>> &, int,
3209                DeviceMemory<std::complex<double>> *, int> impl;
3210   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3211               lda, x, incx);
3212 }
3213 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3214 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3215                              blas::Diagonal diag, uint64 n, uint64 k,
3216                              const DeviceMemory<float> &a, int lda,
3217                              DeviceMemory<float> *x, int incx) {
3218   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3219             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3220 
3221   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3222                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3223                int> impl;
3224   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3225               lda, x, incx);
3226 }
3227 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3228 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3229                              blas::Diagonal diag, uint64 n, uint64 k,
3230                              const DeviceMemory<double> &a, int lda,
3231                              DeviceMemory<double> *x, int incx) {
3232   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3233             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3234 
3235   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3236                uint64, const DeviceMemory<double> &, int,
3237                DeviceMemory<double> *, int> impl;
3238   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3239               lda, x, incx);
3240 }
3241 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3242 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3243                              blas::Diagonal diag, uint64 n, uint64 k,
3244                              const DeviceMemory<std::complex<float>> &a,
3245                              int lda, DeviceMemory<std::complex<float>> *x,
3246                              int incx) {
3247   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3248             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3249 
3250   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3251                uint64, const DeviceMemory<std::complex<float>> &, int,
3252                DeviceMemory<std::complex<float>> *, int> impl;
3253   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3254               lda, x, incx);
3255 }
3256 
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3257 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3258                              blas::Diagonal diag, uint64 n, uint64 k,
3259                              const DeviceMemory<std::complex<double>> &a,
3260                              int lda, DeviceMemory<std::complex<double>> *x,
3261                              int incx) {
3262   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3263             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3264 
3265   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3266                uint64, const DeviceMemory<std::complex<double>> &, int,
3267                DeviceMemory<std::complex<double>> *, int> impl;
3268   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3269               lda, x, incx);
3270 }
3271 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3272 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3273                              blas::Diagonal diag, uint64 n,
3274                              const DeviceMemory<float> &ap,
3275                              DeviceMemory<float> *x, int incx) {
3276   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3277             PARAM(x), PARAM(incx));
3278 
3279   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3280                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3281   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3282               incx);
3283 }
3284 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3285 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3286                              blas::Diagonal diag, uint64 n,
3287                              const DeviceMemory<double> &ap,
3288                              DeviceMemory<double> *x, int incx) {
3289   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3290             PARAM(x), PARAM(incx));
3291 
3292   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3293                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3294   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3295               incx);
3296 }
3297 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3298 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3299                              blas::Diagonal diag, uint64 n,
3300                              const DeviceMemory<std::complex<float>> &ap,
3301                              DeviceMemory<std::complex<float>> *x, int incx) {
3302   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3303             PARAM(x), PARAM(incx));
3304 
3305   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3306                const DeviceMemory<std::complex<float>> &,
3307                DeviceMemory<std::complex<float>> *, int> impl;
3308   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3309               incx);
3310 }
3311 
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3312 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3313                              blas::Diagonal diag, uint64 n,
3314                              const DeviceMemory<std::complex<double>> &ap,
3315                              DeviceMemory<std::complex<double>> *x, int incx) {
3316   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3317             PARAM(x), PARAM(incx));
3318 
3319   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3320                const DeviceMemory<std::complex<double>> &,
3321                DeviceMemory<std::complex<double>> *, int> impl;
3322   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3323               incx);
3324 }
3325 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3326 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3327                              blas::Diagonal diag, uint64 n,
3328                              const DeviceMemory<float> &ap,
3329                              DeviceMemory<float> *x, int incx) {
3330   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3331             PARAM(x), PARAM(incx));
3332 
3333   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3334                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3335   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3336               incx);
3337 }
3338 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3339 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3340                              blas::Diagonal diag, uint64 n,
3341                              const DeviceMemory<double> &ap,
3342                              DeviceMemory<double> *x, int incx) {
3343   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3344             PARAM(x), PARAM(incx));
3345 
3346   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3347                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3348   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3349               incx);
3350 }
3351 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3352 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3353                              blas::Diagonal diag, uint64 n,
3354                              const DeviceMemory<std::complex<float>> &ap,
3355                              DeviceMemory<std::complex<float>> *x, int incx) {
3356   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3357             PARAM(x), PARAM(incx));
3358 
3359   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3360                const DeviceMemory<std::complex<float>> &,
3361                DeviceMemory<std::complex<float>> *, int> impl;
3362   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3363               incx);
3364 }
3365 
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3366 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3367                              blas::Diagonal diag, uint64 n,
3368                              const DeviceMemory<std::complex<double>> &ap,
3369                              DeviceMemory<std::complex<double>> *x, int incx) {
3370   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3371             PARAM(x), PARAM(incx));
3372 
3373   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3374                const DeviceMemory<std::complex<double>> &,
3375                DeviceMemory<std::complex<double>> *, int> impl;
3376   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3377               incx);
3378 }
3379 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3380 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3381                              blas::Diagonal diag, uint64 n,
3382                              const DeviceMemory<float> &a, int lda,
3383                              DeviceMemory<float> *x, int incx) {
3384   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3385             PARAM(lda), PARAM(x), PARAM(incx));
3386 
3387   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3388                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3389                int> impl;
3390   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3391               lda, x, incx);
3392 }
3393 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3394 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3395                              blas::Diagonal diag, uint64 n,
3396                              const DeviceMemory<double> &a, int lda,
3397                              DeviceMemory<double> *x, int incx) {
3398   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3399             PARAM(lda), PARAM(x), PARAM(incx));
3400 
3401   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3402                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3403                int> impl;
3404   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3405               lda, x, incx);
3406 }
3407 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3408 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3409                              blas::Diagonal diag, uint64 n,
3410                              const DeviceMemory<std::complex<float>> &a,
3411                              int lda, DeviceMemory<std::complex<float>> *x,
3412                              int incx) {
3413   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3414             PARAM(lda), PARAM(x), PARAM(incx));
3415 
3416   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3417                const DeviceMemory<std::complex<float>> &, int,
3418                DeviceMemory<std::complex<float>> *, int> impl;
3419   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3420               lda, x, incx);
3421 }
3422 
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3423 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3424                              blas::Diagonal diag, uint64 n,
3425                              const DeviceMemory<std::complex<double>> &a,
3426                              int lda, DeviceMemory<std::complex<double>> *x,
3427                              int incx) {
3428   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3429             PARAM(lda), PARAM(x), PARAM(incx));
3430 
3431   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3432                const DeviceMemory<std::complex<double>> &, int,
3433                DeviceMemory<std::complex<double>> *, int> impl;
3434   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3435               lda, x, incx);
3436 }
3437 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3438 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3439                              blas::Diagonal diag, uint64 n,
3440                              const DeviceMemory<float> &a, int lda,
3441                              DeviceMemory<float> *x, int incx) {
3442   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3443             PARAM(lda), PARAM(x), PARAM(incx));
3444 
3445   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3446                const DeviceMemory<float> &, int, DeviceMemory<float> *,
3447                int> impl;
3448   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3449               lda, x, incx);
3450 }
3451 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3452 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3453                              blas::Diagonal diag, uint64 n,
3454                              const DeviceMemory<double> &a, int lda,
3455                              DeviceMemory<double> *x, int incx) {
3456   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3457             PARAM(lda), PARAM(x), PARAM(incx));
3458 
3459   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3460                const DeviceMemory<double> &, int, DeviceMemory<double> *,
3461                int> impl;
3462   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3463               lda, x, incx);
3464 }
3465 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3466 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3467                              blas::Diagonal diag, uint64 n,
3468                              const DeviceMemory<std::complex<float>> &a,
3469                              int lda, DeviceMemory<std::complex<float>> *x,
3470                              int incx) {
3471   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3472             PARAM(lda), PARAM(x), PARAM(incx));
3473 
3474   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3475                const DeviceMemory<std::complex<float>> &, int,
3476                DeviceMemory<std::complex<float>> *, int> impl;
3477   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3478               lda, x, incx);
3479 }
3480 
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3481 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3482                              blas::Diagonal diag, uint64 n,
3483                              const DeviceMemory<std::complex<double>> &a,
3484                              int lda, DeviceMemory<std::complex<double>> *x,
3485                              int incx) {
3486   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3487             PARAM(lda), PARAM(x), PARAM(incx));
3488 
3489   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3490                const DeviceMemory<std::complex<double>> &, int,
3491                DeviceMemory<std::complex<double>> *, int> impl;
3492   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3493               lda, x, incx);
3494 }
3495 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3496 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3497                              uint64 m, uint64 n, uint64 k, float alpha,
3498                              const DeviceMemory<Eigen::half> &a, int lda,
3499                              const DeviceMemory<Eigen::half> &b, int ldb,
3500                              float beta,
3501                              DeviceMemory<Eigen::half> *c, int ldc) {
3502   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3503             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3504             PARAM(beta), PARAM(c), PARAM(ldc));
3505 
3506   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3507                const DeviceMemory<Eigen::half> &, int,
3508                const DeviceMemory<Eigen::half> &, int,
3509                float, DeviceMemory<Eigen::half> *, int> impl;
3510   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3511               alpha, a, lda, b, ldb, beta, c, ldc);
3512 }
3513 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3514 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3515                              uint64 m, uint64 n, uint64 k, float alpha,
3516                              const DeviceMemory<float> &a, int lda,
3517                              const DeviceMemory<float> &b, int ldb, float beta,
3518                              DeviceMemory<float> *c, int ldc) {
3519   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3520             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3521             PARAM(beta), PARAM(c), PARAM(ldc));
3522 
3523   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3524                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3525                int, float, DeviceMemory<float> *, int> impl;
3526   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3527               alpha, a, lda, b, ldb, beta, c, ldc);
3528 }
3529 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3530 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3531                              uint64 m, uint64 n, uint64 k, double alpha,
3532                              const DeviceMemory<double> &a, int lda,
3533                              const DeviceMemory<double> &b, int ldb,
3534                              double beta, DeviceMemory<double> *c, int ldc) {
3535   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3536             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3537             PARAM(beta), PARAM(c), PARAM(ldc));
3538 
3539   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3540                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3541                int, double, DeviceMemory<double> *, int> impl;
3542   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3543               alpha, a, lda, b, ldb, beta, c, ldc);
3544 }
3545 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3546 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3547                              uint64 m, uint64 n, uint64 k,
3548                              std::complex<float> alpha,
3549                              const DeviceMemory<std::complex<float>> &a,
3550                              int lda,
3551                              const DeviceMemory<std::complex<float>> &b,
3552                              int ldb, std::complex<float> beta,
3553                              DeviceMemory<std::complex<float>> *c, int ldc) {
3554   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3555             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3556             PARAM(beta), PARAM(c), PARAM(ldc));
3557 
3558   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3559                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3560                int, const DeviceMemory<std::complex<float>> &, int,
3561                std::complex<float>, DeviceMemory<std::complex<float>> *,
3562                int> impl;
3563   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3564               alpha, a, lda, b, ldb, beta, c, ldc);
3565 }
3566 
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3567 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3568                              uint64 m, uint64 n, uint64 k,
3569                              std::complex<double> alpha,
3570                              const DeviceMemory<std::complex<double>> &a,
3571                              int lda,
3572                              const DeviceMemory<std::complex<double>> &b,
3573                              int ldb, std::complex<double> beta,
3574                              DeviceMemory<std::complex<double>> *c, int ldc) {
3575   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3576             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3577             PARAM(beta), PARAM(c), PARAM(ldc));
3578 
3579   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3580                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3581                int, const DeviceMemory<std::complex<double>> &, int,
3582                std::complex<double>, DeviceMemory<std::complex<double>> *,
3583                int> impl;
3584   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3585               alpha, a, lda, b, ldb, beta, c, ldc);
3586 }
3587 
3588 namespace {
3589 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3590 // blas::ProfileResult*.  This functor doesn't put the stream into an error
3591 // state if the op fails and the profile result is non-null.  Instead, the
3592 // error-ness is returned in the profile result itself.
3593 template <typename... Args>
3594 struct ThenBlasWithProfileImpl {
operator ()perftools::gputools::__anoncb2e7d7b0211::ThenBlasWithProfileImpl3595   Stream &operator()(Stream *stream,
3596                      bool (blas::BlasSupport::*blas_func)(
3597                          Stream *, Args..., blas::ProfileResult *),
3598                      Args... args, blas::ProfileResult *profile_result) {
3599     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3600     bool record_error = profile_result == nullptr;
3601     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3602   }
3603 };
3604 }  // anonymous namespace
3605 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3606 Stream &Stream::ThenBlasGemvWithProfiling(
3607     blas::Transpose trans, uint64 m, uint64 n, float alpha,
3608     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3609     int incx, float beta, DeviceMemory<float> *y, int incy,
3610     blas::ProfileResult *output_profile_result) {
3611   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3612             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3613             PARAM(incy));
3614 
3615   ThenBlasWithProfileImpl<
3616       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3617       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3618       impl;
3619   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3620               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3621 }
3622 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3623 Stream &Stream::ThenBlasGemvWithProfiling(
3624     blas::Transpose trans, uint64 m, uint64 n, double alpha,
3625     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3626     int incx, double beta, DeviceMemory<double> *y, int incy,
3627     blas::ProfileResult *output_profile_result) {
3628   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3629             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3630             PARAM(incy));
3631 
3632   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3633                           const DeviceMemory<double> &, int,
3634                           const DeviceMemory<double> &, int, double,
3635                           DeviceMemory<double> *, int>
3636       impl;
3637   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3638               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3639 }
3640 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3641 Stream &Stream::ThenBlasGemvWithProfiling(
3642     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3643     const DeviceMemory<std::complex<float>> &a, int lda,
3644     const DeviceMemory<std::complex<float>> &x, int incx,
3645     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3646     blas::ProfileResult *output_profile_result) {
3647   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3648             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3649             PARAM(incy));
3650 
3651   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3652                           const DeviceMemory<std::complex<float>> &, int,
3653                           const DeviceMemory<std::complex<float>> &, int,
3654                           std::complex<float>,
3655                           DeviceMemory<std::complex<float>> *, int>
3656       impl;
3657   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3658               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3659 }
3660 
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3661 Stream &Stream::ThenBlasGemvWithProfiling(
3662     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3663     const DeviceMemory<std::complex<double>> &a, int lda,
3664     const DeviceMemory<std::complex<double>> &x, int incx,
3665     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3666     blas::ProfileResult *output_profile_result) {
3667   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3668             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3669             PARAM(incy));
3670 
3671   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3672                           const DeviceMemory<std::complex<double>> &, int,
3673                           const DeviceMemory<std::complex<double>> &, int,
3674                           std::complex<double>,
3675                           DeviceMemory<std::complex<double>> *, int>
3676       impl;
3677   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3678               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3679 }
3680 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3681 Stream &Stream::ThenBlasGemmWithProfiling(
3682     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3683     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3684     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3685     DeviceMemory<Eigen::half> *c, int ldc,
3686     blas::ProfileResult *output_profile_result) {
3687   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3688             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3689             PARAM(beta), PARAM(c), PARAM(ldc));
3690 
3691   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3692                           uint64, float, const DeviceMemory<Eigen::half> &, int,
3693                           const DeviceMemory<Eigen::half> &, int, float,
3694                           DeviceMemory<Eigen::half> *, int>
3695       impl;
3696   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3697               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3698               output_profile_result);
3699 }
3700 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3701 Stream &Stream::ThenBlasGemmWithProfiling(
3702     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3703     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3704     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3705     int ldc, blas::ProfileResult *output_profile_result) {
3706   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3707             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3708             PARAM(beta), PARAM(c), PARAM(ldc));
3709 
3710   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3711                           uint64, float, const DeviceMemory<float> &, int,
3712                           const DeviceMemory<float> &, int, float,
3713                           DeviceMemory<float> *, int>
3714       impl;
3715   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3716               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3717               output_profile_result);
3718 }
3719 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3720 Stream &Stream::ThenBlasGemmWithProfiling(
3721     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3722     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3723     const DeviceMemory<double> &b, int ldb, double beta,
3724     DeviceMemory<double> *c, int ldc,
3725     blas::ProfileResult *output_profile_result) {
3726   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3727             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3728             PARAM(beta), PARAM(c), PARAM(ldc));
3729 
3730   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3731                           uint64, double, const DeviceMemory<double> &, int,
3732                           const DeviceMemory<double> &, int, double,
3733                           DeviceMemory<double> *, int>
3734       impl;
3735   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3736               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3737               output_profile_result);
3738 }
3739 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3740 Stream &Stream::ThenBlasGemmWithProfiling(
3741     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3742     uint64 k, std::complex<float> alpha,
3743     const DeviceMemory<std::complex<float>> &a, int lda,
3744     const DeviceMemory<std::complex<float>> &b, int ldb,
3745     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3746     blas::ProfileResult *output_profile_result) {
3747   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3748             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3749             PARAM(beta), PARAM(c), PARAM(ldc));
3750 
3751   ThenBlasWithProfileImpl<
3752       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3753       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3754       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3755       DeviceMemory<std::complex<float>> *, int>
3756       impl;
3757   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3758               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3759               output_profile_result);
3760 }
3761 
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3762 Stream &Stream::ThenBlasGemmWithProfiling(
3763     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3764     uint64 k, std::complex<double> alpha,
3765     const DeviceMemory<std::complex<double>> &a, int lda,
3766     const DeviceMemory<std::complex<double>> &b, int ldb,
3767     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3768     blas::ProfileResult *output_profile_result) {
3769   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3770             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3771             PARAM(beta), PARAM(c), PARAM(ldc));
3772 
3773   ThenBlasWithProfileImpl<
3774       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3775       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3776       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3777       DeviceMemory<std::complex<double>> *, int>
3778       impl;
3779   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3780               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3781               output_profile_result);
3782 }
3783 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const Eigen::half & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const Eigen::half & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3784 Stream &Stream::ThenBlasGemmWithAlgorithm(
3785     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3786     uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
3787     int lda, const DeviceMemory<Eigen::half> &b, int ldb,
3788     const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
3789     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3790     blas::ProfileResult *output_profile_result) {
3791   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3792             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3793             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3794             PARAM(algorithm));
3795 
3796   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3797                           uint64, const Eigen::half &,
3798                           const DeviceMemory<Eigen::half> &, int,
3799                           const DeviceMemory<Eigen::half> &, int,
3800                           const Eigen::half &, DeviceMemory<Eigen::half> *, int,
3801                           blas::ComputationType, blas::AlgorithmType>
3802       impl;
3803   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3804               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3805               algorithm, output_profile_result);
3806 }
3807 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,int alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,int beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3808 Stream &Stream::ThenBlasGemmWithAlgorithm(
3809     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3810     uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
3811     const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
3812     int ldc, blas::ComputationType computation_type,
3813     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3814   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3815             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3816             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3817             PARAM(algorithm));
3818 
3819   ThenBlasWithProfileImpl<
3820       blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
3821       const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
3822       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3823       impl;
3824   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3825               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3826               algorithm, output_profile_result);
3827 }
3828 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3829 Stream &Stream::ThenBlasGemmWithAlgorithm(
3830     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3831     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3832     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3833     int ldc, blas::ComputationType computation_type,
3834     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3835   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3836             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3837             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3838             PARAM(algorithm));
3839 
3840   ThenBlasWithProfileImpl<
3841       blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3842       const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
3843       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3844       impl;
3845   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3846               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3847               algorithm, output_profile_result);
3848 }
3849 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3850 Stream &Stream::ThenBlasGemmWithAlgorithm(
3851     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3852     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3853     const DeviceMemory<double> &b, int ldb, double beta,
3854     DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
3855     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3856   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3857             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3858             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3859             PARAM(algorithm));
3860 
3861   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3862                           uint64, double, const DeviceMemory<double> &, int,
3863                           const DeviceMemory<double> &, int, double,
3864                           DeviceMemory<double> *, int, blas::ComputationType,
3865                           blas::AlgorithmType>
3866       impl;
3867   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3868               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3869               algorithm, output_profile_result);
3870 }
3871 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3872 Stream &Stream::ThenBlasGemmWithAlgorithm(
3873     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3874     uint64 k, std::complex<float> alpha,
3875     const DeviceMemory<std::complex<float>> &a, int lda,
3876     const DeviceMemory<std::complex<float>> &b, int ldb,
3877     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3878     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3879     blas::ProfileResult *output_profile_result) {
3880   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3881             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3882             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3883             PARAM(algorithm));
3884 
3885   ThenBlasWithProfileImpl<
3886       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3887       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3888       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3889       DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
3890       blas::AlgorithmType>
3891       impl;
3892   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3893               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3894               algorithm, output_profile_result);
3895 }
3896 
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3897 Stream &Stream::ThenBlasGemmWithAlgorithm(
3898     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3899     uint64 k, std::complex<double> alpha,
3900     const DeviceMemory<std::complex<double>> &a, int lda,
3901     const DeviceMemory<std::complex<double>> &b, int ldb,
3902     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3903     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3904     blas::ProfileResult *output_profile_result) {
3905   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3906             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3907             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3908             PARAM(algorithm));
3909 
3910   ThenBlasWithProfileImpl<
3911       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3912       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3913       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3914       DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
3915       blas::AlgorithmType>
3916       impl;
3917   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3918               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3919               algorithm, output_profile_result);
3920 }
3921 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3922 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3923                              uint64 n, std::complex<float> alpha,
3924                              const DeviceMemory<std::complex<float>> &a,
3925                              int lda,
3926                              const DeviceMemory<std::complex<float>> &b,
3927                              int ldb, std::complex<float> beta,
3928                              DeviceMemory<std::complex<float>> *c, int ldc) {
3929   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3930             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3931             PARAM(ldc));
3932 
3933   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3934                std::complex<float>, const DeviceMemory<std::complex<float>> &,
3935                int, const DeviceMemory<std::complex<float>> &, int,
3936                std::complex<float>, DeviceMemory<std::complex<float>> *,
3937                int> impl;
3938   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3939               lda, b, ldb, beta, c, ldc);
3940 }
3941 
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3942 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3943                              uint64 n, std::complex<double> alpha,
3944                              const DeviceMemory<std::complex<double>> &a,
3945                              int lda,
3946                              const DeviceMemory<std::complex<double>> &b,
3947                              int ldb, std::complex<double> beta,
3948                              DeviceMemory<std::complex<double>> *c, int ldc) {
3949   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3950             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3951             PARAM(ldc));
3952 
3953   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3954                std::complex<double>, const DeviceMemory<std::complex<double>> &,
3955                int, const DeviceMemory<std::complex<double>> &, int,
3956                std::complex<double>, DeviceMemory<std::complex<double>> *,
3957                int> impl;
3958   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3959               lda, b, ldb, beta, c, ldc);
3960 }
3961 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3962 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3963                              uint64 n, uint64 k, float alpha,
3964                              const DeviceMemory<std::complex<float>> &a,
3965                              int lda, float beta,
3966                              DeviceMemory<std::complex<float>> *c, int ldc) {
3967   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3968             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3969 
3970   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3971                const DeviceMemory<std::complex<float>> &, int, float,
3972                DeviceMemory<std::complex<float>> *, int> impl;
3973   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3974               lda, beta, c, ldc);
3975 }
3976 
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3978                              uint64 n, uint64 k, double alpha,
3979                              const DeviceMemory<std::complex<double>> &a,
3980                              int lda, double beta,
3981                              DeviceMemory<std::complex<double>> *c, int ldc) {
3982   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3983             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3984 
3985   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3986                const DeviceMemory<std::complex<double>> &, int, double,
3987                DeviceMemory<std::complex<double>> *, int> impl;
3988   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3989               lda, beta, c, ldc);
3990 }
3991 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3992 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3993                               uint64 n, uint64 k, std::complex<float> alpha,
3994                               const DeviceMemory<std::complex<float>> &a,
3995                               int lda,
3996                               const DeviceMemory<std::complex<float>> &b,
3997                               int ldb, float beta,
3998                               DeviceMemory<std::complex<float>> *c, int ldc) {
3999   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4000             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4001             PARAM(ldc));
4002 
4003   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4004                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4005                int, const DeviceMemory<std::complex<float>> &, int, float,
4006                DeviceMemory<std::complex<float>> *, int> impl;
4007   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4008               a, lda, b, ldb, beta, c, ldc);
4009 }
4010 
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4011 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4012                               uint64 n, uint64 k, std::complex<double> alpha,
4013                               const DeviceMemory<std::complex<double>> &a,
4014                               int lda,
4015                               const DeviceMemory<std::complex<double>> &b,
4016                               int ldb, double beta,
4017                               DeviceMemory<std::complex<double>> *c, int ldc) {
4018   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4019             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4020             PARAM(ldc));
4021 
4022   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4023                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4024                int, const DeviceMemory<std::complex<double>> &, int, double,
4025                DeviceMemory<std::complex<double>> *, int> impl;
4026   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4027               a, lda, b, ldb, beta, c, ldc);
4028 }
4029 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4030 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4031                              uint64 n, float alpha,
4032                              const DeviceMemory<float> &a, int lda,
4033                              const DeviceMemory<float> &b, int ldb, float beta,
4034                              DeviceMemory<float> *c, int ldc) {
4035   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4036             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4037             PARAM(ldc));
4038 
4039   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4040                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4041                int, float, DeviceMemory<float> *, int> impl;
4042   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4043               lda, b, ldb, beta, c, ldc);
4044 }
4045 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4046 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4047                              uint64 n, double alpha,
4048                              const DeviceMemory<double> &a, int lda,
4049                              const DeviceMemory<double> &b, int ldb,
4050                              double beta, DeviceMemory<double> *c, int ldc) {
4051   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4052             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4053             PARAM(ldc));
4054 
4055   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4056                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4057                int, double, DeviceMemory<double> *, int> impl;
4058   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4059               lda, b, ldb, beta, c, ldc);
4060 }
4061 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4062 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4063                              uint64 n, std::complex<float> alpha,
4064                              const DeviceMemory<std::complex<float>> &a,
4065                              int lda,
4066                              const DeviceMemory<std::complex<float>> &b,
4067                              int ldb, std::complex<float> beta,
4068                              DeviceMemory<std::complex<float>> *c, int ldc) {
4069   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4070             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4071             PARAM(ldc));
4072 
4073   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4074                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4075                int, const DeviceMemory<std::complex<float>> &, int,
4076                std::complex<float>, DeviceMemory<std::complex<float>> *,
4077                int> impl;
4078   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4079               lda, b, ldb, beta, c, ldc);
4080 }
4081 
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4082 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4083                              uint64 n, std::complex<double> alpha,
4084                              const DeviceMemory<std::complex<double>> &a,
4085                              int lda,
4086                              const DeviceMemory<std::complex<double>> &b,
4087                              int ldb, std::complex<double> beta,
4088                              DeviceMemory<std::complex<double>> *c, int ldc) {
4089   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4090             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4091             PARAM(ldc));
4092 
4093   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4094                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4095                int, const DeviceMemory<std::complex<double>> &, int,
4096                std::complex<double>, DeviceMemory<std::complex<double>> *,
4097                int> impl;
4098   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4099               lda, b, ldb, beta, c, ldc);
4100 }
4101 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4102 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4103                              uint64 n, uint64 k, float alpha,
4104                              const DeviceMemory<float> &a, int lda, float beta,
4105                              DeviceMemory<float> *c, int ldc) {
4106   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4107             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4108 
4109   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4110                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4111                int> impl;
4112   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4113               lda, beta, c, ldc);
4114 }
4115 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4116 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4117                              uint64 n, uint64 k, double alpha,
4118                              const DeviceMemory<double> &a, int lda,
4119                              double beta, DeviceMemory<double> *c, int ldc) {
4120   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4121             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4122 
4123   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4124                const DeviceMemory<double> &, int, double,
4125                DeviceMemory<double> *, int> impl;
4126   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4127               lda, beta, c, ldc);
4128 }
4129 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4130 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4131                              uint64 n, uint64 k, std::complex<float> alpha,
4132                              const DeviceMemory<std::complex<float>> &a,
4133                              int lda, std::complex<float> beta,
4134                              DeviceMemory<std::complex<float>> *c, int ldc) {
4135   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4136             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4137 
4138   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4139                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4140                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4141                int> impl;
4142   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4143               lda, beta, c, ldc);
4144 }
4145 
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4146 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4147                              uint64 n, uint64 k, std::complex<double> alpha,
4148                              const DeviceMemory<std::complex<double>> &a,
4149                              int lda, std::complex<double> beta,
4150                              DeviceMemory<std::complex<double>> *c, int ldc) {
4151   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4152             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4153 
4154   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4155                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4156                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4157                int> impl;
4158   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4159               lda, beta, c, ldc);
4160 }
4161 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4162 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4163                               uint64 n, uint64 k, float alpha,
4164                               const DeviceMemory<float> &a, int lda,
4165                               const DeviceMemory<float> &b, int ldb, float beta,
4166                               DeviceMemory<float> *c, int ldc) {
4167   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4168             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4169             PARAM(ldc));
4170 
4171   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4172                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4173                int, float, DeviceMemory<float> *, int> impl;
4174   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4175               a, lda, b, ldb, beta, c, ldc);
4176 }
4177 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4178 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4179                               uint64 n, uint64 k, double alpha,
4180                               const DeviceMemory<double> &a, int lda,
4181                               const DeviceMemory<double> &b, int ldb,
4182                               double beta, DeviceMemory<double> *c, int ldc) {
4183   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4184             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4185             PARAM(ldc));
4186 
4187   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4188                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4189                int, double, DeviceMemory<double> *, int> impl;
4190   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4191               a, lda, b, ldb, beta, c, ldc);
4192 }
4193 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4194 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4195                               uint64 n, uint64 k, std::complex<float> alpha,
4196                               const DeviceMemory<std::complex<float>> &a,
4197                               int lda,
4198                               const DeviceMemory<std::complex<float>> &b,
4199                               int ldb, std::complex<float> beta,
4200                               DeviceMemory<std::complex<float>> *c, int ldc) {
4201   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4202             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4203             PARAM(ldc));
4204 
4205   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4206                std::complex<float>, const DeviceMemory<std::complex<float>> &,
4207                int, const DeviceMemory<std::complex<float>> &, int,
4208                std::complex<float>, DeviceMemory<std::complex<float>> *,
4209                int> impl;
4210   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4211               a, lda, b, ldb, beta, c, ldc);
4212 }
4213 
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4214 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4215                               uint64 n, uint64 k, std::complex<double> alpha,
4216                               const DeviceMemory<std::complex<double>> &a,
4217                               int lda,
4218                               const DeviceMemory<std::complex<double>> &b,
4219                               int ldb, std::complex<double> beta,
4220                               DeviceMemory<std::complex<double>> *c, int ldc) {
4221   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4222             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4223             PARAM(ldc));
4224 
4225   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4226                std::complex<double>, const DeviceMemory<std::complex<double>> &,
4227                int, const DeviceMemory<std::complex<double>> &, int,
4228                std::complex<double>, DeviceMemory<std::complex<double>> *,
4229                int> impl;
4230   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4231               a, lda, b, ldb, beta, c, ldc);
4232 }
4233 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4234 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4235                              blas::Transpose transa, blas::Diagonal diag,
4236                              uint64 m, uint64 n, float alpha,
4237                              const DeviceMemory<float> &a, int lda,
4238                              DeviceMemory<float> *b, int ldb) {
4239   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4240             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4241 
4242   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4243                uint64, uint64, float, const DeviceMemory<float> &, int,
4244                DeviceMemory<float> *, int> impl;
4245   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4246               n, alpha, a, lda, b, ldb);
4247 }
4248 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4250                              blas::Transpose transa, blas::Diagonal diag,
4251                              uint64 m, uint64 n, double alpha,
4252                              const DeviceMemory<double> &a, int lda,
4253                              DeviceMemory<double> *b, int ldb) {
4254   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4255             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4256 
4257   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4258                uint64, uint64, double, const DeviceMemory<double> &, int,
4259                DeviceMemory<double> *, int> impl;
4260   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4261               n, alpha, a, lda, b, ldb);
4262 }
4263 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4265                              blas::Transpose transa, blas::Diagonal diag,
4266                              uint64 m, uint64 n, std::complex<float> alpha,
4267                              const DeviceMemory<std::complex<float>> &a,
4268                              int lda, DeviceMemory<std::complex<float>> *b,
4269                              int ldb) {
4270   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4271             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4272 
4273   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4274                uint64, uint64, std::complex<float>,
4275                const DeviceMemory<std::complex<float>> &, int,
4276                DeviceMemory<std::complex<float>> *, int> impl;
4277   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4278               n, alpha, a, lda, b, ldb);
4279 }
4280 
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4281 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4282                              blas::Transpose transa, blas::Diagonal diag,
4283                              uint64 m, uint64 n, std::complex<double> alpha,
4284                              const DeviceMemory<std::complex<double>> &a,
4285                              int lda, DeviceMemory<std::complex<double>> *b,
4286                              int ldb) {
4287   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4288             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4289 
4290   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4291                uint64, uint64, std::complex<double>,
4292                const DeviceMemory<std::complex<double>> &, int,
4293                DeviceMemory<std::complex<double>> *, int> impl;
4294   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4295               n, alpha, a, lda, b, ldb);
4296 }
4297 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4298 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4299                              blas::Transpose transa, blas::Diagonal diag,
4300                              uint64 m, uint64 n, float alpha,
4301                              const DeviceMemory<float> &a, int lda,
4302                              DeviceMemory<float> *b, int ldb) {
4303   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4304             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4305 
4306   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4307                uint64, uint64, float, const DeviceMemory<float> &, int,
4308                DeviceMemory<float> *, int> impl;
4309   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4310               n, alpha, a, lda, b, ldb);
4311 }
4312 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4314                              blas::Transpose transa, blas::Diagonal diag,
4315                              uint64 m, uint64 n, double alpha,
4316                              const DeviceMemory<double> &a, int lda,
4317                              DeviceMemory<double> *b, int ldb) {
4318   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4319             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4320 
4321   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4322                uint64, uint64, double, const DeviceMemory<double> &, int,
4323                DeviceMemory<double> *, int> impl;
4324   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4325               n, alpha, a, lda, b, ldb);
4326 }
4327 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4329                              blas::Transpose transa, blas::Diagonal diag,
4330                              uint64 m, uint64 n, std::complex<float> alpha,
4331                              const DeviceMemory<std::complex<float>> &a,
4332                              int lda, DeviceMemory<std::complex<float>> *b,
4333                              int ldb) {
4334   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4335             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4336 
4337   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4338                uint64, uint64, std::complex<float>,
4339                const DeviceMemory<std::complex<float>> &, int,
4340                DeviceMemory<std::complex<float>> *, int> impl;
4341   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4342               n, alpha, a, lda, b, ldb);
4343 }
4344 
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4345 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4346                              blas::Transpose transa, blas::Diagonal diag,
4347                              uint64 m, uint64 n, std::complex<double> alpha,
4348                              const DeviceMemory<std::complex<double>> &a,
4349                              int lda, DeviceMemory<std::complex<double>> *b,
4350                              int ldb) {
4351   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4352             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4353 
4354   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4355                uint64, uint64, std::complex<double>,
4356                const DeviceMemory<std::complex<double>> &, int,
4357                DeviceMemory<std::complex<double>> *, int> impl;
4358   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4359               n, alpha, a, lda, b, ldb);
4360 }
4361 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4362 Stream &Stream::ThenBlasGemmBatched(
4363     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4364     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4365     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4366     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4367     int batch_count) {
4368   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4369                                         b, ldb, beta, c, ldc, batch_count,
4370                                         /*scratch_allocator=*/nullptr);
4371 }
4372 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4373 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4374     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4375     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4376     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4377     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4378     int batch_count, ScratchAllocator *scratch_allocator) {
4379   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4380             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4381             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4382 
4383   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4384                const port::ArraySlice<DeviceMemory<float> *> &, int,
4385                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4386                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4387                ScratchAllocator *>
4388       impl;
4389   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4390               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4391               scratch_allocator);
4392 }
4393 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4394 Stream &Stream::ThenBlasGemmBatched(
4395     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4396     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4397     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4398     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4399     int batch_count) {
4400   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4401                                         b, ldb, beta, c, ldc, batch_count,
4402                                         /*scratch_allocator=*/nullptr);
4403 }
4404 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4405 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4406     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4407     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4408     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4409     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4410     int batch_count, ScratchAllocator *scratch_allocator) {
4411   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4412             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4413             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4414 
4415   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4416                const port::ArraySlice<DeviceMemory<double> *> &, int,
4417                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4418                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4419                ScratchAllocator *>
4420       impl;
4421   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4422               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4423               scratch_allocator);
4424 }
4425 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4426 Stream &Stream::ThenBlasGemmBatched(
4427     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4428     uint64 k, std::complex<float> alpha,
4429     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4430     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4431     std::complex<float> beta,
4432     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4433     int batch_count) {
4434   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4435                                         b, ldb, beta, c, ldc, batch_count,
4436                                         /*scratch_allocator=*/nullptr);
4437 }
4438 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4439 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4440     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4441     uint64 k, std::complex<float> alpha,
4442     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4443     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4444     std::complex<float> beta,
4445     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4446     int batch_count, ScratchAllocator *scratch_allocator) {
4447   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4448             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4449             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4450 
4451   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4452                std::complex<float>,
4453                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4454                int,
4455                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4456                int, std::complex<float>,
4457                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4458                int, int, ScratchAllocator *>
4459       impl;
4460   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4461               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4462               scratch_allocator);
4463 }
4464 
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4465 Stream &Stream::ThenBlasGemmBatched(
4466     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4467     uint64 k, std::complex<double> alpha,
4468     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4469     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4470     std::complex<double> beta,
4471     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4472     int batch_count) {
4473   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4474                                         b, ldb, beta, c, ldc, batch_count,
4475                                         /*scratch_allocator=*/nullptr);
4476 }
4477 
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4478 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4479     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4480     uint64 k, std::complex<double> alpha,
4481     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4482     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4483     std::complex<double> beta,
4484     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4485     int batch_count, ScratchAllocator *scratch_allocator) {
4486   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4487             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4488             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4489 
4490   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4491                std::complex<double>,
4492                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4493                int,
4494                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4495                int, std::complex<double>,
4496                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4497                int, int, ScratchAllocator *>
4498       impl;
4499   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4500               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4501               scratch_allocator);
4502 }
4503 
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4504 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4505   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4506 
4507   if (ok()) {
4508     if (rng::RngSupport *rng = parent_->AsRng()) {
4509       CheckError(rng->SetSeed(this, seed, seed_bytes));
4510     } else {
4511       SetError();
4512       LOG(INFO) << "stream " << this << " unable to initialize RNG";
4513     }
4514   } else {
4515     LOG(INFO) << "stream " << this
4516               << " did not set RNG seed: " << static_cast<const void *>(seed)
4517               << "; bytes: " << seed_bytes;
4518   }
4519   return *this;
4520 }
4521 
ThenPopulateRandUniform(DeviceMemory<float> * values)4522 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4523   VLOG_CALL(PARAM(values));
4524 
4525   if (ok()) {
4526     if (rng::RngSupport *rng = parent_->AsRng()) {
4527       CheckError(rng->DoPopulateRandUniform(this, values));
4528     } else {
4529       SetError();
4530       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4531                    "without RNG support.";
4532     }
4533   }
4534   return *this;
4535 }
4536 
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4537 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4538                                          DeviceMemory<float> *values) {
4539   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4540 
4541   if (ok()) {
4542     if (rng::RngSupport *rng = parent_->AsRng()) {
4543       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4544     } else {
4545       SetError();
4546       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4547                    "without RNG support.";
4548     }
4549   }
4550   return *this;
4551 }
4552 
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4553 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4554                                          DeviceMemory<double> *values) {
4555   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4556 
4557   if (ok()) {
4558     if (rng::RngSupport *rng = parent_->AsRng()) {
4559       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4560     } else {
4561       SetError();
4562       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4563                    "without RNG support.";
4564     }
4565   }
4566   return *this;
4567 }
4568 
ThenPopulateRandUniform(DeviceMemory<double> * values)4569 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4570   VLOG_CALL(PARAM(values));
4571 
4572   if (ok()) {
4573     if (rng::RngSupport *rng = parent_->AsRng()) {
4574       CheckError(rng->DoPopulateRandUniform(this, values));
4575     } else {
4576       SetError();
4577       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4578                    "without RNG support.";
4579     }
4580   }
4581   return *this;
4582 }
4583 
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4584 Stream &Stream::ThenPopulateRandUniform(
4585     DeviceMemory<std::complex<float>> *values) {
4586   VLOG_CALL(PARAM(values));
4587 
4588   if (ok()) {
4589     if (rng::RngSupport *rng = parent_->AsRng()) {
4590       CheckError(rng->DoPopulateRandUniform(this, values));
4591     } else {
4592       SetError();
4593       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4594                    "without RNG support.";
4595     }
4596   }
4597   return *this;
4598 }
4599 
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4600 Stream &Stream::ThenPopulateRandUniform(
4601     DeviceMemory<std::complex<double>> *values) {
4602   VLOG_CALL(PARAM(values));
4603 
4604   if (ok()) {
4605     if (rng::RngSupport *rng = parent_->AsRng()) {
4606       CheckError(rng->DoPopulateRandUniform(this, values));
4607     } else {
4608       SetError();
4609       LOG(INFO) << "stream " << this
4610                 << " attempting to perform RNG operation using StreamExecutor "
4611                    "without RNG support.";
4612     }
4613   }
4614   return *this;
4615 }
4616 
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4617 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4618                            uint64 size) {
4619   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4620 
4621   if (ok()) {
4622     CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4623   } else {
4624     LOG(INFO) << "stream " << this
4625               << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4626   }
4627   return *this;
4628 }
4629 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4630 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4631                            uint64 size) {
4632   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4633 
4634   if (ok()) {
4635     CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4636   } else {
4637     LOG(INFO) << "stream " << this
4638               << " did not memcpy host-to-device; source: " << host_src;
4639   }
4640   return *this;
4641 }
4642 
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4643 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4644                            const DeviceMemoryBase &gpu_src, uint64 size) {
4645   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4646 
4647   if (ok()) {
4648     CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4649   } else {
4650     LOG(INFO) << "stream " << this
4651               << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4652   }
4653   return *this;
4654 }
4655 
ThenMemZero(DeviceMemoryBase * location,uint64 size)4656 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4657   VLOG_CALL(PARAM(location), PARAM(size));
4658 
4659   if (ok()) {
4660     CheckError(parent_->MemZero(this, location, size));
4661   } else {
4662     LOG(INFO) << "stream " << this
4663               << " did not memzero GPU location; source: " << location;
4664   }
4665   return *this;
4666 }
4667 
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4668 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4669                              uint64 size) {
4670   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4671 
4672   if (ok()) {
4673     CheckError(parent_->Memset32(this, location, pattern, size));
4674   } else {
4675     LOG(INFO) << "stream " << this
4676               << " did not memset GPU location; source: " << location
4677               << "; size: " << size << "; pattern: " << std::hex << pattern;
4678   }
4679   return *this;
4680 }
4681 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4682 Stream &Stream::ThenRnnForward(
4683     const dnn::RnnDescriptor &rnn_desc,
4684     const dnn::RnnSequenceTensorDescriptor &input_desc,
4685     const DeviceMemory<Eigen::half> &input_data,
4686     const dnn::RnnStateTensorDescriptor &input_h_desc,
4687     const DeviceMemory<Eigen::half> &input_h_data,
4688     const dnn::RnnStateTensorDescriptor &input_c_desc,
4689     const DeviceMemory<Eigen::half> &input_c_data,
4690     const DeviceMemory<Eigen::half> &params,
4691     const dnn::RnnSequenceTensorDescriptor &output_desc,
4692     DeviceMemory<Eigen::half> *output_data,
4693     const dnn::RnnStateTensorDescriptor &output_h_desc,
4694     DeviceMemory<Eigen::half> *output_h_data,
4695     const dnn::RnnStateTensorDescriptor &output_c_desc,
4696     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4697     ScratchAllocator *reserve_space_allocator,
4698     ScratchAllocator *workspace_allocator) {
4699   // TODO(zhengxq): add VLOG PARAM calls.
4700   if (ok()) {
4701     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4702       CheckError(dnn->DoRnnForward(
4703           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4704           input_c_desc, input_c_data, params, output_desc, output_data,
4705           output_h_desc, output_h_data, output_c_desc, output_c_data,
4706           is_training, reserve_space_allocator, workspace_allocator));
4707     } else {
4708       SetError();
4709       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4710     }
4711   }
4712   return *this;
4713 }
4714 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4715 Stream &Stream::ThenRnnForward(
4716     const dnn::RnnDescriptor &rnn_desc,
4717     const dnn::RnnSequenceTensorDescriptor &input_desc,
4718     const DeviceMemory<float> &input_data,
4719     const dnn::RnnStateTensorDescriptor &input_h_desc,
4720     const DeviceMemory<float> &input_h_data,
4721     const dnn::RnnStateTensorDescriptor &input_c_desc,
4722     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4723     const dnn::RnnSequenceTensorDescriptor &output_desc,
4724     DeviceMemory<float> *output_data,
4725     const dnn::RnnStateTensorDescriptor &output_h_desc,
4726     DeviceMemory<float> *output_h_data,
4727     const dnn::RnnStateTensorDescriptor &output_c_desc,
4728     DeviceMemory<float> *output_c_data, bool is_training,
4729     ScratchAllocator *reserve_space_allocator,
4730     ScratchAllocator *workspace_allocator) {
4731   // TODO(zhengxq): add VLOG PARAM calls.
4732   if (ok()) {
4733     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4734       CheckError(dnn->DoRnnForward(
4735           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4736           input_c_desc, input_c_data, params, output_desc, output_data,
4737           output_h_desc, output_h_data, output_c_desc, output_c_data,
4738           is_training, reserve_space_allocator, workspace_allocator));
4739     } else {
4740       SetError();
4741       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4742     }
4743   }
4744   return *this;
4745 }
4746 
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4747 Stream &Stream::ThenRnnForward(
4748     const dnn::RnnDescriptor &rnn_desc,
4749     const dnn::RnnSequenceTensorDescriptor &input_desc,
4750     const DeviceMemory<double> &input_data,
4751     const dnn::RnnStateTensorDescriptor &input_h_desc,
4752     const DeviceMemory<double> &input_h_data,
4753     const dnn::RnnStateTensorDescriptor &input_c_desc,
4754     const DeviceMemory<double> &input_c_data,
4755     const DeviceMemory<double> &params,
4756     const dnn::RnnSequenceTensorDescriptor &output_desc,
4757     DeviceMemory<double> *output_data,
4758     const dnn::RnnStateTensorDescriptor &output_h_desc,
4759     DeviceMemory<double> *output_h_data,
4760     const dnn::RnnStateTensorDescriptor &output_c_desc,
4761     DeviceMemory<double> *output_c_data, bool is_training,
4762     ScratchAllocator *reserve_space_allocator,
4763     ScratchAllocator *workspace_allocator) {
4764   // TODO(zhengxq): add VLOG PARAM calls.
4765   if (ok()) {
4766     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4767       CheckError(dnn->DoRnnForward(
4768           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4769           input_c_desc, input_c_data, params, output_desc, output_data,
4770           output_h_desc, output_h_data, output_c_desc, output_c_data,
4771           is_training, reserve_space_allocator, workspace_allocator));
4772     } else {
4773       SetError();
4774       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4775     }
4776   }
4777   return *this;
4778 }
4779 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4780 Stream &Stream::ThenRnnBackward(
4781     const dnn::RnnDescriptor &rnn_desc,
4782     const dnn::RnnSequenceTensorDescriptor &input_desc,
4783     const DeviceMemory<Eigen::half> &input_data,
4784     const dnn::RnnStateTensorDescriptor &input_h_desc,
4785     const DeviceMemory<Eigen::half> &input_h_data,
4786     const dnn::RnnStateTensorDescriptor &input_c_desc,
4787     const DeviceMemory<Eigen::half> &input_c_data,
4788     const DeviceMemory<Eigen::half> &params,
4789     const dnn::RnnSequenceTensorDescriptor &output_desc,
4790     const DeviceMemory<Eigen::half> &output_data,
4791     const dnn::RnnStateTensorDescriptor &output_h_desc,
4792     const DeviceMemory<Eigen::half> &output_h_data,
4793     const dnn::RnnStateTensorDescriptor &output_c_desc,
4794     const DeviceMemory<Eigen::half> &output_c_data,
4795     const DeviceMemory<Eigen::half> &output_backprop_data,
4796     const DeviceMemory<Eigen::half> &output_h_backprop_data,
4797     const DeviceMemory<Eigen::half> &output_c_backprop_data,
4798     DeviceMemory<Eigen::half> *input_backprop_data,
4799     DeviceMemory<Eigen::half> *input_h_backprop_data,
4800     DeviceMemory<Eigen::half> *input_c_backprop_data,
4801     DeviceMemory<Eigen::half> *params_backprop_data,
4802     DeviceMemory<uint8> *reserve_space_data,
4803     ScratchAllocator *workspace_allocator) {
4804   // TODO(zhengxq): add VLOG PARAM calls.
4805   if (ok()) {
4806     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4807       CheckError(dnn->DoRnnBackward(
4808           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4809           input_c_desc, input_c_data, params, output_desc, output_data,
4810           output_h_desc, output_h_data, output_c_desc, output_c_data,
4811           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4812           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4813           params_backprop_data, reserve_space_data, workspace_allocator));
4814     } else {
4815       SetError();
4816       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4817     }
4818   }
4819   return *this;
4820 }
4821 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4822 Stream &Stream::ThenRnnBackward(
4823     const dnn::RnnDescriptor &rnn_desc,
4824     const dnn::RnnSequenceTensorDescriptor &input_desc,
4825     const DeviceMemory<float> &input_data,
4826     const dnn::RnnStateTensorDescriptor &input_h_desc,
4827     const DeviceMemory<float> &input_h_data,
4828     const dnn::RnnStateTensorDescriptor &input_c_desc,
4829     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
4830     const dnn::RnnSequenceTensorDescriptor &output_desc,
4831     const DeviceMemory<float> &output_data,
4832     const dnn::RnnStateTensorDescriptor &output_h_desc,
4833     const DeviceMemory<float> &output_h_data,
4834     const dnn::RnnStateTensorDescriptor &output_c_desc,
4835     const DeviceMemory<float> &output_c_data,
4836     const DeviceMemory<float> &output_backprop_data,
4837     const DeviceMemory<float> &output_h_backprop_data,
4838     const DeviceMemory<float> &output_c_backprop_data,
4839     DeviceMemory<float> *input_backprop_data,
4840     DeviceMemory<float> *input_h_backprop_data,
4841     DeviceMemory<float> *input_c_backprop_data,
4842     DeviceMemory<float> *params_backprop_data,
4843     DeviceMemory<uint8> *reserve_space_data,
4844     ScratchAllocator *workspace_allocator) {
4845   // TODO(zhengxq): add VLOG PARAM calls.
4846   if (ok()) {
4847     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4848       CheckError(dnn->DoRnnBackward(
4849           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4850           input_c_desc, input_c_data, params, output_desc, output_data,
4851           output_h_desc, output_h_data, output_c_desc, output_c_data,
4852           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4853           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4854           params_backprop_data, reserve_space_data, workspace_allocator));
4855     } else {
4856       SetError();
4857       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4858     }
4859   }
4860   return *this;
4861 }
4862 
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4863 Stream &Stream::ThenRnnBackward(
4864     const dnn::RnnDescriptor &rnn_desc,
4865     const dnn::RnnSequenceTensorDescriptor &input_desc,
4866     const DeviceMemory<double> &input_data,
4867     const dnn::RnnStateTensorDescriptor &input_h_desc,
4868     const DeviceMemory<double> &input_h_data,
4869     const dnn::RnnStateTensorDescriptor &input_c_desc,
4870     const DeviceMemory<double> &input_c_data,
4871     const DeviceMemory<double> &params,
4872     const dnn::RnnSequenceTensorDescriptor &output_desc,
4873     const DeviceMemory<double> &output_data,
4874     const dnn::RnnStateTensorDescriptor &output_h_desc,
4875     const DeviceMemory<double> &output_h_data,
4876     const dnn::RnnStateTensorDescriptor &output_c_desc,
4877     const DeviceMemory<double> &output_c_data,
4878     const DeviceMemory<double> &output_backprop_data,
4879     const DeviceMemory<double> &output_h_backprop_data,
4880     const DeviceMemory<double> &output_c_backprop_data,
4881     DeviceMemory<double> *input_backprop_data,
4882     DeviceMemory<double> *input_h_backprop_data,
4883     DeviceMemory<double> *input_c_backprop_data,
4884     DeviceMemory<double> *params_backprop_data,
4885     DeviceMemory<uint8> *reserve_space_data,
4886     ScratchAllocator *workspace_allocator) {
4887   // TODO(zhengxq): add VLOG PARAM calls.
4888   if (ok()) {
4889     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4890       CheckError(dnn->DoRnnBackward(
4891           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4892           input_c_desc, input_c_data, params, output_desc, output_data,
4893           output_h_desc, output_h_data, output_c_desc, output_c_data,
4894           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4895           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4896           params_backprop_data, reserve_space_data, workspace_allocator));
4897     } else {
4898       SetError();
4899       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4900     }
4901   }
4902   return *this;
4903 }
4904 
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4905 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4906                                     dnn::DataType input_type,
4907                                     const DeviceMemoryBase &input_data,
4908                                     const dnn::BatchDescriptor &output_desc,
4909                                     dnn::DataType output_type, float scale,
4910                                     DeviceMemoryBase *output_data) {
4911   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4912             PARAM(output_desc), PARAM(output_type), PARAM(scale),
4913             PARAM(output_data));
4914   if (ok()) {
4915     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4916       CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
4917                                         input_data, output_desc, output_type,
4918                                         scale, output_data));
4919     } else {
4920       SetErrorAndLogNoDnnSupport();
4921     }
4922   }
4923   return *this;
4924 }
4925 
ThenDoHostCallbackForTest(std::function<void ()> callback)4926 Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
4927   VLOG_CALL(PARAM(callback));
4928 
4929   return ThenDoHostCallback(callback);
4930 }
4931 
ThenDoHostCallback(std::function<void ()> callback)4932 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4933   VLOG_CALL(PARAM(callback));
4934 
4935   if (ok()) {
4936     CheckError(parent_->HostCallback(this, callback));
4937   } else {
4938     LOG(INFO) << "stream " << this
4939               << " was in error state before adding host callback";
4940   }
4941   return *this;
4942 }
4943 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4944 Stream &Stream::ThenFft(fft::Plan *plan,
4945                         const DeviceMemory<std::complex<float>> &input,
4946                         DeviceMemory<std::complex<float>> *output) {
4947   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4948 
4949   if (ok()) {
4950     if (fft::FftSupport *fft = parent_->AsFft()) {
4951       CheckError(fft->DoFft(this, plan, input, output));
4952     } else {
4953       SetError();
4954       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4955                    "without FFT support";
4956     }
4957   }
4958   return *this;
4959 }
4960 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4961 Stream &Stream::ThenFft(fft::Plan *plan,
4962                         const DeviceMemory<std::complex<double>> &input,
4963                         DeviceMemory<std::complex<double>> *output) {
4964   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4965 
4966   if (ok()) {
4967     if (fft::FftSupport *fft = parent_->AsFft()) {
4968       CheckError(fft->DoFft(this, plan, input, output));
4969     } else {
4970       SetError();
4971       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4972                    "without FFT support";
4973     }
4974   }
4975   return *this;
4976 }
4977 
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4978 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4979                         DeviceMemory<std::complex<float>> *output) {
4980   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4981 
4982   if (ok()) {
4983     if (fft::FftSupport *fft = parent_->AsFft()) {
4984       CheckError(fft->DoFft(this, plan, input, output));
4985     } else {
4986       SetError();
4987       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4988                    "without FFT support";
4989     }
4990   }
4991   return *this;
4992 }
4993 
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4994 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4995                         DeviceMemory<std::complex<double>> *output) {
4996   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4997 
4998   if (ok()) {
4999     if (fft::FftSupport *fft = parent_->AsFft()) {
5000       CheckError(fft->DoFft(this, plan, input, output));
5001     } else {
5002       SetError();
5003       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5004                    "without FFT support";
5005     }
5006   }
5007   return *this;
5008 }
5009 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5010 Stream &Stream::ThenFft(fft::Plan *plan,
5011                         const DeviceMemory<std::complex<float>> &input,
5012                         DeviceMemory<float> *output) {
5013   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5014 
5015   if (ok()) {
5016     if (fft::FftSupport *fft = parent_->AsFft()) {
5017       CheckError(fft->DoFft(this, plan, input, output));
5018     } else {
5019       SetError();
5020       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5021                    "without FFT support";
5022     }
5023   }
5024   return *this;
5025 }
5026 
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5027 Stream &Stream::ThenFft(fft::Plan *plan,
5028                         const DeviceMemory<std::complex<double>> &input,
5029                         DeviceMemory<double> *output) {
5030   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5031 
5032   if (ok()) {
5033     if (fft::FftSupport *fft = parent_->AsFft()) {
5034       CheckError(fft->DoFft(this, plan, input, output));
5035     } else {
5036       SetError();
5037       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5038                    "without FFT support";
5039     }
5040   }
5041   return *this;
5042 }
5043 
5044 // It looks confusing, but all this is doing is inserting a callback at the
5045 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5046 Stream &Stream::ThenEnqueueOnBackgroundThread(
5047     std::function<void(StreamExecutor *)> task) {
5048   VLOG_CALL(PARAM(task));
5049 
5050   StreamExecutor *stream_executor = this->parent_;
5051   std::function<void()> bound_task = std::bind(task, stream_executor);
5052 
5053   return ThenDoHostCallback([stream_executor, bound_task]() {
5054     stream_executor->EnqueueOnBackgroundThread(bound_task);
5055   });
5056 }
5057 
BlockHostUntilDone()5058 port::Status Stream::BlockHostUntilDone() {
5059   VLOG_CALL();
5060 
5061   if (!ok()) {
5062     port::Status status = port::Status(
5063         port::error::INTERNAL,
5064         "stream did not block host until done; was already in an error state");
5065     LOG(INFO) << status << " " << this;
5066     return status;
5067   }
5068 
5069   port::Status first_error;
5070   {
5071     // Wait until all active sub-streams have done their tasks.
5072     mutex_lock lock{mu_};
5073     for (auto &stream : sub_streams_) {
5074       if (!stream.second) {
5075         first_error.Update(stream.first->BlockHostUntilDone());
5076         // Set this sub-stream as available.
5077         stream.second = true;
5078       }
5079     }
5080   }
5081 
5082   temporary_memory_manager_.DeallocateFinalizedTemporaries();
5083 
5084   first_error.Update(parent_->BlockHostUntilDone(this));
5085   CheckError(first_error.ok());
5086   return first_error;
5087 }
5088 
5089 }  // namespace gputools
5090 }  // namespace perftools
5091