1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
17
18 namespace xla {
19 namespace gpu {
20
ExecuteParams(const ServiceExecutableRunOptions & run_options,const BufferAllocations & buffer_allocations,se::Stream * stream,se::Stream * async_comms_stream)21 Thunk::ExecuteParams::ExecuteParams(
22 const ServiceExecutableRunOptions& run_options,
23 const BufferAllocations& buffer_allocations, se::Stream* stream,
24 se::Stream* async_comms_stream)
25 : buffer_allocations(&buffer_allocations),
26 stream(stream),
27 async_comms_stream(async_comms_stream),
28 nccl_params(run_options, stream) {}
29
KindToString(Thunk::Kind kind)30 /*static*/ absl::string_view Thunk::KindToString(Thunk::Kind kind) {
31 switch (kind) {
32 case Thunk::kCholesky:
33 return "kCholesky";
34 case Thunk::kCollectivePermute:
35 return "kCollectivePermute";
36 case Thunk::kConditional:
37 return "kConditional";
38 case Thunk::kConvolution:
39 return "kConvolution";
40 case Thunk::kCopy:
41 return "kCopy";
42 case Thunk::kCublasLtMatmul:
43 return "kCublasLtMatmul";
44 case Thunk::kCustomCall:
45 return "kCustomCall";
46 case Thunk::kNcclAllGather:
47 return "kNcclAllGather";
48 case Thunk::kNcclAllReduce:
49 return "kNcclAllReduce";
50 case Thunk::kNcclAllReduceStart:
51 return "kNcclAllReduceStart";
52 case Thunk::kNcclAllReduceDone:
53 return "kNcclAllReduceDone";
54 case Thunk::kNcclReduceScatter:
55 return "kNcclReduceScatter";
56 case Thunk::kNcclAllToAll:
57 return "kNcclAllToAll";
58 case Thunk::kFft:
59 return "kFft";
60 case Thunk::kGemm:
61 return "kGemm";
62 case Thunk::kInfeed:
63 return "kInfeed";
64 case Thunk::kKernel:
65 return "kKernel";
66 case Thunk::kMemset32BitValue:
67 return "kMemset32BitValue";
68 case Thunk::kMemzero:
69 return "kMemzero";
70 case Thunk::kOutfeed:
71 return "kOutfeed";
72 case Thunk::kReplicaId:
73 return "kReplicaId";
74 case Thunk::kPartitionId:
75 return "kPartitionId";
76 case Thunk::kSequential:
77 return "kSequential";
78 case Thunk::kTriangularSolve:
79 return "kTriangularSolve";
80 case Thunk::kWhile:
81 return "kWhile";
82 }
83 }
84
operator <<(std::ostream & os,Thunk::Kind kind)85 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
86 return os << Thunk::KindToString(kind);
87 }
88
ToString(int indent,std::function<std::string (const Thunk *)> get_thunk_annotation) const89 std::string ThunkSequence::ToString(
90 int indent,
91 std::function<std::string(const Thunk*)> get_thunk_annotation) const {
92 const std::string indent_str(indent * 2, ' ');
93 if (empty()) return indent_str + "No thunks.";
94
95 auto thunk_with_longest_kind = absl::c_max_element(
96 *this,
97 [](const std::unique_ptr<Thunk>& a, const std::unique_ptr<Thunk>& b) {
98 return Thunk::KindToString(a->kind()).length() <
99 Thunk::KindToString(b->kind()).length();
100 });
101 int64_t max_thunk_kind_len =
102 Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length();
103 std::string result;
104 for (const std::unique_ptr<Thunk>& thunk : *this) {
105 // Write out the thunk kind, padded out to max_thunk_kind_len.
106 absl::string_view kind_str = Thunk::KindToString(thunk->kind());
107 absl::StrAppend(&result, indent_str, kind_str,
108 std::string(max_thunk_kind_len - kind_str.length(), ' '),
109 "\t");
110 if (get_thunk_annotation) {
111 absl::StrAppend(&result, get_thunk_annotation(thunk.get()));
112 }
113 absl::StrAppend(&result, thunk->ToStringExtra(indent));
114 absl::StrAppend(&result, "\n");
115 }
116 return result;
117 }
118
119 } // namespace gpu
120 } // namespace xla
121