1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/stream.h"
17
18 #include "tensorflow/stream_executor/platform/port.h"
19
20 #include "tensorflow/stream_executor/blas.h"
21 #include "tensorflow/stream_executor/host_buffer.h"
22 #include "tensorflow/stream_executor/lib/stacktrace.h"
23 #include "tensorflow/stream_executor/lib/strcat.h"
24 #include "tensorflow/stream_executor/platform.h"
25 #include "tensorflow/stream_executor/platform/logging.h"
26 #include "tensorflow/stream_executor/rng.h"
27 #include "tensorflow/stream_executor/stream_executor_internal.h"
28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
29
30 namespace perftools {
31 namespace gputools {
32
33 namespace {
34 // Code to turn parameters to functions on stream into strings that
35 // will be VLOG'ed. We need overloads, instead of
36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
37 // functions does not know what the type of the parameter is.
ToVlogString(const dnn::BatchDescriptor & descriptor)38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
39 return descriptor.ToShortString();
40 }
41
ToVlogString(const dnn::FilterDescriptor & descriptor)42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
43 return descriptor.ToShortString();
44 }
45
ToVlogString(const dnn::ConvolutionDescriptor & descriptor)46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
47 return descriptor.ToShortString();
48 }
49
ToVlogString(const dnn::PoolingDescriptor & descriptor)50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
51 return descriptor.ToShortString();
52 }
53
ToVlogString(const dnn::NormalizeDescriptor & descriptor)54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
55 return descriptor.ToShortString();
56 }
57
ToVlogString(dnn::ActivationMode mode)58 string ToVlogString(dnn::ActivationMode mode) {
59 return dnn::ActivationModeString(mode);
60 }
61
ToVlogString(const dnn::AlgorithmConfig & algo_config)62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
63 return algo_config.ToString();
64 }
65
ToVlogString(dnn::ElementwiseOperation op)66 string ToVlogString(dnn::ElementwiseOperation op) {
67 return dnn::ElementwiseOperationString(op);
68 }
69
ToVlogString(dnn::QuantizedActivationMode mode)70 string ToVlogString(dnn::QuantizedActivationMode mode) {
71 return dnn::QuantizedActivationModeString(mode);
72 }
73
ToVlogString(blas::Transpose t)74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
75
ToVlogString(blas::UpperLower ul)76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
77
ToVlogString(blas::Diagonal d)78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
79
ToVlogString(blas::Side s)80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
81
ToVlogString(blas::ComputationType ty)82 string ToVlogString(blas::ComputationType ty) {
83 return blas::ComputationTypeString(ty);
84 }
85
ToVlogString(const void * ptr)86 string ToVlogString(const void *ptr) {
87 if (ptr == nullptr) {
88 return "null";
89 }
90
91 // StrCat does not convert pointers to text.
92 std::ostringstream out;
93 out << ptr;
94 return out.str();
95 }
96
ToVlogString(const HostBuffer & buffer)97 string ToVlogString(const HostBuffer &buffer) { return buffer.AsString(); }
98
99 template <class T>
ToVlogString(const std::complex<T> & c)100 string ToVlogString(const std::complex<T> &c) {
101 // StrCat does not convert std::complex to text.
102 std::ostringstream out;
103 out << c;
104 return out.str();
105 }
106
107 template <class T>
ToVlogString(const std::function<T> & f)108 string ToVlogString(const std::function<T> &f) {
109 return f == nullptr ? "null" : "<non-null function>";
110 }
111
ToVlogString(const DeviceMemoryBase & memory)112 string ToVlogString(const DeviceMemoryBase &memory) {
113 return ToVlogString(memory.opaque());
114 }
115
ToVlogString(const DeviceMemoryBase * memory)116 string ToVlogString(const DeviceMemoryBase *memory) {
117 return ToVlogString(*memory);
118 }
119
ToVlogString(const Eigen::half & h)120 string ToVlogString(const Eigen::half &h) { return port::StrCat(h); }
121
ToVlogString(int i)122 string ToVlogString(int i) { return port::StrCat(i); }
123
ToVlogString(uint32 i)124 string ToVlogString(uint32 i) { return port::StrCat(i); }
125
ToVlogString(uint64 i)126 string ToVlogString(uint64 i) { return port::StrCat(i); }
127
ToVlogString(int64 i)128 string ToVlogString(int64 i) { return port::StrCat(i); }
129
ToVlogString(float f)130 string ToVlogString(float f) { return port::StrCat(f); }
131
ToVlogString(double d)132 string ToVlogString(double d) { return port::StrCat(d); }
133
134 template <class T>
ToVlogString(port::ArraySlice<T> elements)135 string ToVlogString(port::ArraySlice<T> elements) {
136 string str = port::StrCat(
137 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
138 elements.size(), "]{");
139 const char *separator = "";
140 size_t max_to_show = std::numeric_limits<size_t>::max();
141 if (!VLOG_IS_ON(2)) {
142 max_to_show = 5;
143 } else if (!VLOG_IS_ON(3)) {
144 max_to_show = 20;
145 } else if (!VLOG_IS_ON(11)) {
146 max_to_show = 1000;
147 }
148 for (size_t i = 0; i < elements.size(); ++i) {
149 if (i == max_to_show) {
150 str += ", ...";
151 break;
152 }
153 port::StrAppend(&str, separator, ToVlogString(elements[i]));
154 separator = ", ";
155 }
156 str += "}";
157 return str;
158 }
159
160 template <class T>
ToVlogString(port::MutableArraySlice<T> elements)161 string ToVlogString(port::MutableArraySlice<T> elements) {
162 return ToVlogString(port::ArraySlice<T>(elements));
163 }
164
ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout)165 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
166 switch (depth_to_space_layout) {
167 case dnn::DepthToSpaceLayout::DepthHeightWidth:
168 return "DepthToSpaceLayout::DepthHeightWidth";
169 }
170 return "unknown DepthToSpaceLayout";
171 }
172
ToVlogString(dnn::DataType data_type)173 string ToVlogString(dnn::DataType data_type) {
174 switch (data_type) {
175 case dnn::DataType::kFloat:
176 return "dnn::DataType::kFloat";
177 case dnn::DataType::kDouble:
178 return "dnn::DataType::kDouble";
179 case dnn::DataType::kHalf:
180 return "dnn::DataType::kHalf";
181 case dnn::DataType::kInt8:
182 return "dnn::DataType::kInt8";
183 }
184 }
185
186 // Used together with PARAM to VLOG calls made to the stream. Intended
187 // to be used like this:
188 //
189 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
190 //
191 // where a and b are the parameters to MyFunction.
192 //
193 // See VLOG_CALL for a short-hand for this. This way of doing it saves
194 // a tremendous amount of boilerplate code given how many functions
195 // there are on Stream and how many parameters they each have.
CallStr(const char * function_name,Stream * stream,std::vector<std::pair<const char *,string>> params)196 string CallStr(const char *function_name, Stream *stream,
197 std::vector<std::pair<const char *, string>> params) {
198 // Do not call this function unless VLOG is on since just
199 // constructing all the strings in params is expensive.
200 CHECK(VLOG_IS_ON(1));
201
202 string str = port::StrCat("Called Stream::", function_name, "(");
203 const char *separator = "";
204 for (const auto ¶m : params) {
205 port::StrAppend(&str, separator, param.first, "=", param.second);
206 separator = ", ";
207 }
208 port::StrAppend(&str, ") stream=", ToVlogString(stream));
209 if (VLOG_IS_ON(10)) {
210 port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
211 }
212 return str;
213 }
214
215 // Use this macro to avoid having to type every parameter twice to log
216 // it with VLOG and CallStr.
217 #define PARAM(parameter) \
218 { #parameter, ToVlogString(parameter) }
219
220 // Use this macro to avoid having to type out the name of each
221 // function and to save some boilerplate. Intended to be used like this:
222 //
223 // VLOG_CALL(PARAM(a), PARAM(b))
224 //
225 // This saves a tremendous amount of boilerplate compared to the alternative:
226 //
227 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
228 // << ", b=" << ToVlogString(b);
229 //
230 // Note here that most of the parameter names are not short and that
231 // most of the functions take many more than 2 parameters.
232 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
233
234 } // namespace
235
Stream(StreamExecutor * parent)236 Stream::Stream(StreamExecutor *parent)
237 : parent_(parent),
238 implementation_(parent->implementation()->GetStreamImplementation()),
239 allocated_(false),
240 ok_(false),
241 temporary_memory_manager_(this) {
242 VLOG_CALL(PARAM(parent));
243 }
244
Stream(StreamExecutor * parent,internal::StreamInterface * implementation)245 Stream::Stream(StreamExecutor *parent,
246 internal::StreamInterface *implementation)
247 : parent_(parent),
248 implementation_(implementation),
249 allocated_(false),
250 ok_(false),
251 temporary_memory_manager_(this) {
252 VLOG_CALL(PARAM(parent), PARAM(implementation));
253 }
254
~Stream()255 Stream::~Stream() {
256 VLOG_CALL();
257
258 temporary_memory_manager_.ForceDeallocateAll();
259
260 if (allocated_) {
261 parent_->DeallocateStream(this);
262 }
263 }
264
Init()265 Stream &Stream::Init() {
266 VLOG_CALL();
267
268 mutex_lock lock{mu_};
269 CHECK_EQ(false, allocated_)
270 << "stream appears to already have been initialized";
271 CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
272
273 if (parent_->AllocateStream(this)) {
274 // Successful initialization!
275 allocated_ = true;
276 ok_ = true;
277 } else {
278 LOG(ERROR) << "failed to allocate stream during initialization";
279 }
280
281 return *this;
282 }
283
InitTimer(Timer * timer)284 Stream &Stream::InitTimer(Timer *timer) {
285 VLOG_CALL(PARAM(timer));
286
287 if (ok()) {
288 CheckError(parent_->AllocateTimer(timer));
289 } else {
290 LOG(INFO) << "did not allocate timer: " << timer;
291 }
292 return *this;
293 }
294
InitWithTimer(Timer * timer)295 Stream &Stream::InitWithTimer(Timer *timer) {
296 VLOG_CALL(PARAM(timer));
297
298 return Init().InitTimer(timer);
299 }
300
ThenRecordEvent(Event * event)301 Stream &Stream::ThenRecordEvent(Event *event) {
302 VLOG_CALL(PARAM(event));
303
304 port::Status status = parent_->RecordEvent(this, event);
305 if (!status.ok()) {
306 LOG(ERROR) << "Error recording event in stream: " << status.error_message()
307 << "; not marking stream as bad, as the Event object may be "
308 << "at fault. Monitor for further errors.";
309 }
310
311 return *this;
312 }
313
ThenBatchNormalizationForward(const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)314 Stream &Stream::ThenBatchNormalizationForward(
315 const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
316 const DeviceMemory<float> &offset,
317 const DeviceMemory<float> &estimated_mean,
318 const DeviceMemory<float> &estimated_variance,
319 const dnn::BatchDescriptor &x_desc,
320 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
321 DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
322 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
323 DeviceMemory<float> *saved_inv_var, bool is_training,
324 std::function<const DeviceMemory<float> &()> var_to_inv_var,
325 std::function<void()> inv_var_to_var) {
326 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
327 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
328 if (ok()) {
329 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
330 CheckError(dnn->DoBatchNormalizationForward(
331 this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
332 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
333 saved_inv_var, is_training, std::move(var_to_inv_var),
334 std::move(inv_var_to_var)));
335 } else {
336 SetErrorAndLogNoDnnSupport();
337 }
338 }
339 return *this;
340 }
341
ThenBatchNormalizationBackward(const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)342 Stream &Stream::ThenBatchNormalizationBackward(
343 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
344 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
345 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
346 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
347 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
348 DeviceMemory<float> *offset_backprop) {
349 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
350 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
351 PARAM(scale_backprop), PARAM(offset_backprop));
352 if (ok()) {
353 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
354 CheckError(dnn->DoBatchNormalizationBackward(
355 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
356 epsilon, x_backprop, scale_backprop, offset_backprop));
357 } else {
358 SetErrorAndLogNoDnnSupport();
359 }
360 }
361 return *this;
362 }
363
ThenBatchNormalizationForward(const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,std::function<const DeviceMemory<float> & ()> var_to_inv_var,std::function<void ()> inv_var_to_var)364 Stream &Stream::ThenBatchNormalizationForward(
365 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
366 const DeviceMemory<float> &offset,
367 const DeviceMemory<float> &estimated_mean,
368 const DeviceMemory<float> &estimated_variance,
369 const dnn::BatchDescriptor &x_desc,
370 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
371 DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
372 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
373 DeviceMemory<float> *saved_inv_var, bool is_training,
374 std::function<const DeviceMemory<float> &()> var_to_inv_var,
375 std::function<void()> inv_var_to_var) {
376 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
377 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
378 if (ok()) {
379 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
380 CheckError(dnn->DoBatchNormalizationForward(
381 this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
382 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
383 saved_inv_var, is_training, std::move(var_to_inv_var),
384 std::move(inv_var_to_var)));
385 } else {
386 SetErrorAndLogNoDnnSupport();
387 }
388 }
389 return *this;
390 }
391
ThenBatchNormalizationBackward(const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop)392 Stream &Stream::ThenBatchNormalizationBackward(
393 const DeviceMemory<Eigen::half> &y_backprop,
394 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
395 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
396 const dnn::BatchDescriptor &x_desc,
397 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
398 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
399 DeviceMemory<float> *offset_backprop) {
400 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
401 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
402 PARAM(scale_backprop), PARAM(offset_backprop));
403 if (ok()) {
404 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
405 CheckError(dnn->DoBatchNormalizationBackward(
406 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
407 epsilon, x_backprop, scale_backprop, offset_backprop));
408 } else {
409 SetErrorAndLogNoDnnSupport();
410 }
411 }
412 return *this;
413 }
414
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator)415 Stream &Stream::ThenFusedConvolveWithScratch(
416 const dnn::BatchDescriptor &conv_input_descriptor,
417 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
418 const dnn::FilterDescriptor &filter_descriptor,
419 const DeviceMemory<int8> &filter_data,
420 const dnn::ConvolutionDescriptor &convolution_descriptor,
421 const DeviceMemory<int8> &side_input_data, float side_input_scale,
422 const dnn::BatchDescriptor &bias_descriptor,
423 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
424 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
425 ScratchAllocator *scratch_allocator) {
426 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
427 PARAM(conv_input_scale), PARAM(filter_descriptor),
428 PARAM(filter_data), PARAM(convolution_descriptor),
429 PARAM(side_input_data), PARAM(side_input_scale),
430 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
431 PARAM(output_descriptor), PARAM(output));
432
433 if (ok()) {
434 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
435 CheckError(dnn->DoFusedConvolve(
436 this, conv_input_descriptor, conv_input_data, conv_input_scale,
437 filter_descriptor, filter_data, convolution_descriptor,
438 side_input_data, side_input_scale, bias_descriptor, biases,
439 activation_mode, output_descriptor, output, scratch_allocator,
440 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
441 } else {
442 SetErrorAndLogNoDnnSupport();
443 }
444 }
445 return *this;
446 }
447
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator)448 Stream &Stream::ThenFusedConvolveWithScratch(
449 const dnn::BatchDescriptor &conv_input_descriptor,
450 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
451 const dnn::FilterDescriptor &filter_descriptor,
452 const DeviceMemory<Eigen::half> &filter_data,
453 const dnn::ConvolutionDescriptor &convolution_descriptor,
454 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
455 const dnn::BatchDescriptor &bias_descriptor,
456 const DeviceMemory<Eigen::half> &biases,
457 dnn::ActivationMode activation_mode,
458 const dnn::BatchDescriptor &output_descriptor,
459 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
460 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
461 PARAM(conv_input_scale), PARAM(filter_descriptor),
462 PARAM(filter_data), PARAM(convolution_descriptor),
463 PARAM(side_input_data), PARAM(side_input_scale),
464 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
465 PARAM(output_descriptor), PARAM(output));
466
467 if (ok()) {
468 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
469 CheckError(dnn->DoFusedConvolve(
470 this, conv_input_descriptor, conv_input_data, conv_input_scale,
471 filter_descriptor, filter_data, convolution_descriptor,
472 side_input_data, side_input_scale, bias_descriptor, biases,
473 activation_mode, output_descriptor, output, scratch_allocator,
474 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
475 } else {
476 SetErrorAndLogNoDnnSupport();
477 }
478 }
479 return *this;
480 }
481
ThenFusedConvolveWithScratch(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator)482 Stream &Stream::ThenFusedConvolveWithScratch(
483 const dnn::BatchDescriptor &conv_input_descriptor,
484 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
485 const dnn::FilterDescriptor &filter_descriptor,
486 const DeviceMemory<float> &filter_data,
487 const dnn::ConvolutionDescriptor &convolution_descriptor,
488 const DeviceMemory<float> &side_input_data, float side_input_scale,
489 const dnn::BatchDescriptor &bias_descriptor,
490 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
491 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
492 ScratchAllocator *scratch_allocator) {
493 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
494 PARAM(conv_input_scale), PARAM(filter_descriptor),
495 PARAM(filter_data), PARAM(convolution_descriptor),
496 PARAM(side_input_data), PARAM(side_input_scale),
497 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
498 PARAM(output_descriptor), PARAM(output));
499
500 if (ok()) {
501 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
502 CheckError(dnn->DoFusedConvolve(
503 this, conv_input_descriptor, conv_input_data, conv_input_scale,
504 filter_descriptor, filter_data, convolution_descriptor,
505 side_input_data, side_input_scale, bias_descriptor, biases,
506 activation_mode, output_descriptor, output, scratch_allocator,
507 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
508 } else {
509 SetErrorAndLogNoDnnSupport();
510 }
511 }
512 return *this;
513 }
514
ThenConvolveWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator)515 Stream &Stream::ThenConvolveWithScratch(
516 const dnn::BatchDescriptor &input_descriptor,
517 const DeviceMemory<Eigen::half> &input_data,
518 const dnn::FilterDescriptor &filter_descriptor,
519 const DeviceMemory<Eigen::half> &filter_data,
520 const dnn::ConvolutionDescriptor &convolution_descriptor,
521 const dnn::BatchDescriptor &output_descriptor,
522 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
523 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
524 PARAM(filter_descriptor), PARAM(filter_data),
525 PARAM(convolution_descriptor), PARAM(output_descriptor),
526 PARAM(output));
527
528 if (ok()) {
529 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
530 CheckError(dnn->DoConvolve(
531 this, input_descriptor, input_data, filter_descriptor, filter_data,
532 convolution_descriptor, output_descriptor, output, scratch_allocator,
533 dnn::AlgorithmConfig(),
534 /*output_profile_result=*/nullptr));
535 } else {
536 SetErrorAndLogNoDnnSupport();
537 }
538 }
539 return *this;
540 }
541
ThenConvolveWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator)542 Stream &Stream::ThenConvolveWithScratch(
543 const dnn::BatchDescriptor &input_descriptor,
544 const DeviceMemory<float> &input_data,
545 const dnn::FilterDescriptor &filter_descriptor,
546 const DeviceMemory<float> &filter_data,
547 const dnn::ConvolutionDescriptor &convolution_descriptor,
548 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
549 ScratchAllocator *scratch_allocator) {
550 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
551 PARAM(filter_descriptor), PARAM(filter_data),
552 PARAM(convolution_descriptor), PARAM(output_descriptor),
553 PARAM(output));
554
555 if (ok()) {
556 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
557 CheckError(dnn->DoConvolve(
558 this, input_descriptor, input_data, filter_descriptor, filter_data,
559 convolution_descriptor, output_descriptor, output, scratch_allocator,
560 dnn::AlgorithmConfig(),
561 /*output_profile_result=*/nullptr));
562 } else {
563 SetErrorAndLogNoDnnSupport();
564 }
565 }
566 return *this;
567 }
568
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<float> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)569 Stream &Stream::ThenFusedConvolveWithAlgorithm(
570 const dnn::BatchDescriptor &conv_input_descriptor,
571 const DeviceMemory<float> &conv_input_data, float conv_input_scale,
572 const dnn::FilterDescriptor &filter_descriptor,
573 const DeviceMemory<float> &filter_data,
574 const dnn::ConvolutionDescriptor &convolution_descriptor,
575 const DeviceMemory<float> &side_input_data, float side_input_scale,
576 const dnn::BatchDescriptor &bias_descriptor,
577 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
578 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
579 ScratchAllocator *scratch_allocator,
580 const dnn::AlgorithmConfig &algorithm_config,
581 dnn::ProfileResult *output_profile_result) {
582 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
583 PARAM(conv_input_scale), PARAM(filter_descriptor),
584 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
585 PARAM(side_input_data), PARAM(side_input_scale),
586 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
587 PARAM(algorithm_config));
588
589 if (ok()) {
590 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
591 auto status = dnn->DoFusedConvolve(
592 this, conv_input_descriptor, conv_input_data, conv_input_scale,
593 filter_descriptor, filter_data, convolution_descriptor,
594 side_input_data, side_input_scale, bias_descriptor, biases,
595 activation_mode, output_descriptor, output, scratch_allocator,
596 algorithm_config, output_profile_result);
597 if (!status && !output_profile_result) {
598 SetError();
599 }
600 } else {
601 SetErrorAndLogNoDnnSupport();
602 }
603 }
604 return *this;
605 }
606
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<Eigen::half> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<Eigen::half> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<Eigen::half> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)607 Stream &Stream::ThenFusedConvolveWithAlgorithm(
608 const dnn::BatchDescriptor &conv_input_descriptor,
609 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
610 const dnn::FilterDescriptor &filter_descriptor,
611 const DeviceMemory<Eigen::half> &filter_data,
612 const dnn::ConvolutionDescriptor &convolution_descriptor,
613 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
614 const dnn::BatchDescriptor &bias_descriptor,
615 const DeviceMemory<Eigen::half> &biases,
616 dnn::ActivationMode activation_mode,
617 const dnn::BatchDescriptor &output_descriptor,
618 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
619 const dnn::AlgorithmConfig &algorithm_config,
620 dnn::ProfileResult *output_profile_result) {
621 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
622 PARAM(conv_input_scale), PARAM(filter_descriptor),
623 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
624 PARAM(side_input_data), PARAM(side_input_scale),
625 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
626 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
627
628 if (ok()) {
629 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
630 auto status = dnn->DoFusedConvolve(
631 this, conv_input_descriptor, conv_input_data, conv_input_scale,
632 filter_descriptor, filter_data, convolution_descriptor,
633 side_input_data, side_input_scale, bias_descriptor, biases,
634 activation_mode, output_descriptor, output, scratch_allocator,
635 algorithm_config, output_profile_result);
636 if (!status && !output_profile_result) {
637 SetError();
638 }
639 } else {
640 SetErrorAndLogNoDnnSupport();
641 }
642 }
643 return *this;
644 }
645
ThenFusedConvolveWithAlgorithm(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)646 Stream &Stream::ThenFusedConvolveWithAlgorithm(
647 const dnn::BatchDescriptor &conv_input_descriptor,
648 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
649 const dnn::FilterDescriptor &filter_descriptor,
650 const DeviceMemory<int8> &filter_data,
651 const dnn::ConvolutionDescriptor &convolution_descriptor,
652 const DeviceMemory<int8> &side_input_data, float side_input_scale,
653 const dnn::BatchDescriptor &bias_descriptor,
654 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
655 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
656 ScratchAllocator *scratch_allocator,
657 const dnn::AlgorithmConfig &algorithm_config,
658 dnn::ProfileResult *output_profile_result) {
659 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
660 PARAM(conv_input_scale), PARAM(filter_descriptor),
661 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
662 PARAM(side_input_data), PARAM(side_input_scale),
663 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
664 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
665
666 if (ok()) {
667 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
668 auto status = dnn->DoFusedConvolve(
669 this, conv_input_descriptor, conv_input_data, conv_input_scale,
670 filter_descriptor, filter_data, convolution_descriptor,
671 side_input_data, side_input_scale, bias_descriptor, biases,
672 activation_mode, output_descriptor, output, scratch_allocator,
673 algorithm_config, output_profile_result);
674 if (!status && !output_profile_result) {
675 SetError();
676 }
677 } else {
678 SetErrorAndLogNoDnnSupport();
679 }
680 }
681 return *this;
682 }
683
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)684 Stream &Stream::ThenConvolveWithAlgorithm(
685 const dnn::BatchDescriptor &input_descriptor,
686 const DeviceMemory<float> &input_data,
687 const dnn::FilterDescriptor &filter_descriptor,
688 const DeviceMemory<float> &filter_data,
689 const dnn::ConvolutionDescriptor &convolution_descriptor,
690 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
691 ScratchAllocator *scratch_allocator,
692 const dnn::AlgorithmConfig &algorithm_config,
693 dnn::ProfileResult *output_profile_result) {
694 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
695 PARAM(filter_descriptor), PARAM(filter_data),
696 PARAM(convolution_descriptor), PARAM(output_descriptor),
697 PARAM(output), PARAM(algorithm_config));
698
699 if (ok()) {
700 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
701 auto status = dnn->DoConvolve(
702 this, input_descriptor, input_data, filter_descriptor, filter_data,
703 convolution_descriptor, output_descriptor, output, scratch_allocator,
704 algorithm_config, output_profile_result);
705 if (!status && !output_profile_result) {
706 SetError();
707 }
708 } else {
709 SetErrorAndLogNoDnnSupport();
710 }
711 }
712 return *this;
713 }
714
ThenConvolveWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> * output,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)715 Stream &Stream::ThenConvolveWithAlgorithm(
716 const dnn::BatchDescriptor &input_descriptor,
717 const DeviceMemory<Eigen::half> &input_data,
718 const dnn::FilterDescriptor &filter_descriptor,
719 const DeviceMemory<Eigen::half> &filter_data,
720 const dnn::ConvolutionDescriptor &convolution_descriptor,
721 const dnn::BatchDescriptor &output_descriptor,
722 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
723 const dnn::AlgorithmConfig &algorithm_config,
724 dnn::ProfileResult *output_profile_result) {
725 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
726 PARAM(filter_descriptor), PARAM(filter_data),
727 PARAM(convolution_descriptor), PARAM(output_descriptor),
728 PARAM(output), PARAM(algorithm_config));
729
730 if (ok()) {
731 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
732 auto status = dnn->DoConvolve(
733 this, input_descriptor, input_data, filter_descriptor, filter_data,
734 convolution_descriptor, output_descriptor, output, scratch_allocator,
735 algorithm_config, output_profile_result);
736 if (!status && !output_profile_result) {
737 SetError();
738 }
739 } else {
740 SetErrorAndLogNoDnnSupport();
741 }
742 }
743 return *this;
744 }
745
ThenFusedConvolve(const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<int8> & conv_input_data,float conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const DeviceMemory<int8> & side_input_data,float side_input_scale,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<int8> * output)746 Stream &Stream::ThenFusedConvolve(
747 const dnn::BatchDescriptor &conv_input_descriptor,
748 const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
749 const dnn::FilterDescriptor &filter_descriptor,
750 const DeviceMemory<int8> &filter_data,
751 const dnn::ConvolutionDescriptor &convolution_descriptor,
752 const DeviceMemory<int8> &side_input_data, float side_input_scale,
753 const dnn::BatchDescriptor &bias_descriptor,
754 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
755 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) {
756 return ThenFusedConvolveWithScratch(
757 conv_input_descriptor, conv_input_data, conv_input_scale,
758 filter_descriptor, filter_data, convolution_descriptor, side_input_data,
759 side_input_scale, bias_descriptor, biases, activation_mode,
760 output_descriptor, output,
761 /*scratch_allocator=*/nullptr);
762 }
763
ThenConvolve(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)764 Stream &Stream::ThenConvolve(
765 const dnn::BatchDescriptor &input_descriptor,
766 const DeviceMemory<float> &input_data,
767 const dnn::FilterDescriptor &filter_descriptor,
768 const DeviceMemory<float> &filter_data,
769 const dnn::ConvolutionDescriptor &convolution_descriptor,
770 const dnn::BatchDescriptor &output_descriptor,
771 DeviceMemory<float> *output) {
772 return ThenConvolveWithScratch(input_descriptor, input_data,
773 filter_descriptor, filter_data,
774 convolution_descriptor, output_descriptor,
775 output, /*scratch_allocator=*/nullptr);
776 }
777
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)778 Stream &Stream::ThenConvolveQuantized(
779 const dnn::BatchDescriptor &input_descriptor,
780 const DeviceMemory<float> &input_data,
781 const dnn::FilterDescriptor &filter_descriptor,
782 const DeviceMemory<int8> &filter_coefficients,
783 const DeviceMemory<float> &coefficient_scales,
784 const dnn::ConvolutionDescriptor &convolution_descriptor,
785 const dnn::BatchDescriptor &output_descriptor,
786 DeviceMemory<float> *output) {
787 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
788 PARAM(filter_descriptor), PARAM(filter_coefficients),
789 PARAM(coefficient_scales), PARAM(convolution_descriptor),
790 PARAM(output_descriptor), PARAM(output));
791
792 if (ok()) {
793 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
794 CheckError(dnn->DoConvolveQuantized(
795 this, input_descriptor, input_data, filter_descriptor,
796 filter_coefficients, coefficient_scales, convolution_descriptor,
797 output_descriptor, output));
798 } else {
799 SetError();
800 LOG(WARNING)
801 << "attempting to perform DNN operation using StreamExecutor "
802 "without DNN support";
803 }
804 }
805 return *this;
806 }
807
ThenConvolveQuantized(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)808 Stream &Stream::ThenConvolveQuantized(
809 const dnn::BatchDescriptor &input_descriptor,
810 const DeviceMemory<float> &input_data,
811 const dnn::FilterDescriptor &filter_descriptor,
812 const DeviceMemory<int16> &filter_coefficients,
813 const DeviceMemory<float> &coefficient_scales,
814 const dnn::ConvolutionDescriptor &convolution_descriptor,
815 const dnn::BatchDescriptor &output_descriptor,
816 DeviceMemory<float> *output) {
817 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
818 PARAM(filter_descriptor), PARAM(filter_coefficients),
819 PARAM(coefficient_scales), PARAM(convolution_descriptor),
820 PARAM(output_descriptor), PARAM(output));
821
822 if (ok()) {
823 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
824 CheckError(dnn->DoConvolveQuantized(
825 this, input_descriptor, input_data, filter_descriptor,
826 filter_coefficients, coefficient_scales, convolution_descriptor,
827 output_descriptor, output));
828 } else {
829 SetError();
830 LOG(WARNING)
831 << "attempting to perform DNN operation using StreamExecutor "
832 "without DNN support";
833 }
834 }
835 return *this;
836 }
837
ThenSeparableConvolve(const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output)838 Stream &Stream::ThenSeparableConvolve(
839 const dnn::BatchDescriptor &batch_descriptor,
840 const DeviceMemory<float> &input_data,
841 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
842 const DeviceMemory<float> &first_weights,
843 const DeviceMemory<float> &second_weights,
844 const dnn::ConvolutionDescriptor &convolution_descriptor,
845 const dnn::BatchDescriptor &output_descriptor,
846 DeviceMemory<float> *output) {
847 VLOG_CALL(
848 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
849 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
850 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
851
852 if (ok()) {
853 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
854 CheckError(dnn->DoSeparableConvolve(
855 this, batch_descriptor, input_data, filter_descriptor,
856 depth_multiplier, first_weights, second_weights,
857 convolution_descriptor, output_descriptor, output));
858 } else {
859 SetErrorAndLogNoDnnSupport();
860 }
861 }
862 return *this;
863 }
864
ThenConvolveBackwardDataWithScratch(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator)865 Stream &Stream::ThenConvolveBackwardDataWithScratch(
866 const dnn::FilterDescriptor &filter_descriptor,
867 const DeviceMemory<float> &filter_data,
868 const dnn::BatchDescriptor &output_descriptor,
869 DeviceMemory<float> backward_output_data,
870 const dnn::ConvolutionDescriptor &convolution_descriptor,
871 const dnn::BatchDescriptor &input_descriptor,
872 DeviceMemory<float> *backward_input_data,
873 ScratchAllocator *scratch_allocator) {
874 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
875 PARAM(output_descriptor), PARAM(backward_output_data),
876 PARAM(convolution_descriptor), PARAM(input_descriptor),
877 PARAM(backward_input_data));
878
879 if (ok()) {
880 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
881 CheckError(dnn->DoConvolveBackwardData(
882 this, filter_descriptor, filter_data, output_descriptor,
883 backward_output_data, convolution_descriptor, input_descriptor,
884 backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
885 /*output_profile_result=*/nullptr));
886 } else {
887 SetErrorAndLogNoDnnSupport();
888 }
889 }
890 return *this;
891 }
892
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)893 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
894 const dnn::FilterDescriptor &filter_descriptor,
895 const DeviceMemory<float> &filter_data,
896 const dnn::BatchDescriptor &output_descriptor,
897 DeviceMemory<float> backward_output_data,
898 const dnn::ConvolutionDescriptor &convolution_descriptor,
899 const dnn::BatchDescriptor &input_descriptor,
900 DeviceMemory<float> *backward_input_data,
901 ScratchAllocator *scratch_allocator,
902 const dnn::AlgorithmConfig &algorithm_config,
903 dnn::ProfileResult *output_profile_result) {
904 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
905 PARAM(output_descriptor), PARAM(backward_output_data),
906 PARAM(convolution_descriptor), PARAM(input_descriptor),
907 PARAM(backward_input_data));
908
909 if (ok()) {
910 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
911 auto status = dnn->DoConvolveBackwardData(
912 this, filter_descriptor, filter_data, output_descriptor,
913 backward_output_data, convolution_descriptor, input_descriptor,
914 backward_input_data, scratch_allocator, algorithm_config,
915 output_profile_result);
916 if (!status && !output_profile_result) {
917 SetError();
918 }
919 } else {
920 SetErrorAndLogNoDnnSupport();
921 }
922 }
923 return *this;
924 }
925
ThenConvolveBackwardDataWithAlgorithm(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)926 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
927 const dnn::FilterDescriptor &filter_descriptor,
928 const DeviceMemory<Eigen::half> &filter_data,
929 const dnn::BatchDescriptor &output_descriptor,
930 DeviceMemory<Eigen::half> backward_output_data,
931 const dnn::ConvolutionDescriptor &convolution_descriptor,
932 const dnn::BatchDescriptor &input_descriptor,
933 DeviceMemory<Eigen::half> *backward_input_data,
934 ScratchAllocator *scratch_allocator,
935 const dnn::AlgorithmConfig &algorithm_config,
936 dnn::ProfileResult *output_profile_result) {
937 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
938 PARAM(output_descriptor), PARAM(backward_output_data),
939 PARAM(convolution_descriptor), PARAM(input_descriptor),
940 PARAM(backward_input_data));
941
942 if (ok()) {
943 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
944 auto status = dnn->DoConvolveBackwardData(
945 this, filter_descriptor, filter_data, output_descriptor,
946 backward_output_data, convolution_descriptor, input_descriptor,
947 backward_input_data, scratch_allocator, algorithm_config,
948 output_profile_result);
949 if (!status && !output_profile_result) {
950 SetError();
951 }
952 } else {
953 SetErrorAndLogNoDnnSupport();
954 }
955 }
956 return *this;
957 }
958
ThenConvolveBackwardDataWithScratch(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<Eigen::half> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<Eigen::half> * backward_input_data,ScratchAllocator * scratch_allocator)959 Stream &Stream::ThenConvolveBackwardDataWithScratch(
960 const dnn::FilterDescriptor &filter_descriptor,
961 const DeviceMemory<Eigen::half> &filter_data,
962 const dnn::BatchDescriptor &output_descriptor,
963 DeviceMemory<Eigen::half> backward_output_data,
964 const dnn::ConvolutionDescriptor &convolution_descriptor,
965 const dnn::BatchDescriptor &input_descriptor,
966 DeviceMemory<Eigen::half> *backward_input_data,
967 ScratchAllocator *scratch_allocator) {
968 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
969 PARAM(output_descriptor), PARAM(backward_output_data),
970 PARAM(convolution_descriptor), PARAM(input_descriptor),
971 PARAM(backward_input_data));
972
973 if (ok()) {
974 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
975 CheckError(dnn->DoConvolveBackwardData(
976 this, filter_descriptor, filter_data, output_descriptor,
977 backward_output_data, convolution_descriptor, input_descriptor,
978 backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
979 /*output_profile_result=*/nullptr));
980 } else {
981 SetErrorAndLogNoDnnSupport();
982 }
983 }
984 return *this;
985 }
986
ThenConvolveBackwardData(const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & input_descriptor,DeviceMemory<float> * backward_input_data)987 Stream &Stream::ThenConvolveBackwardData(
988 const dnn::FilterDescriptor &filter_descriptor,
989 const DeviceMemory<float> &filter_data,
990 const dnn::BatchDescriptor &output_descriptor,
991 DeviceMemory<float> backward_output_data,
992 const dnn::ConvolutionDescriptor &convolution_descriptor,
993 const dnn::BatchDescriptor &input_descriptor,
994 DeviceMemory<float> *backward_input_data) {
995 return ThenConvolveBackwardDataWithScratch(
996 filter_descriptor, filter_data, output_descriptor, backward_output_data,
997 convolution_descriptor, input_descriptor, backward_input_data,
998 /*scratch_allocator=*/nullptr);
999 }
1000
ThenConvolveBackwardFilterWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator)1001 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
1002 const dnn::BatchDescriptor &input_descriptor,
1003 const DeviceMemory<float> &input_data,
1004 const dnn::BatchDescriptor &output_descriptor,
1005 DeviceMemory<float> backward_output_data,
1006 const dnn::ConvolutionDescriptor &convolution_descriptor,
1007 const dnn::FilterDescriptor &filter_descriptor,
1008 DeviceMemory<float> *backward_filter_data,
1009 ScratchAllocator *scratch_allocator) {
1010 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1011 PARAM(output_descriptor), PARAM(backward_output_data),
1012 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1013 PARAM(backward_filter_data));
1014
1015 if (ok()) {
1016 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1017 CheckError(dnn->DoConvolveBackwardFilter(
1018 this, input_descriptor, input_data, output_descriptor,
1019 backward_output_data, convolution_descriptor, filter_descriptor,
1020 backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
1021 /*output_profile_result=*/nullptr));
1022 } else {
1023 SetErrorAndLogNoDnnSupport();
1024 }
1025 }
1026 return *this;
1027 }
1028
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1029 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1030 const dnn::BatchDescriptor &input_descriptor,
1031 const DeviceMemory<float> &input_data,
1032 const dnn::BatchDescriptor &output_descriptor,
1033 DeviceMemory<float> backward_output_data,
1034 const dnn::ConvolutionDescriptor &convolution_descriptor,
1035 const dnn::FilterDescriptor &filter_descriptor,
1036 DeviceMemory<float> *backward_filter_data,
1037 ScratchAllocator *scratch_allocator,
1038 const dnn::AlgorithmConfig &algorithm_config,
1039 dnn::ProfileResult *output_profile_result) {
1040 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1041 PARAM(output_descriptor), PARAM(backward_output_data),
1042 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1043 PARAM(backward_filter_data));
1044
1045 if (ok()) {
1046 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1047 auto status = dnn->DoConvolveBackwardFilter(
1048 this, input_descriptor, input_data, output_descriptor,
1049 backward_output_data, convolution_descriptor, filter_descriptor,
1050 backward_filter_data, scratch_allocator, algorithm_config,
1051 output_profile_result);
1052 if (!status && !output_profile_result) {
1053 SetError();
1054 }
1055 } else {
1056 SetErrorAndLogNoDnnSupport();
1057 }
1058 }
1059 return *this;
1060 }
1061
ThenConvolveBackwardFilterWithScratch(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator)1062 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
1063 const dnn::BatchDescriptor &input_descriptor,
1064 const DeviceMemory<Eigen::half> &input_data,
1065 const dnn::BatchDescriptor &output_descriptor,
1066 DeviceMemory<Eigen::half> backward_output_data,
1067 const dnn::ConvolutionDescriptor &convolution_descriptor,
1068 const dnn::FilterDescriptor &filter_descriptor,
1069 DeviceMemory<Eigen::half> *backward_filter_data,
1070 ScratchAllocator *scratch_allocator) {
1071 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1072 PARAM(output_descriptor), PARAM(backward_output_data),
1073 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1074 PARAM(backward_filter_data));
1075
1076 if (ok()) {
1077 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1078 CheckError(dnn->DoConvolveBackwardFilter(
1079 this, input_descriptor, input_data, output_descriptor,
1080 backward_output_data, convolution_descriptor, filter_descriptor,
1081 backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
1082 /*output_profile_result=*/nullptr));
1083 } else {
1084 SetErrorAndLogNoDnnSupport();
1085 }
1086 }
1087 return *this;
1088 }
1089
ThenConvolveBackwardFilterWithAlgorithm(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<Eigen::half> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<Eigen::half> * backward_filter_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)1090 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
1091 const dnn::BatchDescriptor &input_descriptor,
1092 const DeviceMemory<Eigen::half> &input_data,
1093 const dnn::BatchDescriptor &output_descriptor,
1094 DeviceMemory<Eigen::half> backward_output_data,
1095 const dnn::ConvolutionDescriptor &convolution_descriptor,
1096 const dnn::FilterDescriptor &filter_descriptor,
1097 DeviceMemory<Eigen::half> *backward_filter_data,
1098 ScratchAllocator *scratch_allocator,
1099 const dnn::AlgorithmConfig &algorithm_config,
1100 dnn::ProfileResult *output_profile_result) {
1101 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
1102 PARAM(output_descriptor), PARAM(backward_output_data),
1103 PARAM(convolution_descriptor), PARAM(filter_descriptor),
1104 PARAM(backward_filter_data));
1105
1106 if (ok()) {
1107 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1108 auto status = dnn->DoConvolveBackwardFilter(
1109 this, input_descriptor, input_data, output_descriptor,
1110 backward_output_data, convolution_descriptor, filter_descriptor,
1111 backward_filter_data, scratch_allocator, algorithm_config,
1112 output_profile_result);
1113 if (!status && !output_profile_result) {
1114 SetError();
1115 }
1116 } else {
1117 SetErrorAndLogNoDnnSupport();
1118 }
1119 }
1120 return *this;
1121 }
1122
ThenConvolveBackwardFilter(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> backward_output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::FilterDescriptor & filter_descriptor,DeviceMemory<float> * backward_filter_data)1123 Stream &Stream::ThenConvolveBackwardFilter(
1124 const dnn::BatchDescriptor &input_descriptor,
1125 const DeviceMemory<float> &input_data,
1126 const dnn::BatchDescriptor &output_descriptor,
1127 DeviceMemory<float> backward_output_data,
1128 const dnn::ConvolutionDescriptor &convolution_descriptor,
1129 const dnn::FilterDescriptor &filter_descriptor,
1130 DeviceMemory<float> *backward_filter_data) {
1131 return ThenConvolveBackwardFilterWithScratch(
1132 input_descriptor, input_data, output_descriptor, backward_output_data,
1133 convolution_descriptor, filter_descriptor, backward_filter_data,
1134 /*scratch_allocator=*/nullptr);
1135 }
1136
1137 template <typename T>
ThenConvolveBackwardBiasImpl(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<T> * backward_bias_data)1138 Stream &Stream::ThenConvolveBackwardBiasImpl(
1139 const dnn::BatchDescriptor &input_descriptor,
1140 const DeviceMemory<T> &input_data,
1141 const dnn::BatchDescriptor &bias_descriptor,
1142 DeviceMemory<T> *backward_bias_data) {
1143 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
1144 PARAM(backward_bias_data));
1145
1146 if (ok()) {
1147 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1148 CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
1149 bias_descriptor,
1150 backward_bias_data));
1151 } else {
1152 SetErrorAndLogNoDnnSupport();
1153 }
1154 }
1155 return *this;
1156 }
1157
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<double> * backward_bias_data)1158 Stream &Stream::ThenConvolveBackwardBias(
1159 const dnn::BatchDescriptor &input_descriptor,
1160 const DeviceMemory<double> &input_data,
1161 const dnn::BatchDescriptor &bias_descriptor,
1162 DeviceMemory<double> *backward_bias_data) {
1163 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1164 bias_descriptor, backward_bias_data);
1165 }
1166
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<float> * backward_bias_data)1167 Stream &Stream::ThenConvolveBackwardBias(
1168 const dnn::BatchDescriptor &input_descriptor,
1169 const DeviceMemory<float> &input_data,
1170 const dnn::BatchDescriptor &bias_descriptor,
1171 DeviceMemory<float> *backward_bias_data) {
1172 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1173 bias_descriptor, backward_bias_data);
1174 }
1175
ThenConvolveBackwardBias(const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & bias_descriptor,DeviceMemory<Eigen::half> * backward_bias_data)1176 Stream &Stream::ThenConvolveBackwardBias(
1177 const dnn::BatchDescriptor &input_descriptor,
1178 const DeviceMemory<Eigen::half> &input_data,
1179 const dnn::BatchDescriptor &bias_descriptor,
1180 DeviceMemory<Eigen::half> *backward_bias_data) {
1181 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
1182 bias_descriptor, backward_bias_data);
1183 }
1184
ThenMatMul(const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1185 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
1186 const DeviceMemory<float> &weights,
1187 const dnn::BatchDescriptor &input_dimensions,
1188 const dnn::BatchDescriptor &output_dimensions,
1189 DeviceMemory<float> *output_data) {
1190 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
1191 PARAM(output_dimensions), PARAM(output_data));
1192
1193 if (ok()) {
1194 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1195 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
1196 output_dimensions, output_data));
1197 } else {
1198 SetErrorAndLogNoDnnSupport();
1199 }
1200 }
1201 return *this;
1202 }
1203
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int8> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1204 Stream &Stream::ThenMatMulQuantized(
1205 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
1206 const DeviceMemory<float> &weight_scales,
1207 const dnn::BatchDescriptor &input_dimensions,
1208 const dnn::BatchDescriptor &output_dimensions,
1209 DeviceMemory<float> *output_data) {
1210 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1211 PARAM(input_dimensions), PARAM(output_dimensions),
1212 PARAM(output_data));
1213
1214 if (ok()) {
1215 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1216 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1217 weight_scales, input_dimensions,
1218 output_dimensions, output_data));
1219 } else {
1220 SetErrorAndLogNoDnnSupport();
1221 }
1222 }
1223 return *this;
1224 }
1225
ThenMatMulQuantized(const DeviceMemory<float> & input_data,const DeviceMemory<int16> & weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1226 Stream &Stream::ThenMatMulQuantized(
1227 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
1228 const DeviceMemory<float> &weight_scales,
1229 const dnn::BatchDescriptor &input_dimensions,
1230 const dnn::BatchDescriptor &output_dimensions,
1231 DeviceMemory<float> *output_data) {
1232 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
1233 PARAM(input_dimensions), PARAM(output_dimensions),
1234 PARAM(output_data));
1235
1236 if (ok()) {
1237 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1238 CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
1239 weight_scales, input_dimensions,
1240 output_dimensions, output_data));
1241 } else {
1242 SetErrorAndLogNoDnnSupport();
1243 }
1244 }
1245 return *this;
1246 }
1247
ThenBiasAdd(const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)1248 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
1249 const DeviceMemory<float> &biases,
1250 const dnn::BatchDescriptor &dimensions,
1251 DeviceMemory<float> *output_data) {
1252 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
1253 PARAM(output_data));
1254
1255 if (ok()) {
1256 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1257 CheckError(
1258 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
1259 } else {
1260 SetErrorAndLogNoDnnSupport();
1261 }
1262 }
1263 return *this;
1264 }
1265
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data)1266 Stream &Stream::ThenPoolForward(
1267 const dnn::PoolingDescriptor &pooling_dimensions,
1268 const dnn::BatchDescriptor &input_dimensions,
1269 const DeviceMemory<double> &input_data,
1270 const dnn::BatchDescriptor &output_dimensions,
1271 DeviceMemory<double> *output_data) {
1272 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1273 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1274
1275 if (ok()) {
1276 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1277 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1278 input_data, output_dimensions,
1279 output_data));
1280 } else {
1281 SetError();
1282 LOG(WARNING)
1283 << "attempting to perform DNN operation using StreamExecutor "
1284 "without DNN support";
1285 }
1286 }
1287 return *this;
1288 }
1289
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1290 Stream &Stream::ThenPoolForward(
1291 const dnn::PoolingDescriptor &pooling_dimensions,
1292 const dnn::BatchDescriptor &input_dimensions,
1293 const DeviceMemory<float> &input_data,
1294 const dnn::BatchDescriptor &output_dimensions,
1295 DeviceMemory<float> *output_data) {
1296 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1297 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1298
1299 if (ok()) {
1300 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1301 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1302 input_data, output_dimensions,
1303 output_data));
1304 } else {
1305 SetErrorAndLogNoDnnSupport();
1306 }
1307 }
1308 return *this;
1309 }
1310
ThenPoolForward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data)1311 Stream &Stream::ThenPoolForward(
1312 const dnn::PoolingDescriptor &pooling_dimensions,
1313 const dnn::BatchDescriptor &input_dimensions,
1314 const DeviceMemory<Eigen::half> &input_data,
1315 const dnn::BatchDescriptor &output_dimensions,
1316 DeviceMemory<Eigen::half> *output_data) {
1317 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1318 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
1319
1320 if (ok()) {
1321 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1322 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
1323 input_data, output_dimensions,
1324 output_data));
1325 } else {
1326 SetErrorAndLogNoDnnSupport();
1327 }
1328 }
1329 return *this;
1330 }
1331
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data)1332 Stream &Stream::ThenPoolBackward(
1333 const dnn::PoolingDescriptor &pooling_dimensions,
1334 const dnn::BatchDescriptor &input_dimensions,
1335 const DeviceMemory<double> &input_data,
1336 const dnn::BatchDescriptor &output_dimensions,
1337 const DeviceMemory<double> &output_data,
1338 const DeviceMemory<double> &input_diff_data,
1339 DeviceMemory<double> *output_diff_data) {
1340 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1341 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1342 PARAM(input_diff_data), PARAM(output_diff_data));
1343
1344 if (ok()) {
1345 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1346 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1347 input_data, output_dimensions, output_data,
1348 input_diff_data, output_diff_data));
1349 } else {
1350 SetError();
1351 LOG(WARNING)
1352 << "attempting to perform DNN operation using StreamExecutor "
1353 "without DNN support";
1354 }
1355 }
1356 return *this;
1357 }
1358
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data)1359 Stream &Stream::ThenPoolBackward(
1360 const dnn::PoolingDescriptor &pooling_dimensions,
1361 const dnn::BatchDescriptor &input_dimensions,
1362 const DeviceMemory<float> &input_data,
1363 const dnn::BatchDescriptor &output_dimensions,
1364 const DeviceMemory<float> &output_data,
1365 const DeviceMemory<float> &input_diff_data,
1366 DeviceMemory<float> *output_diff_data) {
1367 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1368 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1369 PARAM(input_diff_data), PARAM(output_diff_data));
1370
1371 if (ok()) {
1372 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1373 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1374 input_data, output_dimensions, output_data,
1375 input_diff_data, output_diff_data));
1376 } else {
1377 SetErrorAndLogNoDnnSupport();
1378 }
1379 }
1380 return *this;
1381 }
1382
ThenPoolBackward(const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data)1383 Stream &Stream::ThenPoolBackward(
1384 const dnn::PoolingDescriptor &pooling_dimensions,
1385 const dnn::BatchDescriptor &input_dimensions,
1386 const DeviceMemory<Eigen::half> &input_data,
1387 const dnn::BatchDescriptor &output_dimensions,
1388 const DeviceMemory<Eigen::half> &output_data,
1389 const DeviceMemory<Eigen::half> &input_diff_data,
1390 DeviceMemory<Eigen::half> *output_diff_data) {
1391 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
1392 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
1393 PARAM(input_diff_data), PARAM(output_diff_data));
1394
1395 if (ok()) {
1396 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1397 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
1398 input_data, output_dimensions, output_data,
1399 input_diff_data, output_diff_data));
1400 } else {
1401 SetErrorAndLogNoDnnSupport();
1402 }
1403 }
1404 return *this;
1405 }
1406
ThenNormalize(const dnn::NormalizeDescriptor & normalize_descriptor,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1407 Stream &Stream::ThenNormalize(
1408 const dnn::NormalizeDescriptor &normalize_descriptor,
1409 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1410 VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data));
1411
1412 if (ok()) {
1413 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1414 CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data,
1415 output_data));
1416 } else {
1417 SetErrorAndLogNoDnnSupport();
1418 }
1419 }
1420 return *this;
1421 }
1422
ThenNormalizeWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1423 Stream &Stream::ThenNormalizeWithDimensions(
1424 const dnn::NormalizeDescriptor &normalize_descriptor,
1425 const dnn::BatchDescriptor &dimensions,
1426 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
1427 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
1428 PARAM(output_data));
1429
1430 if (ok()) {
1431 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1432 CheckError(dnn->DoNormalizeWithDimensions(
1433 this, normalize_descriptor, dimensions, input_data, output_data));
1434 } else {
1435 SetErrorAndLogNoDnnSupport();
1436 }
1437 }
1438 return *this;
1439 }
1440
ThenNormalizeBackwardWithDimensions(const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient)1441 Stream &Stream::ThenNormalizeBackwardWithDimensions(
1442 const dnn::NormalizeDescriptor &normalize_descriptor,
1443 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
1444 const DeviceMemory<float> &normalized_data,
1445 const DeviceMemory<float> &normalized_variable_gradient,
1446 DeviceMemory<float> *raw_variable_gradient) {
1447 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
1448 PARAM(normalized_data), PARAM(normalized_variable_gradient),
1449 PARAM(raw_variable_gradient));
1450
1451 if (ok()) {
1452 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1453 CheckError(dnn->DoNormalizeBackwardWithDimensions(
1454 this, normalize_descriptor, dimensions, raw_data, normalized_data,
1455 normalized_variable_gradient, raw_variable_gradient));
1456 } else {
1457 SetErrorAndLogNoDnnSupport();
1458 }
1459 }
1460 return *this;
1461 }
1462
ThenActivate(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)1463 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
1464 const dnn::BatchDescriptor &dimensions,
1465 const DeviceMemory<float> &input_data,
1466 DeviceMemory<float> *output_data) {
1467 return ThenActivateWithOptions(activation_mode, dimensions, input_data,
1468 output_data, /*options=*/0);
1469 }
1470
ThenActivateWithOptions(dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)1471 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
1472 const dnn::BatchDescriptor &dimensions,
1473 const DeviceMemory<float> &input_data,
1474 DeviceMemory<float> *output_data,
1475 uint64 options) {
1476 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
1477 PARAM(output_data), PARAM(options));
1478
1479 if (ok()) {
1480 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1481 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
1482 output_data, options));
1483 } else {
1484 SetErrorAndLogNoDnnSupport();
1485 }
1486 }
1487 return *this;
1488 }
1489
ThenDepthConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)1490 Stream &Stream::ThenDepthConcatenate(
1491 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1492 port::ArraySlice<const DeviceMemory<float> *> input_data,
1493 DeviceMemory<float> *output_data) {
1494 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1495
1496 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1497 if (input_dimensions[i].count() != input_dimensions[0].count() ||
1498 input_dimensions[i].height() != input_dimensions[0].height() ||
1499 input_dimensions[i].width() != input_dimensions[0].width()) {
1500 SetError();
1501 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
1502 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1503 << "input_dimensions[" << i
1504 << "]: " << input_dimensions[i].ToString();
1505 return *this;
1506 }
1507 }
1508
1509 if (ok()) {
1510 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1511 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
1512 output_data));
1513 } else {
1514 SetErrorAndLogNoDnnSupport();
1515 }
1516 }
1517 return *this;
1518 }
1519
ThenSpaceConcatenate(port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data,dnn::SpaceConcatenateMode concat_direction)1520 Stream &Stream::ThenSpaceConcatenate(
1521 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1522 port::ArraySlice<const DeviceMemory<float> *> input_data,
1523 DeviceMemory<float> *output_data,
1524 dnn::SpaceConcatenateMode concat_direction) {
1525 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
1526
1527 // Check that the input dimensions of all the other batches match those of the
1528 // first batch.
1529 for (size_t i = 1; i < input_dimensions.size(); ++i) {
1530 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
1531 (input_dimensions[i].count() != input_dimensions[0].count() ||
1532 input_dimensions[i].height() != input_dimensions[0].height() ||
1533 input_dimensions[i].feature_map_count() !=
1534 input_dimensions[0].feature_map_count())) {
1535 SetError();
1536 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
1537 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1538 << "input_dimensions[" << i
1539 << "]: " << input_dimensions[i].ToString();
1540 return *this;
1541 }
1542
1543 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
1544 (input_dimensions[i].count() != input_dimensions[0].count() ||
1545 input_dimensions[i].width() != input_dimensions[0].width() ||
1546 input_dimensions[i].feature_map_count() !=
1547 input_dimensions[0].feature_map_count())) {
1548 SetError();
1549 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
1550 << "input_dimensions[0]: " << input_dimensions[0].ToString()
1551 << "input_dimensions[" << i
1552 << "]: " << input_dimensions[i].ToString();
1553 return *this;
1554 }
1555 }
1556 if (ok()) {
1557 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1558 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
1559 output_data, concat_direction));
1560 } else {
1561 SetErrorAndLogNoDnnSupport();
1562 }
1563 }
1564 return *this;
1565 }
1566
ThenReshape(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1567 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
1568 const DeviceMemory<float> &input_data,
1569 const dnn::BatchDescriptor &output_dimensions,
1570 DeviceMemory<float> *output_data) {
1571 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1572 PARAM(output_dimensions), PARAM(output_data));
1573
1574 if (ok()) {
1575 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1576 CheckError(dnn->DoReshape(this, input_dimensions, input_data,
1577 output_dimensions, output_data));
1578 } else {
1579 SetErrorAndLogNoDnnSupport();
1580 }
1581 }
1582 return *this;
1583 }
1584
ThenDepthToSpace(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & depth_to_space_layout,const int sqrt_depth_reduction,DeviceMemory<float> * output_data)1585 Stream &Stream::ThenDepthToSpace(
1586 const dnn::BatchDescriptor &input_dimensions,
1587 const DeviceMemory<float> &input_data,
1588 const dnn::DepthToSpaceLayout &depth_to_space_layout,
1589 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
1590 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1591 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
1592 PARAM(output_data));
1593
1594 if (ok()) {
1595 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1596 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
1597 depth_to_space_layout,
1598 sqrt_depth_reduction, output_data));
1599 } else {
1600 SetErrorAndLogNoDnnSupport();
1601 }
1602 }
1603 return *this;
1604 }
1605
ThenSpaceToDepth(const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::DepthToSpaceLayout & space_to_depth_layout,const int sqrt_depth_increase,DeviceMemory<float> * output_data)1606 Stream &Stream::ThenSpaceToDepth(
1607 const dnn::BatchDescriptor &input_dimensions,
1608 const DeviceMemory<float> &input_data,
1609 const dnn::DepthToSpaceLayout &space_to_depth_layout,
1610 const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
1611 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
1612 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
1613 PARAM(output_data));
1614
1615 if (ok()) {
1616 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1617 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
1618 space_to_depth_layout, sqrt_depth_increase,
1619 output_data));
1620 } else {
1621 SetErrorAndLogNoDnnSupport();
1622 }
1623 }
1624 return *this;
1625 }
1626
ThenElementwiseOperate(dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1627 Stream &Stream::ThenElementwiseOperate(
1628 dnn::ElementwiseOperation operation,
1629 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1630 port::ArraySlice<const DeviceMemory<float> *> input_data,
1631 const dnn::BatchDescriptor &output_dimensions,
1632 DeviceMemory<float> *output_data) {
1633 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
1634 PARAM(output_dimensions), PARAM(output_data));
1635
1636 if (ok()) {
1637 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1638 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
1639 input_data, output_dimensions,
1640 output_data));
1641 } else {
1642 SetErrorAndLogNoDnnSupport();
1643 }
1644 }
1645 return *this;
1646 }
1647
ThenElementwiseOperateScaledQuantized(dnn::ElementwiseOperation operation,port::ArraySlice<int> input_multiplicands,int output_divisor,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)1648 Stream &Stream::ThenElementwiseOperateScaledQuantized(
1649 dnn::ElementwiseOperation operation,
1650 port::ArraySlice<int> input_multiplicands, int output_divisor,
1651 port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
1652 port::ArraySlice<const DeviceMemory<float> *> input_data,
1653 const dnn::BatchDescriptor &output_dimensions,
1654 DeviceMemory<float> *output_data) {
1655 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
1656 PARAM(input_dimensions), PARAM(input_data),
1657 PARAM(output_dimensions), PARAM(output_data));
1658
1659 if (ok()) {
1660 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1661 CheckError(dnn->DoElementwiseOperateScaledQuantized(
1662 this, operation, input_multiplicands, output_divisor,
1663 input_dimensions, input_data, output_dimensions, output_data));
1664 } else {
1665 SetErrorAndLogNoDnnSupport();
1666 }
1667 }
1668 return *this;
1669 }
1670
ThenXYPad(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)1671 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
1672 const DeviceMemory<float> &input_data, int64 left_pad,
1673 int64 right_pad, int64 top_pad, int64 bottom_pad,
1674 DeviceMemory<float> *output_data) {
1675 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
1676 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
1677 PARAM(output_data));
1678
1679 if (ok()) {
1680 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1681 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
1682 top_pad, bottom_pad, output_data));
1683 } else {
1684 SetErrorAndLogNoDnnSupport();
1685 }
1686 }
1687 return *this;
1688 }
1689
ThenXYSlice(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)1690 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
1691 const DeviceMemory<float> &input_data,
1692 int64 left_trim, int64 right_trim, int64 top_trim,
1693 int64 bottom_trim,
1694 DeviceMemory<float> *output_data) {
1695 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
1696 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
1697 PARAM(output_data));
1698
1699 if (ok()) {
1700 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1701 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
1702 right_trim, top_trim, bottom_trim,
1703 output_data));
1704 } else {
1705 SetErrorAndLogNoDnnSupport();
1706 }
1707 }
1708 return *this;
1709 }
1710
ThenXYBroadcast(const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 replicate_x,int64 replicate_y,DeviceMemory<float> * output_data)1711 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
1712 const DeviceMemory<float> &input_data,
1713 int64 replicate_x, int64 replicate_y,
1714 DeviceMemory<float> *output_data) {
1715 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
1716 PARAM(replicate_y), PARAM(output_data));
1717
1718 if (ok()) {
1719 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1720 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
1721 replicate_y, output_data));
1722 } else {
1723 SetErrorAndLogNoDnnSupport();
1724 }
1725 }
1726 return *this;
1727 }
1728
ThenMemcpyD2HQuantized(const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,uint64 size)1729 Stream &Stream::ThenMemcpyD2HQuantized(
1730 const DeviceMemory<float> &gpu_unquantized_src,
1731 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
1732 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
1733 PARAM(size));
1734
1735 if (ok()) {
1736 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1737 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
1738 host_dst, size));
1739 } else {
1740 SetErrorAndLogNoDnnSupport();
1741 }
1742 }
1743 return *this;
1744 }
1745
ThenMemcpyH2DQuantized(const void * host_src,uint64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)1746 Stream &Stream::ThenMemcpyH2DQuantized(
1747 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
1748 DeviceMemory<float> *gpu_unquantized_dst) {
1749 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
1750 PARAM(gpu_unquantized_dst));
1751
1752 if (ok()) {
1753 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1754 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
1755 gpu_unquantized_dst));
1756 } else {
1757 SetErrorAndLogNoDnnSupport();
1758 }
1759 }
1760 return *this;
1761 }
1762
ThenCopyHostBuffer2Device(HostBuffer * buffer_src,DeviceMemory<float> * gpu_unquantized_dst)1763 Stream &Stream::ThenCopyHostBuffer2Device(
1764 HostBuffer *buffer_src, DeviceMemory<float> *gpu_unquantized_dst) {
1765 VLOG_CALL(PARAM(*buffer_src), PARAM(gpu_unquantized_dst));
1766
1767 if (ok()) {
1768 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1769 CheckError(
1770 dnn->DoCopyHostBuffer2Device(this, buffer_src, gpu_unquantized_dst));
1771 } else {
1772 SetErrorAndLogNoDnnSupport();
1773 }
1774 }
1775 return *this;
1776 }
1777
ThenCopyDevice2HostBuffer(const DeviceMemory<float> & gpu_unquantized_src,HostBuffer * buffer_dst)1778 Stream &Stream::ThenCopyDevice2HostBuffer(
1779 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst) {
1780 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(*buffer_dst));
1781
1782 if (ok()) {
1783 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
1784 CheckError(
1785 dnn->DoCopyDevice2HostBuffer(this, gpu_unquantized_src, buffer_dst));
1786 } else {
1787 SetErrorAndLogNoDnnSupport();
1788 }
1789 }
1790 return *this;
1791 }
1792
GetOrCreateSubStream()1793 Stream *Stream::GetOrCreateSubStream() {
1794 mutex_lock lock{mu_};
1795 for (auto &stream : sub_streams_) {
1796 if (stream.second) {
1797 stream.second = false;
1798 return stream.first.get();
1799 }
1800 }
1801 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
1802 false);
1803 Stream *sub_stream = sub_streams_.back().first.get();
1804 sub_stream->Init();
1805 CHECK(ok_) << "sub-stream failed to be initialized";
1806
1807 return sub_stream;
1808 }
1809
ReturnSubStream(Stream * sub_stream)1810 void Stream::ReturnSubStream(Stream *sub_stream) {
1811 mutex_lock lock{mu_};
1812 for (auto &stream : sub_streams_) {
1813 if (stream.first.get() == sub_stream) {
1814 stream.second = true;
1815 return;
1816 }
1817 }
1818 LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
1819 }
1820
ThenStartTimer(Timer * t)1821 Stream &Stream::ThenStartTimer(Timer *t) {
1822 VLOG_CALL(PARAM(t));
1823
1824 if (ok()) {
1825 CheckError(parent_->StartTimer(this, t));
1826 } else {
1827 LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
1828 }
1829 return *this;
1830 }
1831
ThenStopTimer(Timer * t)1832 Stream &Stream::ThenStopTimer(Timer *t) {
1833 VLOG_CALL(PARAM(t));
1834
1835 if (ok()) {
1836 CheckError(parent_->StopTimer(this, t));
1837 } else {
1838 LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
1839 }
1840 return *this;
1841 }
1842
ThenWaitFor(Stream * other)1843 Stream &Stream::ThenWaitFor(Stream *other) {
1844 VLOG_CALL(PARAM(other));
1845
1846 CHECK(this != other) << "stream cannot wait for itself";
1847 if (ok() && other->ok()) {
1848 CheckError(parent_->CreateStreamDependency(this, other));
1849 } else {
1850 SetError();
1851 LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
1852 }
1853 return *this;
1854 }
1855
ThenWaitFor(Event * event)1856 Stream &Stream::ThenWaitFor(Event *event) {
1857 VLOG_CALL(PARAM(event));
1858
1859 if (ok()) {
1860 port::Status status = parent_->WaitForEvent(this, event);
1861 if (!status.ok()) {
1862 LOG(ERROR) << "Error waiting for event in stream: "
1863 << status.error_message()
1864 << "; not marking stream as bad, as the Event object may be "
1865 << "at fault. Monitor for further errors.";
1866 }
1867 } else {
1868 LOG(INFO) << "stream " << this << " did not wait for an event.";
1869 }
1870 return *this;
1871 }
1872
1873 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
1874 // functions and logs for errors.
1875 template <typename... Args>
1876 struct ThenBlasImpl {
1877 // blas_func is the DoBlasXXX member function pointer, and args are its
1878 // arguments except the first one of Stream* type.
operator ()perftools::gputools::ThenBlasImpl1879 Stream &operator()(Stream *stream,
1880 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1881 Args... args) {
1882 return Run(stream, blas_func, /*record_error=*/true, args...);
1883 }
1884
1885 // Like operator(), but only calls stream->CheckError() if record_error is
1886 // true.
1887 Stream &Run(Stream *stream,
1888 bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1889 bool record_error, Args... args);
1890 };
1891
1892 template <typename... Args>
Run(Stream * stream,bool (blas::BlasSupport::* blas_func)(Stream *,Args...),bool record_error,Args...args)1893 Stream &ThenBlasImpl<Args...>::Run(
1894 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
1895 bool record_error, Args... args) {
1896 if (stream->ok()) {
1897 bool ok;
1898 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
1899 ok = (blas->*blas_func)(stream, args...);
1900 } else {
1901 LOG(WARNING)
1902 << "attempting to perform BLAS operation using StreamExecutor "
1903 "without BLAS support";
1904 ok = false;
1905 }
1906 if (record_error) {
1907 stream->CheckError(ok);
1908 }
1909 }
1910 return *stream;
1911 }
1912
ThenBlasAsum(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)1913 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
1914 int incx, DeviceMemory<float> *result) {
1915 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1916
1917 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
1918 impl;
1919 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1920 result);
1921 }
1922
ThenBlasAsum(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)1923 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
1924 int incx, DeviceMemory<double> *result) {
1925 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1926
1927 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
1928 DeviceMemory<double> *> impl;
1929 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1930 result);
1931 }
1932
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)1933 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1934 const DeviceMemory<std::complex<float>> &x,
1935 int incx, DeviceMemory<float> *result) {
1936 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1937
1938 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
1939 DeviceMemory<float> *> impl;
1940 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1941 result);
1942 }
1943
ThenBlasAsum(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)1944 Stream &Stream::ThenBlasAsum(uint64 elem_count,
1945 const DeviceMemory<std::complex<double>> &x,
1946 int incx, DeviceMemory<double> *result) {
1947 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
1948
1949 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
1950 DeviceMemory<double> *> impl;
1951 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
1952 result);
1953 }
1954
ThenBlasAxpy(uint64 elem_count,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)1955 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
1956 const DeviceMemory<float> &x, int incx,
1957 DeviceMemory<float> *y, int incy) {
1958 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1959 PARAM(incy));
1960
1961 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
1962 DeviceMemory<float> *, int> impl;
1963 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1964 y, incy);
1965 }
1966
ThenBlasAxpy(uint64 elem_count,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)1967 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
1968 const DeviceMemory<double> &x, int incx,
1969 DeviceMemory<double> *y, int incy) {
1970 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1971 PARAM(incy));
1972
1973 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
1974 DeviceMemory<double> *, int> impl;
1975 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1976 y, incy);
1977 }
1978
ThenBlasAxpy(uint64 elem_count,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)1979 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
1980 const DeviceMemory<std::complex<float>> &x,
1981 int incx, DeviceMemory<std::complex<float>> *y,
1982 int incy) {
1983 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1984 PARAM(incy));
1985
1986 ThenBlasImpl<uint64, std::complex<float>,
1987 const DeviceMemory<std::complex<float>> &, int,
1988 DeviceMemory<std::complex<float>> *, int> impl;
1989 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
1990 y, incy);
1991 }
1992
ThenBlasAxpy(uint64 elem_count,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)1993 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
1994 const DeviceMemory<std::complex<double>> &x,
1995 int incx, DeviceMemory<std::complex<double>> *y,
1996 int incy) {
1997 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
1998 PARAM(incy));
1999
2000 ThenBlasImpl<uint64, std::complex<double>,
2001 const DeviceMemory<std::complex<double>> &, int,
2002 DeviceMemory<std::complex<double>> *, int> impl;
2003 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
2004 y, incy);
2005 }
2006
ThenBlasCopy(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * y,int incy)2007 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
2008 int incx, DeviceMemory<float> *y, int incy) {
2009 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2010
2011 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
2012 int> impl;
2013 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2014 incy);
2015 }
2016
ThenBlasCopy(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * y,int incy)2017 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
2018 int incx, DeviceMemory<double> *y, int incy) {
2019 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2020
2021 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2022 DeviceMemory<double> *, int> impl;
2023 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2024 incy);
2025 }
2026
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2027 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2028 const DeviceMemory<std::complex<float>> &x,
2029 int incx, DeviceMemory<std::complex<float>> *y,
2030 int incy) {
2031 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2032
2033 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2034 DeviceMemory<std::complex<float>> *, int> impl;
2035 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2036 incy);
2037 }
2038
ThenBlasCopy(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2039 Stream &Stream::ThenBlasCopy(uint64 elem_count,
2040 const DeviceMemory<std::complex<double>> &x,
2041 int incx, DeviceMemory<std::complex<double>> *y,
2042 int incy) {
2043 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2044
2045 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2046 DeviceMemory<std::complex<double>> *, int> impl;
2047 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
2048 incy);
2049 }
2050
ThenBlasDot(uint64 elem_count,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * result)2051 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
2052 int incx, const DeviceMemory<float> &y, int incy,
2053 DeviceMemory<float> *result) {
2054 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2055 PARAM(result));
2056
2057 ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
2058 const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
2059 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2060 result);
2061 }
2062
ThenBlasDot(uint64 elem_count,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * result)2063 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
2064 int incx, const DeviceMemory<double> &y, int incy,
2065 DeviceMemory<double> *result) {
2066 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2067 PARAM(result));
2068
2069 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2070 const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
2071 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
2072 result);
2073 }
2074
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2075 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2076 const DeviceMemory<std::complex<float>> &x,
2077 int incx,
2078 const DeviceMemory<std::complex<float>> &y,
2079 int incy,
2080 DeviceMemory<std::complex<float>> *result) {
2081 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2082 PARAM(result));
2083
2084 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2085 const DeviceMemory<std::complex<float>> &, int,
2086 DeviceMemory<std::complex<float>> *> impl;
2087 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2088 incy, result);
2089 }
2090
ThenBlasDotc(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2091 Stream &Stream::ThenBlasDotc(uint64 elem_count,
2092 const DeviceMemory<std::complex<double>> &x,
2093 int incx,
2094 const DeviceMemory<std::complex<double>> &y,
2095 int incy,
2096 DeviceMemory<std::complex<double>> *result) {
2097 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2098 PARAM(result));
2099
2100 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2101 const DeviceMemory<std::complex<double>> &, int,
2102 DeviceMemory<std::complex<double>> *> impl;
2103 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
2104 incy, result);
2105 }
2106
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * result)2107 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2108 const DeviceMemory<std::complex<float>> &x,
2109 int incx,
2110 const DeviceMemory<std::complex<float>> &y,
2111 int incy,
2112 DeviceMemory<std::complex<float>> *result) {
2113 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2114 PARAM(result));
2115
2116 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2117 const DeviceMemory<std::complex<float>> &, int,
2118 DeviceMemory<std::complex<float>> *> impl;
2119 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2120 incy, result);
2121 }
2122
ThenBlasDotu(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * result)2123 Stream &Stream::ThenBlasDotu(uint64 elem_count,
2124 const DeviceMemory<std::complex<double>> &x,
2125 int incx,
2126 const DeviceMemory<std::complex<double>> &y,
2127 int incy,
2128 DeviceMemory<std::complex<double>> *result) {
2129 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2130 PARAM(result));
2131
2132 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2133 const DeviceMemory<std::complex<double>> &, int,
2134 DeviceMemory<std::complex<double>> *> impl;
2135 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
2136 incy, result);
2137 }
2138
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * result)2139 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
2140 int incx, DeviceMemory<float> *result) {
2141 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2142
2143 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
2144 impl;
2145 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2146 result);
2147 }
2148
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * result)2149 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
2150 int incx, DeviceMemory<double> *result) {
2151 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2152
2153 ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
2154 DeviceMemory<double> *> impl;
2155 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2156 result);
2157 }
2158
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<float> * result)2159 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2160 const DeviceMemory<std::complex<float>> &x,
2161 int incx, DeviceMemory<float> *result) {
2162 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2163
2164 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2165 DeviceMemory<float> *> impl;
2166 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2167 result);
2168 }
2169
ThenBlasNrm2(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<double> * result)2170 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
2171 const DeviceMemory<std::complex<double>> &x,
2172 int incx, DeviceMemory<double> *result) {
2173 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2174
2175 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2176 DeviceMemory<double> *> impl;
2177 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
2178 result);
2179 }
2180
ThenBlasRot(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,float c,float s)2181 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
2182 DeviceMemory<float> *y, int incy, float c,
2183 float s) {
2184 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2185 PARAM(c), PARAM(s));
2186
2187 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2188 float, float> impl;
2189 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2190 c, s);
2191 }
2192
ThenBlasRot(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,double c,double s)2193 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
2194 int incx, DeviceMemory<double> *y, int incy,
2195 double c, double s) {
2196 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2197 PARAM(c), PARAM(s));
2198
2199 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2200 double, double> impl;
2201 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2202 c, s);
2203 }
2204
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy,float c,float s)2205 Stream &Stream::ThenBlasRot(uint64 elem_count,
2206 DeviceMemory<std::complex<float>> *x, int incx,
2207 DeviceMemory<std::complex<float>> *y, int incy,
2208 float c, float s) {
2209 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2210 PARAM(c), PARAM(s));
2211
2212 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2213 DeviceMemory<std::complex<float>> *, int, float, float> impl;
2214 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2215 c, s);
2216 }
2217
ThenBlasRot(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy,double c,double s)2218 Stream &Stream::ThenBlasRot(uint64 elem_count,
2219 DeviceMemory<std::complex<double>> *x, int incx,
2220 DeviceMemory<std::complex<double>> *y, int incy,
2221 double c, double s) {
2222 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2223 PARAM(c), PARAM(s));
2224
2225 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2226 DeviceMemory<std::complex<double>> *, int, double, double> impl;
2227 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
2228 c, s);
2229 }
2230
ThenBlasRotg(DeviceMemory<float> * a,DeviceMemory<float> * b,DeviceMemory<float> * c,DeviceMemory<float> * s)2231 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
2232 DeviceMemory<float> *c, DeviceMemory<float> *s) {
2233 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2234
2235 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2236 DeviceMemory<float> *, DeviceMemory<float> *> impl;
2237 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2238 }
2239
ThenBlasRotg(DeviceMemory<double> * a,DeviceMemory<double> * b,DeviceMemory<double> * c,DeviceMemory<double> * s)2240 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
2241 DeviceMemory<double> *c, DeviceMemory<double> *s) {
2242 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2243
2244 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2245 DeviceMemory<double> *, DeviceMemory<double> *> impl;
2246 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2247 }
2248
ThenBlasRotg(DeviceMemory<std::complex<float>> * a,DeviceMemory<std::complex<float>> * b,DeviceMemory<float> * c,DeviceMemory<std::complex<float>> * s)2249 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
2250 DeviceMemory<std::complex<float>> *b,
2251 DeviceMemory<float> *c,
2252 DeviceMemory<std::complex<float>> *s) {
2253 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2254
2255 ThenBlasImpl<DeviceMemory<std::complex<float>> *,
2256 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
2257 DeviceMemory<std::complex<float>> *> impl;
2258 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2259 }
2260
ThenBlasRotg(DeviceMemory<std::complex<double>> * a,DeviceMemory<std::complex<double>> * b,DeviceMemory<double> * c,DeviceMemory<std::complex<double>> * s)2261 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
2262 DeviceMemory<std::complex<double>> *b,
2263 DeviceMemory<double> *c,
2264 DeviceMemory<std::complex<double>> *s) {
2265 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
2266
2267 ThenBlasImpl<DeviceMemory<std::complex<double>> *,
2268 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
2269 DeviceMemory<std::complex<double>> *> impl;
2270 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
2271 }
2272
ThenBlasRotm(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy,const DeviceMemory<float> & param)2273 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
2274 int incx, DeviceMemory<float> *y, int incy,
2275 const DeviceMemory<float> ¶m) {
2276 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2277 PARAM(param));
2278
2279 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
2280 const DeviceMemory<float> &> impl;
2281 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2282 incy, param);
2283 }
2284
ThenBlasRotm(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy,const DeviceMemory<double> & param)2285 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
2286 int incx, DeviceMemory<double> *y, int incy,
2287 const DeviceMemory<double> ¶m) {
2288 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
2289 PARAM(param));
2290
2291 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
2292 const DeviceMemory<double> &> impl;
2293 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
2294 incy, param);
2295 }
2296
ThenBlasRotmg(DeviceMemory<float> * d1,DeviceMemory<float> * d2,DeviceMemory<float> * x1,const DeviceMemory<float> & y1,DeviceMemory<float> * param)2297 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
2298 DeviceMemory<float> *x1,
2299 const DeviceMemory<float> &y1,
2300 DeviceMemory<float> *param) {
2301 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2302
2303 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
2304 DeviceMemory<float> *, const DeviceMemory<float> &,
2305 DeviceMemory<float> *> impl;
2306 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2307 }
2308
ThenBlasRotmg(DeviceMemory<double> * d1,DeviceMemory<double> * d2,DeviceMemory<double> * x1,const DeviceMemory<double> & y1,DeviceMemory<double> * param)2309 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
2310 DeviceMemory<double> *d2,
2311 DeviceMemory<double> *x1,
2312 const DeviceMemory<double> &y1,
2313 DeviceMemory<double> *param) {
2314 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
2315
2316 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
2317 DeviceMemory<double> *, const DeviceMemory<double> &,
2318 DeviceMemory<double> *> impl;
2319 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
2320 }
2321
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<float> * x,int incx)2322 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2323 DeviceMemory<float> *x, int incx) {
2324 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2325
2326 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
2327 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2328 }
2329
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<double> * x,int incx)2330 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2331 DeviceMemory<double> *x, int incx) {
2332 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2333
2334 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
2335 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2336 }
2337
ThenBlasScal(uint64 elem_count,float alpha,DeviceMemory<std::complex<float>> * x,int incx)2338 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
2339 DeviceMemory<std::complex<float>> *x, int incx) {
2340 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2341
2342 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
2343 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2344 }
2345
ThenBlasScal(uint64 elem_count,double alpha,DeviceMemory<std::complex<double>> * x,int incx)2346 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
2347 DeviceMemory<std::complex<double>> *x, int incx) {
2348 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2349
2350 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
2351 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2352 }
2353
ThenBlasScal(uint64 elem_count,std::complex<float> alpha,DeviceMemory<std::complex<float>> * x,int incx)2354 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
2355 DeviceMemory<std::complex<float>> *x, int incx) {
2356 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2357
2358 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
2359 int> impl;
2360 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2361 }
2362
ThenBlasScal(uint64 elem_count,std::complex<double> alpha,DeviceMemory<std::complex<double>> * x,int incx)2363 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
2364 DeviceMemory<std::complex<double>> *x, int incx) {
2365 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
2366
2367 ThenBlasImpl<uint64, std::complex<double>,
2368 DeviceMemory<std::complex<double>> *, int> impl;
2369 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
2370 }
2371
ThenBlasSwap(uint64 elem_count,DeviceMemory<float> * x,int incx,DeviceMemory<float> * y,int incy)2372 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
2373 int incx, DeviceMemory<float> *y, int incy) {
2374 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2375
2376 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
2377 impl;
2378 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2379 incy);
2380 }
2381
ThenBlasSwap(uint64 elem_count,DeviceMemory<double> * x,int incx,DeviceMemory<double> * y,int incy)2382 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
2383 int incx, DeviceMemory<double> *y, int incy) {
2384 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2385
2386 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
2387 impl;
2388 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2389 incy);
2390 }
2391
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<float>> * x,int incx,DeviceMemory<std::complex<float>> * y,int incy)2392 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2393 DeviceMemory<std::complex<float>> *x, int incx,
2394 DeviceMemory<std::complex<float>> *y, int incy) {
2395 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2396
2397 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
2398 DeviceMemory<std::complex<float>> *, int> impl;
2399 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2400 incy);
2401 }
2402
ThenBlasSwap(uint64 elem_count,DeviceMemory<std::complex<double>> * x,int incx,DeviceMemory<std::complex<double>> * y,int incy)2403 Stream &Stream::ThenBlasSwap(uint64 elem_count,
2404 DeviceMemory<std::complex<double>> *x, int incx,
2405 DeviceMemory<std::complex<double>> *y, int incy) {
2406 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
2407
2408 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
2409 DeviceMemory<std::complex<double>> *, int> impl;
2410 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
2411 incy);
2412 }
2413
ThenBlasIamax(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2414 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
2415 int incx, DeviceMemory<int> *result) {
2416 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2417
2418 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2419 impl;
2420 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2421 result);
2422 }
2423
ThenBlasIamax(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2424 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
2425 int incx, DeviceMemory<int> *result) {
2426 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2427
2428 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2429 impl;
2430 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2431 result);
2432 }
2433
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2434 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2435 const DeviceMemory<std::complex<float>> &x,
2436 int incx, DeviceMemory<int> *result) {
2437 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2438
2439 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2440 DeviceMemory<int> *> impl;
2441 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2442 result);
2443 }
2444
ThenBlasIamax(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2445 Stream &Stream::ThenBlasIamax(uint64 elem_count,
2446 const DeviceMemory<std::complex<double>> &x,
2447 int incx, DeviceMemory<int> *result) {
2448 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2449
2450 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2451 DeviceMemory<int> *> impl;
2452 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
2453 result);
2454 }
2455
ThenBlasIamin(uint64 elem_count,const DeviceMemory<float> & x,int incx,DeviceMemory<int> * result)2456 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
2457 int incx, DeviceMemory<int> *result) {
2458 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2459
2460 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
2461 impl;
2462 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2463 result);
2464 }
2465
ThenBlasIamin(uint64 elem_count,const DeviceMemory<double> & x,int incx,DeviceMemory<int> * result)2466 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
2467 int incx, DeviceMemory<int> *result) {
2468 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2469
2470 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
2471 impl;
2472 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2473 result);
2474 }
2475
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<int> * result)2476 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2477 const DeviceMemory<std::complex<float>> &x,
2478 int incx, DeviceMemory<int> *result) {
2479 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2480
2481 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
2482 DeviceMemory<int> *> impl;
2483 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2484 result);
2485 }
2486
ThenBlasIamin(uint64 elem_count,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<int> * result)2487 Stream &Stream::ThenBlasIamin(uint64 elem_count,
2488 const DeviceMemory<std::complex<double>> &x,
2489 int incx, DeviceMemory<int> *result) {
2490 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
2491
2492 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
2493 DeviceMemory<int> *> impl;
2494 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
2495 result);
2496 }
2497
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2498 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2499 uint64 kl, uint64 ku, float alpha,
2500 const DeviceMemory<float> &a, int lda,
2501 const DeviceMemory<float> &x, int incx, float beta,
2502 DeviceMemory<float> *y, int incy) {
2503 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2504 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2505 PARAM(beta), PARAM(y), PARAM(incy));
2506
2507 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
2508 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2509 int, float, DeviceMemory<float> *, int> impl;
2510 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2511 a, lda, x, incx, beta, y, incy);
2512 }
2513
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2514 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2515 uint64 kl, uint64 ku, double alpha,
2516 const DeviceMemory<double> &a, int lda,
2517 const DeviceMemory<double> &x, int incx,
2518 double beta, DeviceMemory<double> *y, int incy) {
2519 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2520 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2521 PARAM(beta), PARAM(y), PARAM(incy));
2522
2523 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
2524 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2525 int, double, DeviceMemory<double> *, int> impl;
2526 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2527 a, lda, x, incx, beta, y, incy);
2528 }
2529
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2530 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2531 uint64 kl, uint64 ku, std::complex<float> alpha,
2532 const DeviceMemory<std::complex<float>> &a,
2533 int lda,
2534 const DeviceMemory<std::complex<float>> &x,
2535 int incx, std::complex<float> beta,
2536 DeviceMemory<std::complex<float>> *y, int incy) {
2537 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2538 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2539 PARAM(beta), PARAM(y), PARAM(incy));
2540
2541 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2542 std::complex<float>, const DeviceMemory<std::complex<float>> &,
2543 int, const DeviceMemory<std::complex<float>> &, int,
2544 std::complex<float>, DeviceMemory<std::complex<float>> *,
2545 int> impl;
2546 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2547 a, lda, x, incx, beta, y, incy);
2548 }
2549
ThenBlasGbmv(blas::Transpose trans,uint64 m,uint64 n,uint64 kl,uint64 ku,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2550 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
2551 uint64 kl, uint64 ku, std::complex<double> alpha,
2552 const DeviceMemory<std::complex<double>> &a,
2553 int lda,
2554 const DeviceMemory<std::complex<double>> &x,
2555 int incx, std::complex<double> beta,
2556 DeviceMemory<std::complex<double>> *y, int incy) {
2557 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
2558 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
2559 PARAM(beta), PARAM(y), PARAM(incy));
2560
2561 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
2562 std::complex<double>, const DeviceMemory<std::complex<double>> &,
2563 int, const DeviceMemory<std::complex<double>> &, int,
2564 std::complex<double>, DeviceMemory<std::complex<double>> *,
2565 int> impl;
2566 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
2567 a, lda, x, incx, beta, y, incy);
2568 }
2569
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2570 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2571 float alpha, const DeviceMemory<float> &a, int lda,
2572 const DeviceMemory<float> &x, int incx, float beta,
2573 DeviceMemory<float> *y, int incy) {
2574 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2575 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2576 PARAM(incy));
2577
2578 ThenBlasImpl<blas::Transpose, uint64, uint64, float,
2579 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2580 int, float, DeviceMemory<float> *, int> impl;
2581 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2582 x, incx, beta, y, incy);
2583 }
2584
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2585 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2586 double alpha, const DeviceMemory<double> &a,
2587 int lda, const DeviceMemory<double> &x, int incx,
2588 double beta, DeviceMemory<double> *y, int incy) {
2589 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2590 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2591 PARAM(incy));
2592
2593 ThenBlasImpl<blas::Transpose, uint64, uint64, double,
2594 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2595 int, double, DeviceMemory<double> *, int> impl;
2596 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2597 x, incx, beta, y, incy);
2598 }
2599
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2600 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2601 std::complex<float> alpha,
2602 const DeviceMemory<std::complex<float>> &a,
2603 int lda,
2604 const DeviceMemory<std::complex<float>> &x,
2605 int incx, std::complex<float> beta,
2606 DeviceMemory<std::complex<float>> *y, int incy) {
2607 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2608 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2609 PARAM(incy));
2610
2611 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
2612 const DeviceMemory<std::complex<float>> &, int,
2613 const DeviceMemory<std::complex<float>> &, int,
2614 std::complex<float>, DeviceMemory<std::complex<float>> *,
2615 int> impl;
2616 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2617 x, incx, beta, y, incy);
2618 }
2619
ThenBlasGemv(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2620 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
2621 std::complex<double> alpha,
2622 const DeviceMemory<std::complex<double>> &a,
2623 int lda,
2624 const DeviceMemory<std::complex<double>> &x,
2625 int incx, std::complex<double> beta,
2626 DeviceMemory<std::complex<double>> *y, int incy) {
2627 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
2628 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
2629 PARAM(incy));
2630
2631 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
2632 const DeviceMemory<std::complex<double>> &, int,
2633 const DeviceMemory<std::complex<double>> &, int,
2634 std::complex<double>, DeviceMemory<std::complex<double>> *,
2635 int> impl;
2636 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
2637 x, incx, beta, y, incy);
2638 }
2639
ThenBlasGer(uint64 m,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)2640 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
2641 const DeviceMemory<float> &x, int incx,
2642 const DeviceMemory<float> &y, int incy,
2643 DeviceMemory<float> *a, int lda) {
2644 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2645 PARAM(incy), PARAM(a), PARAM(lda));
2646
2647 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
2648 const DeviceMemory<float> &, int, DeviceMemory<float> *,
2649 int> impl;
2650 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2651 incy, a, lda);
2652 }
2653
ThenBlasGer(uint64 m,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)2654 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
2655 const DeviceMemory<double> &x, int incx,
2656 const DeviceMemory<double> &y, int incy,
2657 DeviceMemory<double> *a, int lda) {
2658 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2659 PARAM(incy), PARAM(a), PARAM(lda));
2660
2661 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
2662 const DeviceMemory<double> &, int, DeviceMemory<double> *,
2663 int> impl;
2664 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
2665 incy, a, lda);
2666 }
2667
ThenBlasGerc(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2668 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
2669 const DeviceMemory<std::complex<float>> &x,
2670 int incx,
2671 const DeviceMemory<std::complex<float>> &y,
2672 int incy, DeviceMemory<std::complex<float>> *a,
2673 int lda) {
2674 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2675 PARAM(incy), PARAM(a), PARAM(lda));
2676
2677 ThenBlasImpl<uint64, uint64, std::complex<float>,
2678 const DeviceMemory<std::complex<float>> &, int,
2679 const DeviceMemory<std::complex<float>> &, int,
2680 DeviceMemory<std::complex<float>> *, int> impl;
2681 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2682 incy, a, lda);
2683 }
2684
ThenBlasGerc(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2685 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
2686 const DeviceMemory<std::complex<double>> &x,
2687 int incx,
2688 const DeviceMemory<std::complex<double>> &y,
2689 int incy, DeviceMemory<std::complex<double>> *a,
2690 int lda) {
2691 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2692 PARAM(incy), PARAM(a), PARAM(lda));
2693
2694 ThenBlasImpl<uint64, uint64, std::complex<double>,
2695 const DeviceMemory<std::complex<double>> &, int,
2696 const DeviceMemory<std::complex<double>> &, int,
2697 DeviceMemory<std::complex<double>> *, int> impl;
2698 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
2699 incy, a, lda);
2700 }
2701
ThenBlasGeru(uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2702 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
2703 const DeviceMemory<std::complex<float>> &x,
2704 int incx,
2705 const DeviceMemory<std::complex<float>> &y,
2706 int incy, DeviceMemory<std::complex<float>> *a,
2707 int lda) {
2708 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2709 PARAM(incy), PARAM(a), PARAM(lda));
2710
2711 ThenBlasImpl<uint64, uint64, std::complex<float>,
2712 const DeviceMemory<std::complex<float>> &, int,
2713 const DeviceMemory<std::complex<float>> &, int,
2714 DeviceMemory<std::complex<float>> *, int> impl;
2715 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2716 incy, a, lda);
2717 }
2718
ThenBlasGeru(uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2719 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
2720 const DeviceMemory<std::complex<double>> &x,
2721 int incx,
2722 const DeviceMemory<std::complex<double>> &y,
2723 int incy, DeviceMemory<std::complex<double>> *a,
2724 int lda) {
2725 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
2726 PARAM(incy), PARAM(a), PARAM(lda));
2727
2728 ThenBlasImpl<uint64, uint64, std::complex<double>,
2729 const DeviceMemory<std::complex<double>> &, int,
2730 const DeviceMemory<std::complex<double>> &, int,
2731 DeviceMemory<std::complex<double>> *, int> impl;
2732 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
2733 incy, a, lda);
2734 }
2735
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2736 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2737 std::complex<float> alpha,
2738 const DeviceMemory<std::complex<float>> &a,
2739 int lda,
2740 const DeviceMemory<std::complex<float>> &x,
2741 int incx, std::complex<float> beta,
2742 DeviceMemory<std::complex<float>> *y, int incy) {
2743 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2744 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2745
2746 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
2747 const DeviceMemory<std::complex<float>> &, int,
2748 const DeviceMemory<std::complex<float>> &, int,
2749 std::complex<float>, DeviceMemory<std::complex<float>> *,
2750 int> impl;
2751 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2752 x, incx, beta, y, incy);
2753 }
2754
ThenBlasHbmv(blas::UpperLower uplo,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2755 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2756 std::complex<double> alpha,
2757 const DeviceMemory<std::complex<double>> &a,
2758 int lda,
2759 const DeviceMemory<std::complex<double>> &x,
2760 int incx, std::complex<double> beta,
2761 DeviceMemory<std::complex<double>> *y, int incy) {
2762 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2763 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2764
2765 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
2766 const DeviceMemory<std::complex<double>> &, int,
2767 const DeviceMemory<std::complex<double>> &, int,
2768 std::complex<double>, DeviceMemory<std::complex<double>> *,
2769 int> impl;
2770 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
2771 x, incx, beta, y, incy);
2772 }
2773
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2774 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2775 std::complex<float> alpha,
2776 const DeviceMemory<std::complex<float>> &a,
2777 int lda,
2778 const DeviceMemory<std::complex<float>> &x,
2779 int incx, std::complex<float> beta,
2780 DeviceMemory<std::complex<float>> *y, int incy) {
2781 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2782 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2783
2784 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2785 const DeviceMemory<std::complex<float>> &, int,
2786 const DeviceMemory<std::complex<float>> &, int,
2787 std::complex<float>, DeviceMemory<std::complex<float>> *,
2788 int> impl;
2789 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2790 incx, beta, y, incy);
2791 }
2792
ThenBlasHemv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2793 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
2794 std::complex<double> alpha,
2795 const DeviceMemory<std::complex<double>> &a,
2796 int lda,
2797 const DeviceMemory<std::complex<double>> &x,
2798 int incx, std::complex<double> beta,
2799 DeviceMemory<std::complex<double>> *y, int incy) {
2800 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
2801 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2802
2803 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2804 const DeviceMemory<std::complex<double>> &, int,
2805 const DeviceMemory<std::complex<double>> &, int,
2806 std::complex<double>, DeviceMemory<std::complex<double>> *,
2807 int> impl;
2808 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
2809 incx, beta, y, incy);
2810 }
2811
ThenBlasHer(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * a,int lda)2812 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
2813 const DeviceMemory<std::complex<float>> &x,
2814 int incx, DeviceMemory<std::complex<float>> *a,
2815 int lda) {
2816 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2817 PARAM(a), PARAM(lda));
2818
2819 ThenBlasImpl<blas::UpperLower, uint64, float,
2820 const DeviceMemory<std::complex<float>> &, int,
2821 DeviceMemory<std::complex<float>> *, int> impl;
2822 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2823 lda);
2824 }
2825
ThenBlasHer(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * a,int lda)2826 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
2827 const DeviceMemory<std::complex<double>> &x,
2828 int incx, DeviceMemory<std::complex<double>> *a,
2829 int lda) {
2830 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2831 PARAM(a), PARAM(lda));
2832
2833 ThenBlasImpl<blas::UpperLower, uint64, double,
2834 const DeviceMemory<std::complex<double>> &, int,
2835 DeviceMemory<std::complex<double>> *, int> impl;
2836 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
2837 lda);
2838 }
2839
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * a,int lda)2840 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2841 std::complex<float> alpha,
2842 const DeviceMemory<std::complex<float>> &x,
2843 int incx,
2844 const DeviceMemory<std::complex<float>> &y,
2845 int incy, DeviceMemory<std::complex<float>> *a,
2846 int lda) {
2847 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2848 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2849
2850 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2851 const DeviceMemory<std::complex<float>> &, int,
2852 const DeviceMemory<std::complex<float>> &, int,
2853 DeviceMemory<std::complex<float>> *, int> impl;
2854 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2855 incy, a, lda);
2856 }
2857
ThenBlasHer2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * a,int lda)2858 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
2859 std::complex<double> alpha,
2860 const DeviceMemory<std::complex<double>> &x,
2861 int incx,
2862 const DeviceMemory<std::complex<double>> &y,
2863 int incy, DeviceMemory<std::complex<double>> *a,
2864 int lda) {
2865 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2866 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
2867
2868 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2869 const DeviceMemory<std::complex<double>> &, int,
2870 const DeviceMemory<std::complex<double>> &, int,
2871 DeviceMemory<std::complex<double>> *, int> impl;
2872 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
2873 incy, a, lda);
2874 }
2875
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & ap,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy)2876 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2877 std::complex<float> alpha,
2878 const DeviceMemory<std::complex<float>> &ap,
2879 const DeviceMemory<std::complex<float>> &x,
2880 int incx, std::complex<float> beta,
2881 DeviceMemory<std::complex<float>> *y, int incy) {
2882 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2883 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2884
2885 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2886 const DeviceMemory<std::complex<float>> &,
2887 const DeviceMemory<std::complex<float>> &, int,
2888 std::complex<float>, DeviceMemory<std::complex<float>> *,
2889 int> impl;
2890 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2891 beta, y, incy);
2892 }
2893
ThenBlasHpmv(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & ap,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy)2894 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
2895 std::complex<double> alpha,
2896 const DeviceMemory<std::complex<double>> &ap,
2897 const DeviceMemory<std::complex<double>> &x,
2898 int incx, std::complex<double> beta,
2899 DeviceMemory<std::complex<double>> *y, int incy) {
2900 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
2901 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2902
2903 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2904 const DeviceMemory<std::complex<double>> &,
2905 const DeviceMemory<std::complex<double>> &, int,
2906 std::complex<double>, DeviceMemory<std::complex<double>> *,
2907 int> impl;
2908 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
2909 beta, y, incy);
2910 }
2911
ThenBlasHpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<std::complex<float>> & x,int incx,DeviceMemory<std::complex<float>> * ap)2912 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
2913 const DeviceMemory<std::complex<float>> &x,
2914 int incx, DeviceMemory<std::complex<float>> *ap) {
2915 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2916 PARAM(ap));
2917
2918 ThenBlasImpl<blas::UpperLower, uint64, float,
2919 const DeviceMemory<std::complex<float>> &, int,
2920 DeviceMemory<std::complex<float>> *> impl;
2921 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2922 }
2923
ThenBlasHpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<std::complex<double>> & x,int incx,DeviceMemory<std::complex<double>> * ap)2924 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
2925 const DeviceMemory<std::complex<double>> &x,
2926 int incx, DeviceMemory<std::complex<double>> *ap) {
2927 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2928 PARAM(ap));
2929
2930 ThenBlasImpl<blas::UpperLower, uint64, double,
2931 const DeviceMemory<std::complex<double>> &, int,
2932 DeviceMemory<std::complex<double>> *> impl;
2933 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
2934 }
2935
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & x,int incx,const DeviceMemory<std::complex<float>> & y,int incy,DeviceMemory<std::complex<float>> * ap)2936 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2937 std::complex<float> alpha,
2938 const DeviceMemory<std::complex<float>> &x,
2939 int incx,
2940 const DeviceMemory<std::complex<float>> &y,
2941 int incy, DeviceMemory<std::complex<float>> *ap) {
2942 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2943 PARAM(y), PARAM(incy), PARAM(ap));
2944
2945 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
2946 const DeviceMemory<std::complex<float>> &, int,
2947 const DeviceMemory<std::complex<float>> &, int,
2948 DeviceMemory<std::complex<float>> *> impl;
2949 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2950 incy, ap);
2951 }
2952
ThenBlasHpr2(blas::UpperLower uplo,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & x,int incx,const DeviceMemory<std::complex<double>> & y,int incy,DeviceMemory<std::complex<double>> * ap)2953 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
2954 std::complex<double> alpha,
2955 const DeviceMemory<std::complex<double>> &x,
2956 int incx,
2957 const DeviceMemory<std::complex<double>> &y,
2958 int incy, DeviceMemory<std::complex<double>> *ap) {
2959 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
2960 PARAM(y), PARAM(incy), PARAM(ap));
2961
2962 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
2963 const DeviceMemory<std::complex<double>> &, int,
2964 const DeviceMemory<std::complex<double>> &, int,
2965 DeviceMemory<std::complex<double>> *> impl;
2966 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
2967 incy, ap);
2968 }
2969
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2970 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2971 float alpha, const DeviceMemory<float> &a, int lda,
2972 const DeviceMemory<float> &x, int incx, float beta,
2973 DeviceMemory<float> *y, int incy) {
2974 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2975 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2976
2977 ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
2978 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
2979 int, float, DeviceMemory<float> *, int> impl;
2980 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2981 x, incx, beta, y, incy);
2982 }
2983
ThenBlasSbmv(blas::UpperLower uplo,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)2984 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
2985 double alpha, const DeviceMemory<double> &a,
2986 int lda, const DeviceMemory<double> &x, int incx,
2987 double beta, DeviceMemory<double> *y, int incy) {
2988 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
2989 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
2990
2991 ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
2992 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
2993 int, double, DeviceMemory<double> *, int> impl;
2994 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
2995 x, incx, beta, y, incy);
2996 }
2997
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & ap,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)2998 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
2999 const DeviceMemory<float> &ap,
3000 const DeviceMemory<float> &x, int incx, float beta,
3001 DeviceMemory<float> *y, int incy) {
3002 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3003 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3004
3005 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3006 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
3007 int> impl;
3008 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3009 beta, y, incy);
3010 }
3011
ThenBlasSpmv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & ap,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3012 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
3013 const DeviceMemory<double> &ap,
3014 const DeviceMemory<double> &x, int incx,
3015 double beta, DeviceMemory<double> *y, int incy) {
3016 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
3017 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3018
3019 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3020 const DeviceMemory<double> &, int, double,
3021 DeviceMemory<double> *, int> impl;
3022 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
3023 beta, y, incy);
3024 }
3025
ThenBlasSpr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * ap)3026 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
3027 const DeviceMemory<float> &x, int incx,
3028 DeviceMemory<float> *ap) {
3029 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3030 PARAM(ap));
3031
3032 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3033 int, DeviceMemory<float> *> impl;
3034 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3035 }
3036
ThenBlasSpr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * ap)3037 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
3038 const DeviceMemory<double> &x, int incx,
3039 DeviceMemory<double> *ap) {
3040 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3041 PARAM(ap));
3042
3043 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3044 int, DeviceMemory<double> *> impl;
3045 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
3046 }
3047
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * ap)3048 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
3049 const DeviceMemory<float> &x, int incx,
3050 const DeviceMemory<float> &y, int incy,
3051 DeviceMemory<float> *ap) {
3052 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3053 PARAM(y), PARAM(incy), PARAM(ap));
3054
3055 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3056 int, const DeviceMemory<float> &, int,
3057 DeviceMemory<float> *> impl;
3058 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3059 incy, ap);
3060 }
3061
ThenBlasSpr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * ap)3062 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
3063 const DeviceMemory<double> &x, int incx,
3064 const DeviceMemory<double> &y, int incy,
3065 DeviceMemory<double> *ap) {
3066 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3067 PARAM(y), PARAM(incy), PARAM(ap));
3068
3069 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3070 int, const DeviceMemory<double> &, int,
3071 DeviceMemory<double> *> impl;
3072 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
3073 incy, ap);
3074 }
3075
ThenBlasSymv(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy)3076 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
3077 const DeviceMemory<float> &a, int lda,
3078 const DeviceMemory<float> &x, int incx, float beta,
3079 DeviceMemory<float> *y, int incy) {
3080 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3081 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3082
3083 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3084 int, const DeviceMemory<float> &, int, float,
3085 DeviceMemory<float> *, int> impl;
3086 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3087 incx, beta, y, incy);
3088 }
3089
ThenBlasSymv(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy)3090 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
3091 const DeviceMemory<double> &a, int lda,
3092 const DeviceMemory<double> &x, int incx,
3093 double beta, DeviceMemory<double> *y, int incy) {
3094 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
3095 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
3096
3097 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3098 int, const DeviceMemory<double> &, int, double,
3099 DeviceMemory<double> *, int> impl;
3100 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
3101 incx, beta, y, incy);
3102 }
3103
ThenBlasSyr(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,DeviceMemory<float> * a,int lda)3104 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
3105 const DeviceMemory<float> &x, int incx,
3106 DeviceMemory<float> *a, int lda) {
3107 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3108 PARAM(a), PARAM(lda));
3109
3110 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3111 int, DeviceMemory<float> *, int> impl;
3112 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3113 lda);
3114 }
3115
ThenBlasSyr(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,DeviceMemory<double> * a,int lda)3116 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
3117 const DeviceMemory<double> &x, int incx,
3118 DeviceMemory<double> *a, int lda) {
3119 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3120 PARAM(a), PARAM(lda));
3121
3122 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3123 int, DeviceMemory<double> *, int> impl;
3124 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
3125 lda);
3126 }
3127
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,float alpha,const DeviceMemory<float> & x,int incx,const DeviceMemory<float> & y,int incy,DeviceMemory<float> * a,int lda)3128 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
3129 const DeviceMemory<float> &x, int incx,
3130 const DeviceMemory<float> &y, int incy,
3131 DeviceMemory<float> *a, int lda) {
3132 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3133 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3134
3135 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
3136 int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3137 int> impl;
3138 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3139 incy, a, lda);
3140 }
3141
ThenBlasSyr2(blas::UpperLower uplo,uint64 n,double alpha,const DeviceMemory<double> & x,int incx,const DeviceMemory<double> & y,int incy,DeviceMemory<double> * a,int lda)3142 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
3143 const DeviceMemory<double> &x, int incx,
3144 const DeviceMemory<double> &y, int incy,
3145 DeviceMemory<double> *a, int lda) {
3146 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
3147 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
3148
3149 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
3150 int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
3151 int> impl;
3152 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
3153 incy, a, lda);
3154 }
3155
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3156 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3157 blas::Diagonal diag, uint64 n, uint64 k,
3158 const DeviceMemory<float> &a, int lda,
3159 DeviceMemory<float> *x, int incx) {
3160 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3161 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3162
3163 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3164 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3165 int> impl;
3166 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3167 lda, x, incx);
3168 }
3169
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3170 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3171 blas::Diagonal diag, uint64 n, uint64 k,
3172 const DeviceMemory<double> &a, int lda,
3173 DeviceMemory<double> *x, int incx) {
3174 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3175 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3176
3177 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3178 uint64, const DeviceMemory<double> &, int,
3179 DeviceMemory<double> *, int> impl;
3180 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3181 lda, x, incx);
3182 }
3183
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3184 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3185 blas::Diagonal diag, uint64 n, uint64 k,
3186 const DeviceMemory<std::complex<float>> &a,
3187 int lda, DeviceMemory<std::complex<float>> *x,
3188 int incx) {
3189 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3190 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3191
3192 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3193 uint64, const DeviceMemory<std::complex<float>> &, int,
3194 DeviceMemory<std::complex<float>> *, int> impl;
3195 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3196 lda, x, incx);
3197 }
3198
ThenBlasTbmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3199 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
3200 blas::Diagonal diag, uint64 n, uint64 k,
3201 const DeviceMemory<std::complex<double>> &a,
3202 int lda, DeviceMemory<std::complex<double>> *x,
3203 int incx) {
3204 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3205 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3206
3207 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3208 uint64, const DeviceMemory<std::complex<double>> &, int,
3209 DeviceMemory<std::complex<double>> *, int> impl;
3210 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
3211 lda, x, incx);
3212 }
3213
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3214 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3215 blas::Diagonal diag, uint64 n, uint64 k,
3216 const DeviceMemory<float> &a, int lda,
3217 DeviceMemory<float> *x, int incx) {
3218 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3219 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3220
3221 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3222 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
3223 int> impl;
3224 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3225 lda, x, incx);
3226 }
3227
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3228 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3229 blas::Diagonal diag, uint64 n, uint64 k,
3230 const DeviceMemory<double> &a, int lda,
3231 DeviceMemory<double> *x, int incx) {
3232 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3233 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3234
3235 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3236 uint64, const DeviceMemory<double> &, int,
3237 DeviceMemory<double> *, int> impl;
3238 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3239 lda, x, incx);
3240 }
3241
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3242 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3243 blas::Diagonal diag, uint64 n, uint64 k,
3244 const DeviceMemory<std::complex<float>> &a,
3245 int lda, DeviceMemory<std::complex<float>> *x,
3246 int incx) {
3247 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3248 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3249
3250 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3251 uint64, const DeviceMemory<std::complex<float>> &, int,
3252 DeviceMemory<std::complex<float>> *, int> impl;
3253 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3254 lda, x, incx);
3255 }
3256
ThenBlasTbsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,uint64 k,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3257 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
3258 blas::Diagonal diag, uint64 n, uint64 k,
3259 const DeviceMemory<std::complex<double>> &a,
3260 int lda, DeviceMemory<std::complex<double>> *x,
3261 int incx) {
3262 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
3263 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
3264
3265 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3266 uint64, const DeviceMemory<std::complex<double>> &, int,
3267 DeviceMemory<std::complex<double>> *, int> impl;
3268 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
3269 lda, x, incx);
3270 }
3271
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3272 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3273 blas::Diagonal diag, uint64 n,
3274 const DeviceMemory<float> &ap,
3275 DeviceMemory<float> *x, int incx) {
3276 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3277 PARAM(x), PARAM(incx));
3278
3279 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3280 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3281 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3282 incx);
3283 }
3284
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3285 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3286 blas::Diagonal diag, uint64 n,
3287 const DeviceMemory<double> &ap,
3288 DeviceMemory<double> *x, int incx) {
3289 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3290 PARAM(x), PARAM(incx));
3291
3292 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3293 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3294 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3295 incx);
3296 }
3297
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3298 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3299 blas::Diagonal diag, uint64 n,
3300 const DeviceMemory<std::complex<float>> &ap,
3301 DeviceMemory<std::complex<float>> *x, int incx) {
3302 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3303 PARAM(x), PARAM(incx));
3304
3305 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3306 const DeviceMemory<std::complex<float>> &,
3307 DeviceMemory<std::complex<float>> *, int> impl;
3308 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3309 incx);
3310 }
3311
ThenBlasTpmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3312 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
3313 blas::Diagonal diag, uint64 n,
3314 const DeviceMemory<std::complex<double>> &ap,
3315 DeviceMemory<std::complex<double>> *x, int incx) {
3316 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3317 PARAM(x), PARAM(incx));
3318
3319 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3320 const DeviceMemory<std::complex<double>> &,
3321 DeviceMemory<std::complex<double>> *, int> impl;
3322 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
3323 incx);
3324 }
3325
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & ap,DeviceMemory<float> * x,int incx)3326 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3327 blas::Diagonal diag, uint64 n,
3328 const DeviceMemory<float> &ap,
3329 DeviceMemory<float> *x, int incx) {
3330 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3331 PARAM(x), PARAM(incx));
3332
3333 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3334 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
3335 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3336 incx);
3337 }
3338
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & ap,DeviceMemory<double> * x,int incx)3339 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3340 blas::Diagonal diag, uint64 n,
3341 const DeviceMemory<double> &ap,
3342 DeviceMemory<double> *x, int incx) {
3343 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3344 PARAM(x), PARAM(incx));
3345
3346 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3347 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
3348 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3349 incx);
3350 }
3351
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & ap,DeviceMemory<std::complex<float>> * x,int incx)3352 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3353 blas::Diagonal diag, uint64 n,
3354 const DeviceMemory<std::complex<float>> &ap,
3355 DeviceMemory<std::complex<float>> *x, int incx) {
3356 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3357 PARAM(x), PARAM(incx));
3358
3359 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3360 const DeviceMemory<std::complex<float>> &,
3361 DeviceMemory<std::complex<float>> *, int> impl;
3362 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3363 incx);
3364 }
3365
ThenBlasTpsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & ap,DeviceMemory<std::complex<double>> * x,int incx)3366 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
3367 blas::Diagonal diag, uint64 n,
3368 const DeviceMemory<std::complex<double>> &ap,
3369 DeviceMemory<std::complex<double>> *x, int incx) {
3370 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
3371 PARAM(x), PARAM(incx));
3372
3373 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3374 const DeviceMemory<std::complex<double>> &,
3375 DeviceMemory<std::complex<double>> *, int> impl;
3376 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
3377 incx);
3378 }
3379
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3380 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3381 blas::Diagonal diag, uint64 n,
3382 const DeviceMemory<float> &a, int lda,
3383 DeviceMemory<float> *x, int incx) {
3384 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3385 PARAM(lda), PARAM(x), PARAM(incx));
3386
3387 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3388 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3389 int> impl;
3390 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3391 lda, x, incx);
3392 }
3393
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3394 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3395 blas::Diagonal diag, uint64 n,
3396 const DeviceMemory<double> &a, int lda,
3397 DeviceMemory<double> *x, int incx) {
3398 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3399 PARAM(lda), PARAM(x), PARAM(incx));
3400
3401 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3402 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3403 int> impl;
3404 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3405 lda, x, incx);
3406 }
3407
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3408 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3409 blas::Diagonal diag, uint64 n,
3410 const DeviceMemory<std::complex<float>> &a,
3411 int lda, DeviceMemory<std::complex<float>> *x,
3412 int incx) {
3413 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3414 PARAM(lda), PARAM(x), PARAM(incx));
3415
3416 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3417 const DeviceMemory<std::complex<float>> &, int,
3418 DeviceMemory<std::complex<float>> *, int> impl;
3419 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3420 lda, x, incx);
3421 }
3422
ThenBlasTrmv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3423 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
3424 blas::Diagonal diag, uint64 n,
3425 const DeviceMemory<std::complex<double>> &a,
3426 int lda, DeviceMemory<std::complex<double>> *x,
3427 int incx) {
3428 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3429 PARAM(lda), PARAM(x), PARAM(incx));
3430
3431 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3432 const DeviceMemory<std::complex<double>> &, int,
3433 DeviceMemory<std::complex<double>> *, int> impl;
3434 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
3435 lda, x, incx);
3436 }
3437
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * x,int incx)3438 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3439 blas::Diagonal diag, uint64 n,
3440 const DeviceMemory<float> &a, int lda,
3441 DeviceMemory<float> *x, int incx) {
3442 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3443 PARAM(lda), PARAM(x), PARAM(incx));
3444
3445 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3446 const DeviceMemory<float> &, int, DeviceMemory<float> *,
3447 int> impl;
3448 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3449 lda, x, incx);
3450 }
3451
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * x,int incx)3452 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3453 blas::Diagonal diag, uint64 n,
3454 const DeviceMemory<double> &a, int lda,
3455 DeviceMemory<double> *x, int incx) {
3456 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3457 PARAM(lda), PARAM(x), PARAM(incx));
3458
3459 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3460 const DeviceMemory<double> &, int, DeviceMemory<double> *,
3461 int> impl;
3462 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3463 lda, x, incx);
3464 }
3465
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * x,int incx)3466 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3467 blas::Diagonal diag, uint64 n,
3468 const DeviceMemory<std::complex<float>> &a,
3469 int lda, DeviceMemory<std::complex<float>> *x,
3470 int incx) {
3471 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3472 PARAM(lda), PARAM(x), PARAM(incx));
3473
3474 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3475 const DeviceMemory<std::complex<float>> &, int,
3476 DeviceMemory<std::complex<float>> *, int> impl;
3477 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3478 lda, x, incx);
3479 }
3480
ThenBlasTrsv(blas::UpperLower uplo,blas::Transpose trans,blas::Diagonal diag,uint64 n,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * x,int incx)3481 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
3482 blas::Diagonal diag, uint64 n,
3483 const DeviceMemory<std::complex<double>> &a,
3484 int lda, DeviceMemory<std::complex<double>> *x,
3485 int incx) {
3486 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
3487 PARAM(lda), PARAM(x), PARAM(incx));
3488
3489 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
3490 const DeviceMemory<std::complex<double>> &, int,
3491 DeviceMemory<std::complex<double>> *, int> impl;
3492 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
3493 lda, x, incx);
3494 }
3495
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc)3496 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3497 uint64 m, uint64 n, uint64 k, float alpha,
3498 const DeviceMemory<Eigen::half> &a, int lda,
3499 const DeviceMemory<Eigen::half> &b, int ldb,
3500 float beta,
3501 DeviceMemory<Eigen::half> *c, int ldc) {
3502 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3503 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3504 PARAM(beta), PARAM(c), PARAM(ldc));
3505
3506 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3507 const DeviceMemory<Eigen::half> &, int,
3508 const DeviceMemory<Eigen::half> &, int,
3509 float, DeviceMemory<Eigen::half> *, int> impl;
3510 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3511 alpha, a, lda, b, ldb, beta, c, ldc);
3512 }
3513
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)3514 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3515 uint64 m, uint64 n, uint64 k, float alpha,
3516 const DeviceMemory<float> &a, int lda,
3517 const DeviceMemory<float> &b, int ldb, float beta,
3518 DeviceMemory<float> *c, int ldc) {
3519 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3520 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3521 PARAM(beta), PARAM(c), PARAM(ldc));
3522
3523 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3524 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
3525 int, float, DeviceMemory<float> *, int> impl;
3526 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3527 alpha, a, lda, b, ldb, beta, c, ldc);
3528 }
3529
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)3530 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3531 uint64 m, uint64 n, uint64 k, double alpha,
3532 const DeviceMemory<double> &a, int lda,
3533 const DeviceMemory<double> &b, int ldb,
3534 double beta, DeviceMemory<double> *c, int ldc) {
3535 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3536 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3537 PARAM(beta), PARAM(c), PARAM(ldc));
3538
3539 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
3540 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
3541 int, double, DeviceMemory<double> *, int> impl;
3542 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3543 alpha, a, lda, b, ldb, beta, c, ldc);
3544 }
3545
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3546 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3547 uint64 m, uint64 n, uint64 k,
3548 std::complex<float> alpha,
3549 const DeviceMemory<std::complex<float>> &a,
3550 int lda,
3551 const DeviceMemory<std::complex<float>> &b,
3552 int ldb, std::complex<float> beta,
3553 DeviceMemory<std::complex<float>> *c, int ldc) {
3554 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3555 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3556 PARAM(beta), PARAM(c), PARAM(ldc));
3557
3558 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3559 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3560 int, const DeviceMemory<std::complex<float>> &, int,
3561 std::complex<float>, DeviceMemory<std::complex<float>> *,
3562 int> impl;
3563 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3564 alpha, a, lda, b, ldb, beta, c, ldc);
3565 }
3566
ThenBlasGemm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3567 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
3568 uint64 m, uint64 n, uint64 k,
3569 std::complex<double> alpha,
3570 const DeviceMemory<std::complex<double>> &a,
3571 int lda,
3572 const DeviceMemory<std::complex<double>> &b,
3573 int ldb, std::complex<double> beta,
3574 DeviceMemory<std::complex<double>> *c, int ldc) {
3575 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3576 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3577 PARAM(beta), PARAM(c), PARAM(ldc));
3578
3579 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3580 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3581 int, const DeviceMemory<std::complex<double>> &, int,
3582 std::complex<double>, DeviceMemory<std::complex<double>> *,
3583 int> impl;
3584 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
3585 alpha, a, lda, b, ldb, beta, c, ldc);
3586 }
3587
3588 namespace {
3589 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
3590 // blas::ProfileResult*. This functor doesn't put the stream into an error
3591 // state if the op fails and the profile result is non-null. Instead, the
3592 // error-ness is returned in the profile result itself.
3593 template <typename... Args>
3594 struct ThenBlasWithProfileImpl {
operator ()perftools::gputools::__anoncb2e7d7b0211::ThenBlasWithProfileImpl3595 Stream &operator()(Stream *stream,
3596 bool (blas::BlasSupport::*blas_func)(
3597 Stream *, Args..., blas::ProfileResult *),
3598 Args... args, blas::ProfileResult *profile_result) {
3599 ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
3600 bool record_error = profile_result == nullptr;
3601 return Runner.Run(stream, blas_func, record_error, args..., profile_result);
3602 }
3603 };
3604 } // anonymous namespace
3605
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & x,int incx,float beta,DeviceMemory<float> * y,int incy,blas::ProfileResult * output_profile_result)3606 Stream &Stream::ThenBlasGemvWithProfiling(
3607 blas::Transpose trans, uint64 m, uint64 n, float alpha,
3608 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
3609 int incx, float beta, DeviceMemory<float> *y, int incy,
3610 blas::ProfileResult *output_profile_result) {
3611 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3612 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3613 PARAM(incy));
3614
3615 ThenBlasWithProfileImpl<
3616 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
3617 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
3618 impl;
3619 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3620 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3621 }
3622
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & x,int incx,double beta,DeviceMemory<double> * y,int incy,blas::ProfileResult * output_profile_result)3623 Stream &Stream::ThenBlasGemvWithProfiling(
3624 blas::Transpose trans, uint64 m, uint64 n, double alpha,
3625 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
3626 int incx, double beta, DeviceMemory<double> *y, int incy,
3627 blas::ProfileResult *output_profile_result) {
3628 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3629 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3630 PARAM(incy));
3631
3632 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
3633 const DeviceMemory<double> &, int,
3634 const DeviceMemory<double> &, int, double,
3635 DeviceMemory<double> *, int>
3636 impl;
3637 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3638 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3639 }
3640
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & x,int incx,std::complex<float> beta,DeviceMemory<std::complex<float>> * y,int incy,blas::ProfileResult * output_profile_result)3641 Stream &Stream::ThenBlasGemvWithProfiling(
3642 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
3643 const DeviceMemory<std::complex<float>> &a, int lda,
3644 const DeviceMemory<std::complex<float>> &x, int incx,
3645 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
3646 blas::ProfileResult *output_profile_result) {
3647 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3648 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3649 PARAM(incy));
3650
3651 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
3652 const DeviceMemory<std::complex<float>> &, int,
3653 const DeviceMemory<std::complex<float>> &, int,
3654 std::complex<float>,
3655 DeviceMemory<std::complex<float>> *, int>
3656 impl;
3657 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3658 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3659 }
3660
ThenBlasGemvWithProfiling(blas::Transpose trans,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & x,int incx,std::complex<double> beta,DeviceMemory<std::complex<double>> * y,int incy,blas::ProfileResult * output_profile_result)3661 Stream &Stream::ThenBlasGemvWithProfiling(
3662 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
3663 const DeviceMemory<std::complex<double>> &a, int lda,
3664 const DeviceMemory<std::complex<double>> &x, int incx,
3665 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
3666 blas::ProfileResult *output_profile_result) {
3667 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
3668 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
3669 PARAM(incy));
3670
3671 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
3672 const DeviceMemory<std::complex<double>> &, int,
3673 const DeviceMemory<std::complex<double>> &, int,
3674 std::complex<double>,
3675 DeviceMemory<std::complex<double>> *, int>
3676 impl;
3677 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
3678 alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
3679 }
3680
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,float beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ProfileResult * output_profile_result)3681 Stream &Stream::ThenBlasGemmWithProfiling(
3682 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3683 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
3684 const DeviceMemory<Eigen::half> &b, int ldb, float beta,
3685 DeviceMemory<Eigen::half> *c, int ldc,
3686 blas::ProfileResult *output_profile_result) {
3687 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3688 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3689 PARAM(beta), PARAM(c), PARAM(ldc));
3690
3691 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3692 uint64, float, const DeviceMemory<Eigen::half> &, int,
3693 const DeviceMemory<Eigen::half> &, int, float,
3694 DeviceMemory<Eigen::half> *, int>
3695 impl;
3696 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3697 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3698 output_profile_result);
3699 }
3700
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ProfileResult * output_profile_result)3701 Stream &Stream::ThenBlasGemmWithProfiling(
3702 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3703 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3704 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3705 int ldc, blas::ProfileResult *output_profile_result) {
3706 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3707 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3708 PARAM(beta), PARAM(c), PARAM(ldc));
3709
3710 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3711 uint64, float, const DeviceMemory<float> &, int,
3712 const DeviceMemory<float> &, int, float,
3713 DeviceMemory<float> *, int>
3714 impl;
3715 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3716 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3717 output_profile_result);
3718 }
3719
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ProfileResult * output_profile_result)3720 Stream &Stream::ThenBlasGemmWithProfiling(
3721 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3722 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3723 const DeviceMemory<double> &b, int ldb, double beta,
3724 DeviceMemory<double> *c, int ldc,
3725 blas::ProfileResult *output_profile_result) {
3726 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3727 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3728 PARAM(beta), PARAM(c), PARAM(ldc));
3729
3730 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3731 uint64, double, const DeviceMemory<double> &, int,
3732 const DeviceMemory<double> &, int, double,
3733 DeviceMemory<double> *, int>
3734 impl;
3735 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3736 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3737 output_profile_result);
3738 }
3739
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ProfileResult * output_profile_result)3740 Stream &Stream::ThenBlasGemmWithProfiling(
3741 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3742 uint64 k, std::complex<float> alpha,
3743 const DeviceMemory<std::complex<float>> &a, int lda,
3744 const DeviceMemory<std::complex<float>> &b, int ldb,
3745 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3746 blas::ProfileResult *output_profile_result) {
3747 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3748 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3749 PARAM(beta), PARAM(c), PARAM(ldc));
3750
3751 ThenBlasWithProfileImpl<
3752 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3753 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3754 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3755 DeviceMemory<std::complex<float>> *, int>
3756 impl;
3757 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3758 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3759 output_profile_result);
3760 }
3761
ThenBlasGemmWithProfiling(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ProfileResult * output_profile_result)3762 Stream &Stream::ThenBlasGemmWithProfiling(
3763 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3764 uint64 k, std::complex<double> alpha,
3765 const DeviceMemory<std::complex<double>> &a, int lda,
3766 const DeviceMemory<std::complex<double>> &b, int ldb,
3767 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3768 blas::ProfileResult *output_profile_result) {
3769 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3770 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3771 PARAM(beta), PARAM(c), PARAM(ldc));
3772
3773 ThenBlasWithProfileImpl<
3774 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3775 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3776 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3777 DeviceMemory<std::complex<double>> *, int>
3778 impl;
3779 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
3780 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
3781 output_profile_result);
3782 }
3783
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,const Eigen::half & alpha,const DeviceMemory<Eigen::half> & a,int lda,const DeviceMemory<Eigen::half> & b,int ldb,const Eigen::half & beta,DeviceMemory<Eigen::half> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3784 Stream &Stream::ThenBlasGemmWithAlgorithm(
3785 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3786 uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
3787 int lda, const DeviceMemory<Eigen::half> &b, int ldb,
3788 const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
3789 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3790 blas::ProfileResult *output_profile_result) {
3791 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3792 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3793 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3794 PARAM(algorithm));
3795
3796 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3797 uint64, const Eigen::half &,
3798 const DeviceMemory<Eigen::half> &, int,
3799 const DeviceMemory<Eigen::half> &, int,
3800 const Eigen::half &, DeviceMemory<Eigen::half> *, int,
3801 blas::ComputationType, blas::AlgorithmType>
3802 impl;
3803 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3804 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3805 algorithm, output_profile_result);
3806 }
3807
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,int alpha,const DeviceMemory<int8> & a,int lda,const DeviceMemory<int8> & b,int ldb,int beta,DeviceMemory<int> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3808 Stream &Stream::ThenBlasGemmWithAlgorithm(
3809 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3810 uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
3811 const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
3812 int ldc, blas::ComputationType computation_type,
3813 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3814 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3815 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3816 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3817 PARAM(algorithm));
3818
3819 ThenBlasWithProfileImpl<
3820 blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
3821 const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
3822 DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
3823 impl;
3824 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3825 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3826 algorithm, output_profile_result);
3827 }
3828
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3829 Stream &Stream::ThenBlasGemmWithAlgorithm(
3830 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3831 uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
3832 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
3833 int ldc, blas::ComputationType computation_type,
3834 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3835 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3836 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3837 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3838 PARAM(algorithm));
3839
3840 ThenBlasWithProfileImpl<
3841 blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
3842 const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
3843 DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
3844 impl;
3845 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3846 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3847 algorithm, output_profile_result);
3848 }
3849
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3850 Stream &Stream::ThenBlasGemmWithAlgorithm(
3851 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3852 uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
3853 const DeviceMemory<double> &b, int ldb, double beta,
3854 DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
3855 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
3856 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3857 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3858 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3859 PARAM(algorithm));
3860
3861 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
3862 uint64, double, const DeviceMemory<double> &, int,
3863 const DeviceMemory<double> &, int, double,
3864 DeviceMemory<double> *, int, blas::ComputationType,
3865 blas::AlgorithmType>
3866 impl;
3867 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3868 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3869 algorithm, output_profile_result);
3870 }
3871
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3872 Stream &Stream::ThenBlasGemmWithAlgorithm(
3873 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3874 uint64 k, std::complex<float> alpha,
3875 const DeviceMemory<std::complex<float>> &a, int lda,
3876 const DeviceMemory<std::complex<float>> &b, int ldb,
3877 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
3878 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3879 blas::ProfileResult *output_profile_result) {
3880 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3881 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3882 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3883 PARAM(algorithm));
3884
3885 ThenBlasWithProfileImpl<
3886 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3887 std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
3888 const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
3889 DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
3890 blas::AlgorithmType>
3891 impl;
3892 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3893 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3894 algorithm, output_profile_result);
3895 }
3896
ThenBlasGemmWithAlgorithm(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc,blas::ComputationType computation_type,blas::AlgorithmType algorithm,blas::ProfileResult * output_profile_result)3897 Stream &Stream::ThenBlasGemmWithAlgorithm(
3898 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
3899 uint64 k, std::complex<double> alpha,
3900 const DeviceMemory<std::complex<double>> &a, int lda,
3901 const DeviceMemory<std::complex<double>> &b, int ldb,
3902 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
3903 blas::ComputationType computation_type, blas::AlgorithmType algorithm,
3904 blas::ProfileResult *output_profile_result) {
3905 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
3906 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
3907 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
3908 PARAM(algorithm));
3909
3910 ThenBlasWithProfileImpl<
3911 blas::Transpose, blas::Transpose, uint64, uint64, uint64,
3912 std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
3913 const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
3914 DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
3915 blas::AlgorithmType>
3916 impl;
3917 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
3918 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
3919 algorithm, output_profile_result);
3920 }
3921
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)3922 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3923 uint64 n, std::complex<float> alpha,
3924 const DeviceMemory<std::complex<float>> &a,
3925 int lda,
3926 const DeviceMemory<std::complex<float>> &b,
3927 int ldb, std::complex<float> beta,
3928 DeviceMemory<std::complex<float>> *c, int ldc) {
3929 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3930 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3931 PARAM(ldc));
3932
3933 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3934 std::complex<float>, const DeviceMemory<std::complex<float>> &,
3935 int, const DeviceMemory<std::complex<float>> &, int,
3936 std::complex<float>, DeviceMemory<std::complex<float>> *,
3937 int> impl;
3938 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3939 lda, b, ldb, beta, c, ldc);
3940 }
3941
ThenBlasHemm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)3942 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
3943 uint64 n, std::complex<double> alpha,
3944 const DeviceMemory<std::complex<double>> &a,
3945 int lda,
3946 const DeviceMemory<std::complex<double>> &b,
3947 int ldb, std::complex<double> beta,
3948 DeviceMemory<std::complex<double>> *c, int ldc) {
3949 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
3950 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
3951 PARAM(ldc));
3952
3953 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
3954 std::complex<double>, const DeviceMemory<std::complex<double>> &,
3955 int, const DeviceMemory<std::complex<double>> &, int,
3956 std::complex<double>, DeviceMemory<std::complex<double>> *,
3957 int> impl;
3958 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
3959 lda, b, ldb, beta, c, ldc);
3960 }
3961
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<std::complex<float>> & a,int lda,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3962 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3963 uint64 n, uint64 k, float alpha,
3964 const DeviceMemory<std::complex<float>> &a,
3965 int lda, float beta,
3966 DeviceMemory<std::complex<float>> *c, int ldc) {
3967 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3968 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3969
3970 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
3971 const DeviceMemory<std::complex<float>> &, int, float,
3972 DeviceMemory<std::complex<float>> *, int> impl;
3973 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3974 lda, beta, c, ldc);
3975 }
3976
ThenBlasHerk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<std::complex<double>> & a,int lda,double beta,DeviceMemory<std::complex<double>> * c,int ldc)3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
3978 uint64 n, uint64 k, double alpha,
3979 const DeviceMemory<std::complex<double>> &a,
3980 int lda, double beta,
3981 DeviceMemory<std::complex<double>> *c, int ldc) {
3982 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
3983 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
3984
3985 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
3986 const DeviceMemory<std::complex<double>> &, int, double,
3987 DeviceMemory<std::complex<double>> *, int> impl;
3988 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
3989 lda, beta, c, ldc);
3990 }
3991
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,float beta,DeviceMemory<std::complex<float>> * c,int ldc)3992 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
3993 uint64 n, uint64 k, std::complex<float> alpha,
3994 const DeviceMemory<std::complex<float>> &a,
3995 int lda,
3996 const DeviceMemory<std::complex<float>> &b,
3997 int ldb, float beta,
3998 DeviceMemory<std::complex<float>> *c, int ldc) {
3999 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4000 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4001 PARAM(ldc));
4002
4003 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4004 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4005 int, const DeviceMemory<std::complex<float>> &, int, float,
4006 DeviceMemory<std::complex<float>> *, int> impl;
4007 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4008 a, lda, b, ldb, beta, c, ldc);
4009 }
4010
ThenBlasHer2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,double beta,DeviceMemory<std::complex<double>> * c,int ldc)4011 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
4012 uint64 n, uint64 k, std::complex<double> alpha,
4013 const DeviceMemory<std::complex<double>> &a,
4014 int lda,
4015 const DeviceMemory<std::complex<double>> &b,
4016 int ldb, double beta,
4017 DeviceMemory<std::complex<double>> *c, int ldc) {
4018 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4019 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4020 PARAM(ldc));
4021
4022 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4023 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4024 int, const DeviceMemory<std::complex<double>> &, int, double,
4025 DeviceMemory<std::complex<double>> *, int> impl;
4026 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
4027 a, lda, b, ldb, beta, c, ldc);
4028 }
4029
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4030 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4031 uint64 n, float alpha,
4032 const DeviceMemory<float> &a, int lda,
4033 const DeviceMemory<float> &b, int ldb, float beta,
4034 DeviceMemory<float> *c, int ldc) {
4035 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4036 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4037 PARAM(ldc));
4038
4039 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
4040 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4041 int, float, DeviceMemory<float> *, int> impl;
4042 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4043 lda, b, ldb, beta, c, ldc);
4044 }
4045
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4046 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4047 uint64 n, double alpha,
4048 const DeviceMemory<double> &a, int lda,
4049 const DeviceMemory<double> &b, int ldb,
4050 double beta, DeviceMemory<double> *c, int ldc) {
4051 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4052 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4053 PARAM(ldc));
4054
4055 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
4056 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4057 int, double, DeviceMemory<double> *, int> impl;
4058 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4059 lda, b, ldb, beta, c, ldc);
4060 }
4061
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4062 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4063 uint64 n, std::complex<float> alpha,
4064 const DeviceMemory<std::complex<float>> &a,
4065 int lda,
4066 const DeviceMemory<std::complex<float>> &b,
4067 int ldb, std::complex<float> beta,
4068 DeviceMemory<std::complex<float>> *c, int ldc) {
4069 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4070 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4071 PARAM(ldc));
4072
4073 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4074 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4075 int, const DeviceMemory<std::complex<float>> &, int,
4076 std::complex<float>, DeviceMemory<std::complex<float>> *,
4077 int> impl;
4078 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4079 lda, b, ldb, beta, c, ldc);
4080 }
4081
ThenBlasSymm(blas::Side side,blas::UpperLower uplo,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4082 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
4083 uint64 n, std::complex<double> alpha,
4084 const DeviceMemory<std::complex<double>> &a,
4085 int lda,
4086 const DeviceMemory<std::complex<double>> &b,
4087 int ldb, std::complex<double> beta,
4088 DeviceMemory<std::complex<double>> *c, int ldc) {
4089 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
4090 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4091 PARAM(ldc));
4092
4093 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
4094 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4095 int, const DeviceMemory<std::complex<double>> &, int,
4096 std::complex<double>, DeviceMemory<std::complex<double>> *,
4097 int> impl;
4098 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
4099 lda, b, ldb, beta, c, ldc);
4100 }
4101
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,float beta,DeviceMemory<float> * c,int ldc)4102 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4103 uint64 n, uint64 k, float alpha,
4104 const DeviceMemory<float> &a, int lda, float beta,
4105 DeviceMemory<float> *c, int ldc) {
4106 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4107 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4108
4109 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4110 const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
4111 int> impl;
4112 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4113 lda, beta, c, ldc);
4114 }
4115
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,double beta,DeviceMemory<double> * c,int ldc)4116 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4117 uint64 n, uint64 k, double alpha,
4118 const DeviceMemory<double> &a, int lda,
4119 double beta, DeviceMemory<double> *c, int ldc) {
4120 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4121 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4122
4123 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4124 const DeviceMemory<double> &, int, double,
4125 DeviceMemory<double> *, int> impl;
4126 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4127 lda, beta, c, ldc);
4128 }
4129
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4130 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4131 uint64 n, uint64 k, std::complex<float> alpha,
4132 const DeviceMemory<std::complex<float>> &a,
4133 int lda, std::complex<float> beta,
4134 DeviceMemory<std::complex<float>> *c, int ldc) {
4135 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4136 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4137
4138 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4139 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4140 int, std::complex<float>, DeviceMemory<std::complex<float>> *,
4141 int> impl;
4142 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4143 lda, beta, c, ldc);
4144 }
4145
ThenBlasSyrk(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4146 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
4147 uint64 n, uint64 k, std::complex<double> alpha,
4148 const DeviceMemory<std::complex<double>> &a,
4149 int lda, std::complex<double> beta,
4150 DeviceMemory<std::complex<double>> *c, int ldc) {
4151 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4152 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
4153
4154 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4155 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4156 int, std::complex<double>, DeviceMemory<std::complex<double>> *,
4157 int> impl;
4158 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
4159 lda, beta, c, ldc);
4160 }
4161
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,float alpha,const DeviceMemory<float> & a,int lda,const DeviceMemory<float> & b,int ldb,float beta,DeviceMemory<float> * c,int ldc)4162 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4163 uint64 n, uint64 k, float alpha,
4164 const DeviceMemory<float> &a, int lda,
4165 const DeviceMemory<float> &b, int ldb, float beta,
4166 DeviceMemory<float> *c, int ldc) {
4167 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4168 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4169 PARAM(ldc));
4170
4171 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
4172 const DeviceMemory<float> &, int, const DeviceMemory<float> &,
4173 int, float, DeviceMemory<float> *, int> impl;
4174 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4175 a, lda, b, ldb, beta, c, ldc);
4176 }
4177
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,double alpha,const DeviceMemory<double> & a,int lda,const DeviceMemory<double> & b,int ldb,double beta,DeviceMemory<double> * c,int ldc)4178 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4179 uint64 n, uint64 k, double alpha,
4180 const DeviceMemory<double> &a, int lda,
4181 const DeviceMemory<double> &b, int ldb,
4182 double beta, DeviceMemory<double> *c, int ldc) {
4183 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4184 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4185 PARAM(ldc));
4186
4187 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
4188 const DeviceMemory<double> &, int, const DeviceMemory<double> &,
4189 int, double, DeviceMemory<double> *, int> impl;
4190 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4191 a, lda, b, ldb, beta, c, ldc);
4192 }
4193
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,const DeviceMemory<std::complex<float>> & b,int ldb,std::complex<float> beta,DeviceMemory<std::complex<float>> * c,int ldc)4194 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4195 uint64 n, uint64 k, std::complex<float> alpha,
4196 const DeviceMemory<std::complex<float>> &a,
4197 int lda,
4198 const DeviceMemory<std::complex<float>> &b,
4199 int ldb, std::complex<float> beta,
4200 DeviceMemory<std::complex<float>> *c, int ldc) {
4201 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4202 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4203 PARAM(ldc));
4204
4205 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4206 std::complex<float>, const DeviceMemory<std::complex<float>> &,
4207 int, const DeviceMemory<std::complex<float>> &, int,
4208 std::complex<float>, DeviceMemory<std::complex<float>> *,
4209 int> impl;
4210 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4211 a, lda, b, ldb, beta, c, ldc);
4212 }
4213
ThenBlasSyr2k(blas::UpperLower uplo,blas::Transpose trans,uint64 n,uint64 k,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,const DeviceMemory<std::complex<double>> & b,int ldb,std::complex<double> beta,DeviceMemory<std::complex<double>> * c,int ldc)4214 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
4215 uint64 n, uint64 k, std::complex<double> alpha,
4216 const DeviceMemory<std::complex<double>> &a,
4217 int lda,
4218 const DeviceMemory<std::complex<double>> &b,
4219 int ldb, std::complex<double> beta,
4220 DeviceMemory<std::complex<double>> *c, int ldc) {
4221 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
4222 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
4223 PARAM(ldc));
4224
4225 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
4226 std::complex<double>, const DeviceMemory<std::complex<double>> &,
4227 int, const DeviceMemory<std::complex<double>> &, int,
4228 std::complex<double>, DeviceMemory<std::complex<double>> *,
4229 int> impl;
4230 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
4231 a, lda, b, ldb, beta, c, ldc);
4232 }
4233
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4234 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4235 blas::Transpose transa, blas::Diagonal diag,
4236 uint64 m, uint64 n, float alpha,
4237 const DeviceMemory<float> &a, int lda,
4238 DeviceMemory<float> *b, int ldb) {
4239 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4240 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4241
4242 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4243 uint64, uint64, float, const DeviceMemory<float> &, int,
4244 DeviceMemory<float> *, int> impl;
4245 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4246 n, alpha, a, lda, b, ldb);
4247 }
4248
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4250 blas::Transpose transa, blas::Diagonal diag,
4251 uint64 m, uint64 n, double alpha,
4252 const DeviceMemory<double> &a, int lda,
4253 DeviceMemory<double> *b, int ldb) {
4254 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4255 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4256
4257 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4258 uint64, uint64, double, const DeviceMemory<double> &, int,
4259 DeviceMemory<double> *, int> impl;
4260 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4261 n, alpha, a, lda, b, ldb);
4262 }
4263
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4265 blas::Transpose transa, blas::Diagonal diag,
4266 uint64 m, uint64 n, std::complex<float> alpha,
4267 const DeviceMemory<std::complex<float>> &a,
4268 int lda, DeviceMemory<std::complex<float>> *b,
4269 int ldb) {
4270 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4271 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4272
4273 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4274 uint64, uint64, std::complex<float>,
4275 const DeviceMemory<std::complex<float>> &, int,
4276 DeviceMemory<std::complex<float>> *, int> impl;
4277 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4278 n, alpha, a, lda, b, ldb);
4279 }
4280
ThenBlasTrmm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4281 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
4282 blas::Transpose transa, blas::Diagonal diag,
4283 uint64 m, uint64 n, std::complex<double> alpha,
4284 const DeviceMemory<std::complex<double>> &a,
4285 int lda, DeviceMemory<std::complex<double>> *b,
4286 int ldb) {
4287 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4288 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4289
4290 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4291 uint64, uint64, std::complex<double>,
4292 const DeviceMemory<std::complex<double>> &, int,
4293 DeviceMemory<std::complex<double>> *, int> impl;
4294 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
4295 n, alpha, a, lda, b, ldb);
4296 }
4297
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,float alpha,const DeviceMemory<float> & a,int lda,DeviceMemory<float> * b,int ldb)4298 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4299 blas::Transpose transa, blas::Diagonal diag,
4300 uint64 m, uint64 n, float alpha,
4301 const DeviceMemory<float> &a, int lda,
4302 DeviceMemory<float> *b, int ldb) {
4303 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4304 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4305
4306 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4307 uint64, uint64, float, const DeviceMemory<float> &, int,
4308 DeviceMemory<float> *, int> impl;
4309 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4310 n, alpha, a, lda, b, ldb);
4311 }
4312
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,double alpha,const DeviceMemory<double> & a,int lda,DeviceMemory<double> * b,int ldb)4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4314 blas::Transpose transa, blas::Diagonal diag,
4315 uint64 m, uint64 n, double alpha,
4316 const DeviceMemory<double> &a, int lda,
4317 DeviceMemory<double> *b, int ldb) {
4318 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4319 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4320
4321 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4322 uint64, uint64, double, const DeviceMemory<double> &, int,
4323 DeviceMemory<double> *, int> impl;
4324 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4325 n, alpha, a, lda, b, ldb);
4326 }
4327
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<float> alpha,const DeviceMemory<std::complex<float>> & a,int lda,DeviceMemory<std::complex<float>> * b,int ldb)4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4329 blas::Transpose transa, blas::Diagonal diag,
4330 uint64 m, uint64 n, std::complex<float> alpha,
4331 const DeviceMemory<std::complex<float>> &a,
4332 int lda, DeviceMemory<std::complex<float>> *b,
4333 int ldb) {
4334 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4335 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4336
4337 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4338 uint64, uint64, std::complex<float>,
4339 const DeviceMemory<std::complex<float>> &, int,
4340 DeviceMemory<std::complex<float>> *, int> impl;
4341 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4342 n, alpha, a, lda, b, ldb);
4343 }
4344
ThenBlasTrsm(blas::Side side,blas::UpperLower uplo,blas::Transpose transa,blas::Diagonal diag,uint64 m,uint64 n,std::complex<double> alpha,const DeviceMemory<std::complex<double>> & a,int lda,DeviceMemory<std::complex<double>> * b,int ldb)4345 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
4346 blas::Transpose transa, blas::Diagonal diag,
4347 uint64 m, uint64 n, std::complex<double> alpha,
4348 const DeviceMemory<std::complex<double>> &a,
4349 int lda, DeviceMemory<std::complex<double>> *b,
4350 int ldb) {
4351 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
4352 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
4353
4354 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
4355 uint64, uint64, std::complex<double>,
4356 const DeviceMemory<std::complex<double>> &, int,
4357 DeviceMemory<std::complex<double>> *, int> impl;
4358 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
4359 n, alpha, a, lda, b, ldb);
4360 }
4361
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count)4362 Stream &Stream::ThenBlasGemmBatched(
4363 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4364 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4365 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4366 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4367 int batch_count) {
4368 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4369 b, ldb, beta, c, ldc, batch_count,
4370 /*scratch_allocator=*/nullptr);
4371 }
4372
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,float alpha,const port::ArraySlice<DeviceMemory<float> * > & a,int lda,const port::ArraySlice<DeviceMemory<float> * > & b,int ldb,float beta,const port::ArraySlice<DeviceMemory<float> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4373 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4374 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4375 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
4376 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
4377 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
4378 int batch_count, ScratchAllocator *scratch_allocator) {
4379 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4380 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4381 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4382
4383 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
4384 const port::ArraySlice<DeviceMemory<float> *> &, int,
4385 const port::ArraySlice<DeviceMemory<float> *> &, int, float,
4386 const port::ArraySlice<DeviceMemory<float> *> &, int, int,
4387 ScratchAllocator *>
4388 impl;
4389 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4390 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4391 scratch_allocator);
4392 }
4393
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count)4394 Stream &Stream::ThenBlasGemmBatched(
4395 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4396 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4397 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4398 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4399 int batch_count) {
4400 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4401 b, ldb, beta, c, ldc, batch_count,
4402 /*scratch_allocator=*/nullptr);
4403 }
4404
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,double alpha,const port::ArraySlice<DeviceMemory<double> * > & a,int lda,const port::ArraySlice<DeviceMemory<double> * > & b,int ldb,double beta,const port::ArraySlice<DeviceMemory<double> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4405 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4406 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4407 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
4408 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
4409 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
4410 int batch_count, ScratchAllocator *scratch_allocator) {
4411 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4412 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4413 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4414
4415 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
4416 const port::ArraySlice<DeviceMemory<double> *> &, int,
4417 const port::ArraySlice<DeviceMemory<double> *> &, int, double,
4418 const port::ArraySlice<DeviceMemory<double> *> &, int, int,
4419 ScratchAllocator *>
4420 impl;
4421 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4422 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4423 scratch_allocator);
4424 }
4425
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count)4426 Stream &Stream::ThenBlasGemmBatched(
4427 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4428 uint64 k, std::complex<float> alpha,
4429 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4430 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4431 std::complex<float> beta,
4432 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4433 int batch_count) {
4434 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4435 b, ldb, beta, c, ldc, batch_count,
4436 /*scratch_allocator=*/nullptr);
4437 }
4438
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<float> alpha,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & b,int ldb,std::complex<float> beta,const port::ArraySlice<DeviceMemory<std::complex<float>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4439 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4440 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4441 uint64 k, std::complex<float> alpha,
4442 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
4443 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
4444 std::complex<float> beta,
4445 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
4446 int batch_count, ScratchAllocator *scratch_allocator) {
4447 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4448 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4449 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4450
4451 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4452 std::complex<float>,
4453 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4454 int,
4455 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4456 int, std::complex<float>,
4457 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
4458 int, int, ScratchAllocator *>
4459 impl;
4460 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4461 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4462 scratch_allocator);
4463 }
4464
ThenBlasGemmBatched(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count)4465 Stream &Stream::ThenBlasGemmBatched(
4466 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4467 uint64 k, std::complex<double> alpha,
4468 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4469 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4470 std::complex<double> beta,
4471 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4472 int batch_count) {
4473 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
4474 b, ldb, beta, c, ldc, batch_count,
4475 /*scratch_allocator=*/nullptr);
4476 }
4477
ThenBlasGemmBatchedWithScratch(blas::Transpose transa,blas::Transpose transb,uint64 m,uint64 n,uint64 k,std::complex<double> alpha,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & a,int lda,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & b,int ldb,std::complex<double> beta,const port::ArraySlice<DeviceMemory<std::complex<double>> * > & c,int ldc,int batch_count,ScratchAllocator * scratch_allocator)4478 Stream &Stream::ThenBlasGemmBatchedWithScratch(
4479 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
4480 uint64 k, std::complex<double> alpha,
4481 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
4482 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
4483 std::complex<double> beta,
4484 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
4485 int batch_count, ScratchAllocator *scratch_allocator) {
4486 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
4487 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
4488 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
4489
4490 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
4491 std::complex<double>,
4492 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4493 int,
4494 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4495 int, std::complex<double>,
4496 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
4497 int, int, ScratchAllocator *>
4498 impl;
4499 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
4500 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
4501 scratch_allocator);
4502 }
4503
ThenSetRngSeed(const uint8 * seed,uint64 seed_bytes)4504 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
4505 VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
4506
4507 if (ok()) {
4508 if (rng::RngSupport *rng = parent_->AsRng()) {
4509 CheckError(rng->SetSeed(this, seed, seed_bytes));
4510 } else {
4511 SetError();
4512 LOG(INFO) << "stream " << this << " unable to initialize RNG";
4513 }
4514 } else {
4515 LOG(INFO) << "stream " << this
4516 << " did not set RNG seed: " << static_cast<const void *>(seed)
4517 << "; bytes: " << seed_bytes;
4518 }
4519 return *this;
4520 }
4521
ThenPopulateRandUniform(DeviceMemory<float> * values)4522 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
4523 VLOG_CALL(PARAM(values));
4524
4525 if (ok()) {
4526 if (rng::RngSupport *rng = parent_->AsRng()) {
4527 CheckError(rng->DoPopulateRandUniform(this, values));
4528 } else {
4529 SetError();
4530 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4531 "without RNG support.";
4532 }
4533 }
4534 return *this;
4535 }
4536
ThenPopulateRandGaussian(float mean,float sd,DeviceMemory<float> * values)4537 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
4538 DeviceMemory<float> *values) {
4539 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4540
4541 if (ok()) {
4542 if (rng::RngSupport *rng = parent_->AsRng()) {
4543 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4544 } else {
4545 SetError();
4546 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4547 "without RNG support.";
4548 }
4549 }
4550 return *this;
4551 }
4552
ThenPopulateRandGaussian(double mean,double sd,DeviceMemory<double> * values)4553 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
4554 DeviceMemory<double> *values) {
4555 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
4556
4557 if (ok()) {
4558 if (rng::RngSupport *rng = parent_->AsRng()) {
4559 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
4560 } else {
4561 SetError();
4562 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4563 "without RNG support.";
4564 }
4565 }
4566 return *this;
4567 }
4568
ThenPopulateRandUniform(DeviceMemory<double> * values)4569 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
4570 VLOG_CALL(PARAM(values));
4571
4572 if (ok()) {
4573 if (rng::RngSupport *rng = parent_->AsRng()) {
4574 CheckError(rng->DoPopulateRandUniform(this, values));
4575 } else {
4576 SetError();
4577 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4578 "without RNG support.";
4579 }
4580 }
4581 return *this;
4582 }
4583
ThenPopulateRandUniform(DeviceMemory<std::complex<float>> * values)4584 Stream &Stream::ThenPopulateRandUniform(
4585 DeviceMemory<std::complex<float>> *values) {
4586 VLOG_CALL(PARAM(values));
4587
4588 if (ok()) {
4589 if (rng::RngSupport *rng = parent_->AsRng()) {
4590 CheckError(rng->DoPopulateRandUniform(this, values));
4591 } else {
4592 SetError();
4593 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
4594 "without RNG support.";
4595 }
4596 }
4597 return *this;
4598 }
4599
ThenPopulateRandUniform(DeviceMemory<std::complex<double>> * values)4600 Stream &Stream::ThenPopulateRandUniform(
4601 DeviceMemory<std::complex<double>> *values) {
4602 VLOG_CALL(PARAM(values));
4603
4604 if (ok()) {
4605 if (rng::RngSupport *rng = parent_->AsRng()) {
4606 CheckError(rng->DoPopulateRandUniform(this, values));
4607 } else {
4608 SetError();
4609 LOG(INFO) << "stream " << this
4610 << " attempting to perform RNG operation using StreamExecutor "
4611 "without RNG support.";
4612 }
4613 }
4614 return *this;
4615 }
4616
ThenMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)4617 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
4618 uint64 size) {
4619 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
4620
4621 if (ok()) {
4622 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
4623 } else {
4624 LOG(INFO) << "stream " << this
4625 << " did not memcpy device-to-host; source: " << gpu_src.opaque();
4626 }
4627 return *this;
4628 }
4629
ThenMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)4630 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
4631 uint64 size) {
4632 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
4633
4634 if (ok()) {
4635 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
4636 } else {
4637 LOG(INFO) << "stream " << this
4638 << " did not memcpy host-to-device; source: " << host_src;
4639 }
4640 return *this;
4641 }
4642
ThenMemcpy(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)4643 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
4644 const DeviceMemoryBase &gpu_src, uint64 size) {
4645 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
4646
4647 if (ok()) {
4648 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
4649 } else {
4650 LOG(INFO) << "stream " << this
4651 << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
4652 }
4653 return *this;
4654 }
4655
ThenMemZero(DeviceMemoryBase * location,uint64 size)4656 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
4657 VLOG_CALL(PARAM(location), PARAM(size));
4658
4659 if (ok()) {
4660 CheckError(parent_->MemZero(this, location, size));
4661 } else {
4662 LOG(INFO) << "stream " << this
4663 << " did not memzero GPU location; source: " << location;
4664 }
4665 return *this;
4666 }
4667
ThenMemset32(DeviceMemoryBase * location,uint32 pattern,uint64 size)4668 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
4669 uint64 size) {
4670 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
4671
4672 if (ok()) {
4673 CheckError(parent_->Memset32(this, location, pattern, size));
4674 } else {
4675 LOG(INFO) << "stream " << this
4676 << " did not memset GPU location; source: " << location
4677 << "; size: " << size << "; pattern: " << std::hex << pattern;
4678 }
4679 return *this;
4680 }
4681
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4682 Stream &Stream::ThenRnnForward(
4683 const dnn::RnnDescriptor &rnn_desc,
4684 const dnn::RnnSequenceTensorDescriptor &input_desc,
4685 const DeviceMemory<Eigen::half> &input_data,
4686 const dnn::RnnStateTensorDescriptor &input_h_desc,
4687 const DeviceMemory<Eigen::half> &input_h_data,
4688 const dnn::RnnStateTensorDescriptor &input_c_desc,
4689 const DeviceMemory<Eigen::half> &input_c_data,
4690 const DeviceMemory<Eigen::half> ¶ms,
4691 const dnn::RnnSequenceTensorDescriptor &output_desc,
4692 DeviceMemory<Eigen::half> *output_data,
4693 const dnn::RnnStateTensorDescriptor &output_h_desc,
4694 DeviceMemory<Eigen::half> *output_h_data,
4695 const dnn::RnnStateTensorDescriptor &output_c_desc,
4696 DeviceMemory<Eigen::half> *output_c_data, bool is_training,
4697 ScratchAllocator *reserve_space_allocator,
4698 ScratchAllocator *workspace_allocator) {
4699 // TODO(zhengxq): add VLOG PARAM calls.
4700 if (ok()) {
4701 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4702 CheckError(dnn->DoRnnForward(
4703 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4704 input_c_desc, input_c_data, params, output_desc, output_data,
4705 output_h_desc, output_h_data, output_c_desc, output_c_data,
4706 is_training, reserve_space_allocator, workspace_allocator));
4707 } else {
4708 SetError();
4709 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4710 }
4711 }
4712 return *this;
4713 }
4714
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4715 Stream &Stream::ThenRnnForward(
4716 const dnn::RnnDescriptor &rnn_desc,
4717 const dnn::RnnSequenceTensorDescriptor &input_desc,
4718 const DeviceMemory<float> &input_data,
4719 const dnn::RnnStateTensorDescriptor &input_h_desc,
4720 const DeviceMemory<float> &input_h_data,
4721 const dnn::RnnStateTensorDescriptor &input_c_desc,
4722 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4723 const dnn::RnnSequenceTensorDescriptor &output_desc,
4724 DeviceMemory<float> *output_data,
4725 const dnn::RnnStateTensorDescriptor &output_h_desc,
4726 DeviceMemory<float> *output_h_data,
4727 const dnn::RnnStateTensorDescriptor &output_c_desc,
4728 DeviceMemory<float> *output_c_data, bool is_training,
4729 ScratchAllocator *reserve_space_allocator,
4730 ScratchAllocator *workspace_allocator) {
4731 // TODO(zhengxq): add VLOG PARAM calls.
4732 if (ok()) {
4733 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4734 CheckError(dnn->DoRnnForward(
4735 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4736 input_c_desc, input_c_data, params, output_desc, output_data,
4737 output_h_desc, output_h_data, output_c_desc, output_c_data,
4738 is_training, reserve_space_allocator, workspace_allocator));
4739 } else {
4740 SetError();
4741 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4742 }
4743 }
4744 return *this;
4745 }
4746
ThenRnnForward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)4747 Stream &Stream::ThenRnnForward(
4748 const dnn::RnnDescriptor &rnn_desc,
4749 const dnn::RnnSequenceTensorDescriptor &input_desc,
4750 const DeviceMemory<double> &input_data,
4751 const dnn::RnnStateTensorDescriptor &input_h_desc,
4752 const DeviceMemory<double> &input_h_data,
4753 const dnn::RnnStateTensorDescriptor &input_c_desc,
4754 const DeviceMemory<double> &input_c_data,
4755 const DeviceMemory<double> ¶ms,
4756 const dnn::RnnSequenceTensorDescriptor &output_desc,
4757 DeviceMemory<double> *output_data,
4758 const dnn::RnnStateTensorDescriptor &output_h_desc,
4759 DeviceMemory<double> *output_h_data,
4760 const dnn::RnnStateTensorDescriptor &output_c_desc,
4761 DeviceMemory<double> *output_c_data, bool is_training,
4762 ScratchAllocator *reserve_space_allocator,
4763 ScratchAllocator *workspace_allocator) {
4764 // TODO(zhengxq): add VLOG PARAM calls.
4765 if (ok()) {
4766 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4767 CheckError(dnn->DoRnnForward(
4768 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4769 input_c_desc, input_c_data, params, output_desc, output_data,
4770 output_h_desc, output_h_data, output_c_desc, output_c_data,
4771 is_training, reserve_space_allocator, workspace_allocator));
4772 } else {
4773 SetError();
4774 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
4775 }
4776 }
4777 return *this;
4778 }
4779
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4780 Stream &Stream::ThenRnnBackward(
4781 const dnn::RnnDescriptor &rnn_desc,
4782 const dnn::RnnSequenceTensorDescriptor &input_desc,
4783 const DeviceMemory<Eigen::half> &input_data,
4784 const dnn::RnnStateTensorDescriptor &input_h_desc,
4785 const DeviceMemory<Eigen::half> &input_h_data,
4786 const dnn::RnnStateTensorDescriptor &input_c_desc,
4787 const DeviceMemory<Eigen::half> &input_c_data,
4788 const DeviceMemory<Eigen::half> ¶ms,
4789 const dnn::RnnSequenceTensorDescriptor &output_desc,
4790 const DeviceMemory<Eigen::half> &output_data,
4791 const dnn::RnnStateTensorDescriptor &output_h_desc,
4792 const DeviceMemory<Eigen::half> &output_h_data,
4793 const dnn::RnnStateTensorDescriptor &output_c_desc,
4794 const DeviceMemory<Eigen::half> &output_c_data,
4795 const DeviceMemory<Eigen::half> &output_backprop_data,
4796 const DeviceMemory<Eigen::half> &output_h_backprop_data,
4797 const DeviceMemory<Eigen::half> &output_c_backprop_data,
4798 DeviceMemory<Eigen::half> *input_backprop_data,
4799 DeviceMemory<Eigen::half> *input_h_backprop_data,
4800 DeviceMemory<Eigen::half> *input_c_backprop_data,
4801 DeviceMemory<Eigen::half> *params_backprop_data,
4802 DeviceMemory<uint8> *reserve_space_data,
4803 ScratchAllocator *workspace_allocator) {
4804 // TODO(zhengxq): add VLOG PARAM calls.
4805 if (ok()) {
4806 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4807 CheckError(dnn->DoRnnBackward(
4808 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4809 input_c_desc, input_c_data, params, output_desc, output_data,
4810 output_h_desc, output_h_data, output_c_desc, output_c_data,
4811 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4812 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4813 params_backprop_data, reserve_space_data, workspace_allocator));
4814 } else {
4815 SetError();
4816 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4817 }
4818 }
4819 return *this;
4820 }
4821
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4822 Stream &Stream::ThenRnnBackward(
4823 const dnn::RnnDescriptor &rnn_desc,
4824 const dnn::RnnSequenceTensorDescriptor &input_desc,
4825 const DeviceMemory<float> &input_data,
4826 const dnn::RnnStateTensorDescriptor &input_h_desc,
4827 const DeviceMemory<float> &input_h_data,
4828 const dnn::RnnStateTensorDescriptor &input_c_desc,
4829 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms,
4830 const dnn::RnnSequenceTensorDescriptor &output_desc,
4831 const DeviceMemory<float> &output_data,
4832 const dnn::RnnStateTensorDescriptor &output_h_desc,
4833 const DeviceMemory<float> &output_h_data,
4834 const dnn::RnnStateTensorDescriptor &output_c_desc,
4835 const DeviceMemory<float> &output_c_data,
4836 const DeviceMemory<float> &output_backprop_data,
4837 const DeviceMemory<float> &output_h_backprop_data,
4838 const DeviceMemory<float> &output_c_backprop_data,
4839 DeviceMemory<float> *input_backprop_data,
4840 DeviceMemory<float> *input_h_backprop_data,
4841 DeviceMemory<float> *input_c_backprop_data,
4842 DeviceMemory<float> *params_backprop_data,
4843 DeviceMemory<uint8> *reserve_space_data,
4844 ScratchAllocator *workspace_allocator) {
4845 // TODO(zhengxq): add VLOG PARAM calls.
4846 if (ok()) {
4847 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4848 CheckError(dnn->DoRnnBackward(
4849 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4850 input_c_desc, input_c_data, params, output_desc, output_data,
4851 output_h_desc, output_h_data, output_c_desc, output_c_data,
4852 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4853 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4854 params_backprop_data, reserve_space_data, workspace_allocator));
4855 } else {
4856 SetError();
4857 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4858 }
4859 }
4860 return *this;
4861 }
4862
ThenRnnBackward(const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)4863 Stream &Stream::ThenRnnBackward(
4864 const dnn::RnnDescriptor &rnn_desc,
4865 const dnn::RnnSequenceTensorDescriptor &input_desc,
4866 const DeviceMemory<double> &input_data,
4867 const dnn::RnnStateTensorDescriptor &input_h_desc,
4868 const DeviceMemory<double> &input_h_data,
4869 const dnn::RnnStateTensorDescriptor &input_c_desc,
4870 const DeviceMemory<double> &input_c_data,
4871 const DeviceMemory<double> ¶ms,
4872 const dnn::RnnSequenceTensorDescriptor &output_desc,
4873 const DeviceMemory<double> &output_data,
4874 const dnn::RnnStateTensorDescriptor &output_h_desc,
4875 const DeviceMemory<double> &output_h_data,
4876 const dnn::RnnStateTensorDescriptor &output_c_desc,
4877 const DeviceMemory<double> &output_c_data,
4878 const DeviceMemory<double> &output_backprop_data,
4879 const DeviceMemory<double> &output_h_backprop_data,
4880 const DeviceMemory<double> &output_c_backprop_data,
4881 DeviceMemory<double> *input_backprop_data,
4882 DeviceMemory<double> *input_h_backprop_data,
4883 DeviceMemory<double> *input_c_backprop_data,
4884 DeviceMemory<double> *params_backprop_data,
4885 DeviceMemory<uint8> *reserve_space_data,
4886 ScratchAllocator *workspace_allocator) {
4887 // TODO(zhengxq): add VLOG PARAM calls.
4888 if (ok()) {
4889 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4890 CheckError(dnn->DoRnnBackward(
4891 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
4892 input_c_desc, input_c_data, params, output_desc, output_data,
4893 output_h_desc, output_h_data, output_c_desc, output_c_data,
4894 output_backprop_data, output_h_backprop_data, output_c_backprop_data,
4895 input_backprop_data, input_h_backprop_data, input_c_backprop_data,
4896 params_backprop_data, reserve_space_data, workspace_allocator));
4897 } else {
4898 SetError();
4899 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
4900 }
4901 }
4902 return *this;
4903 }
4904
ThenTransformTensor(const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)4905 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
4906 dnn::DataType input_type,
4907 const DeviceMemoryBase &input_data,
4908 const dnn::BatchDescriptor &output_desc,
4909 dnn::DataType output_type, float scale,
4910 DeviceMemoryBase *output_data) {
4911 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
4912 PARAM(output_desc), PARAM(output_type), PARAM(scale),
4913 PARAM(output_data));
4914 if (ok()) {
4915 if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
4916 CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
4917 input_data, output_desc, output_type,
4918 scale, output_data));
4919 } else {
4920 SetErrorAndLogNoDnnSupport();
4921 }
4922 }
4923 return *this;
4924 }
4925
ThenDoHostCallbackForTest(std::function<void ()> callback)4926 Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
4927 VLOG_CALL(PARAM(callback));
4928
4929 return ThenDoHostCallback(callback);
4930 }
4931
ThenDoHostCallback(std::function<void ()> callback)4932 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
4933 VLOG_CALL(PARAM(callback));
4934
4935 if (ok()) {
4936 CheckError(parent_->HostCallback(this, callback));
4937 } else {
4938 LOG(INFO) << "stream " << this
4939 << " was in error state before adding host callback";
4940 }
4941 return *this;
4942 }
4943
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<std::complex<float>> * output)4944 Stream &Stream::ThenFft(fft::Plan *plan,
4945 const DeviceMemory<std::complex<float>> &input,
4946 DeviceMemory<std::complex<float>> *output) {
4947 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4948
4949 if (ok()) {
4950 if (fft::FftSupport *fft = parent_->AsFft()) {
4951 CheckError(fft->DoFft(this, plan, input, output));
4952 } else {
4953 SetError();
4954 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4955 "without FFT support";
4956 }
4957 }
4958 return *this;
4959 }
4960
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<std::complex<double>> * output)4961 Stream &Stream::ThenFft(fft::Plan *plan,
4962 const DeviceMemory<std::complex<double>> &input,
4963 DeviceMemory<std::complex<double>> *output) {
4964 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4965
4966 if (ok()) {
4967 if (fft::FftSupport *fft = parent_->AsFft()) {
4968 CheckError(fft->DoFft(this, plan, input, output));
4969 } else {
4970 SetError();
4971 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4972 "without FFT support";
4973 }
4974 }
4975 return *this;
4976 }
4977
ThenFft(fft::Plan * plan,const DeviceMemory<float> & input,DeviceMemory<std::complex<float>> * output)4978 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
4979 DeviceMemory<std::complex<float>> *output) {
4980 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4981
4982 if (ok()) {
4983 if (fft::FftSupport *fft = parent_->AsFft()) {
4984 CheckError(fft->DoFft(this, plan, input, output));
4985 } else {
4986 SetError();
4987 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
4988 "without FFT support";
4989 }
4990 }
4991 return *this;
4992 }
4993
ThenFft(fft::Plan * plan,const DeviceMemory<double> & input,DeviceMemory<std::complex<double>> * output)4994 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
4995 DeviceMemory<std::complex<double>> *output) {
4996 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
4997
4998 if (ok()) {
4999 if (fft::FftSupport *fft = parent_->AsFft()) {
5000 CheckError(fft->DoFft(this, plan, input, output));
5001 } else {
5002 SetError();
5003 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5004 "without FFT support";
5005 }
5006 }
5007 return *this;
5008 }
5009
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<float>> & input,DeviceMemory<float> * output)5010 Stream &Stream::ThenFft(fft::Plan *plan,
5011 const DeviceMemory<std::complex<float>> &input,
5012 DeviceMemory<float> *output) {
5013 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5014
5015 if (ok()) {
5016 if (fft::FftSupport *fft = parent_->AsFft()) {
5017 CheckError(fft->DoFft(this, plan, input, output));
5018 } else {
5019 SetError();
5020 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5021 "without FFT support";
5022 }
5023 }
5024 return *this;
5025 }
5026
ThenFft(fft::Plan * plan,const DeviceMemory<std::complex<double>> & input,DeviceMemory<double> * output)5027 Stream &Stream::ThenFft(fft::Plan *plan,
5028 const DeviceMemory<std::complex<double>> &input,
5029 DeviceMemory<double> *output) {
5030 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
5031
5032 if (ok()) {
5033 if (fft::FftSupport *fft = parent_->AsFft()) {
5034 CheckError(fft->DoFft(this, plan, input, output));
5035 } else {
5036 SetError();
5037 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
5038 "without FFT support";
5039 }
5040 }
5041 return *this;
5042 }
5043
5044 // It looks confusing, but all this is doing is inserting a callback at the
5045 // present point in the stream to then enqueue a task on the host executor.
ThenEnqueueOnBackgroundThread(std::function<void (StreamExecutor *)> task)5046 Stream &Stream::ThenEnqueueOnBackgroundThread(
5047 std::function<void(StreamExecutor *)> task) {
5048 VLOG_CALL(PARAM(task));
5049
5050 StreamExecutor *stream_executor = this->parent_;
5051 std::function<void()> bound_task = std::bind(task, stream_executor);
5052
5053 return ThenDoHostCallback([stream_executor, bound_task]() {
5054 stream_executor->EnqueueOnBackgroundThread(bound_task);
5055 });
5056 }
5057
BlockHostUntilDone()5058 port::Status Stream::BlockHostUntilDone() {
5059 VLOG_CALL();
5060
5061 if (!ok()) {
5062 port::Status status = port::Status(
5063 port::error::INTERNAL,
5064 "stream did not block host until done; was already in an error state");
5065 LOG(INFO) << status << " " << this;
5066 return status;
5067 }
5068
5069 port::Status first_error;
5070 {
5071 // Wait until all active sub-streams have done their tasks.
5072 mutex_lock lock{mu_};
5073 for (auto &stream : sub_streams_) {
5074 if (!stream.second) {
5075 first_error.Update(stream.first->BlockHostUntilDone());
5076 // Set this sub-stream as available.
5077 stream.second = true;
5078 }
5079 }
5080 }
5081
5082 temporary_memory_manager_.DeallocateFinalizedTemporaries();
5083
5084 first_error.Update(parent_->BlockHostUntilDone(this));
5085 CheckError(first_error.ok());
5086 return first_error;
5087 }
5088
5089 } // namespace gputools
5090 } // namespace perftools
5091