• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
17 
18 namespace xla {
19 namespace gpu {
20 
ExecuteParams(const ServiceExecutableRunOptions & run_options,const BufferAllocations & buffer_allocations,se::Stream * stream,se::Stream * async_comms_stream)21 Thunk::ExecuteParams::ExecuteParams(
22     const ServiceExecutableRunOptions& run_options,
23     const BufferAllocations& buffer_allocations, se::Stream* stream,
24     se::Stream* async_comms_stream)
25     : buffer_allocations(&buffer_allocations),
26       stream(stream),
27       async_comms_stream(async_comms_stream),
28       nccl_params(run_options, stream) {}
29 
KindToString(Thunk::Kind kind)30 /*static*/ absl::string_view Thunk::KindToString(Thunk::Kind kind) {
31   switch (kind) {
32     case Thunk::kCholesky:
33       return "kCholesky";
34     case Thunk::kCollectivePermute:
35       return "kCollectivePermute";
36     case Thunk::kConditional:
37       return "kConditional";
38     case Thunk::kConvolution:
39       return "kConvolution";
40     case Thunk::kCopy:
41       return "kCopy";
42     case Thunk::kCublasLtMatmul:
43       return "kCublasLtMatmul";
44     case Thunk::kCustomCall:
45       return "kCustomCall";
46     case Thunk::kNcclAllGather:
47       return "kNcclAllGather";
48     case Thunk::kNcclAllReduce:
49       return "kNcclAllReduce";
50     case Thunk::kNcclAllReduceStart:
51       return "kNcclAllReduceStart";
52     case Thunk::kNcclAllReduceDone:
53       return "kNcclAllReduceDone";
54     case Thunk::kNcclReduceScatter:
55       return "kNcclReduceScatter";
56     case Thunk::kNcclAllToAll:
57       return "kNcclAllToAll";
58     case Thunk::kFft:
59       return "kFft";
60     case Thunk::kGemm:
61       return "kGemm";
62     case Thunk::kInfeed:
63       return "kInfeed";
64     case Thunk::kKernel:
65       return "kKernel";
66     case Thunk::kMemset32BitValue:
67       return "kMemset32BitValue";
68     case Thunk::kMemzero:
69       return "kMemzero";
70     case Thunk::kOutfeed:
71       return "kOutfeed";
72     case Thunk::kReplicaId:
73       return "kReplicaId";
74     case Thunk::kPartitionId:
75       return "kPartitionId";
76     case Thunk::kSequential:
77       return "kSequential";
78     case Thunk::kTriangularSolve:
79       return "kTriangularSolve";
80     case Thunk::kWhile:
81       return "kWhile";
82   }
83 }
84 
operator <<(std::ostream & os,Thunk::Kind kind)85 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind) {
86   return os << Thunk::KindToString(kind);
87 }
88 
ToString(int indent,std::function<std::string (const Thunk *)> get_thunk_annotation) const89 std::string ThunkSequence::ToString(
90     int indent,
91     std::function<std::string(const Thunk*)> get_thunk_annotation) const {
92   const std::string indent_str(indent * 2, ' ');
93   if (empty()) return indent_str + "No thunks.";
94 
95   auto thunk_with_longest_kind = absl::c_max_element(
96       *this,
97       [](const std::unique_ptr<Thunk>& a, const std::unique_ptr<Thunk>& b) {
98         return Thunk::KindToString(a->kind()).length() <
99                Thunk::KindToString(b->kind()).length();
100       });
101   int64_t max_thunk_kind_len =
102       Thunk::KindToString(thunk_with_longest_kind->get()->kind()).length();
103   std::string result;
104   for (const std::unique_ptr<Thunk>& thunk : *this) {
105     // Write out the thunk kind, padded out to max_thunk_kind_len.
106     absl::string_view kind_str = Thunk::KindToString(thunk->kind());
107     absl::StrAppend(&result, indent_str, kind_str,
108                     std::string(max_thunk_kind_len - kind_str.length(), ' '),
109                     "\t");
110     if (get_thunk_annotation) {
111       absl::StrAppend(&result, get_thunk_annotation(thunk.get()));
112     }
113     absl::StrAppend(&result, thunk->ToStringExtra(indent));
114     absl::StrAppend(&result, "\n");
115   }
116   return result;
117 }
118 
119 }  // namespace gpu
120 }  // namespace xla
121