1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/util.h"
17
18 #include <stdarg.h>
19 #include <numeric>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/stacktrace.h"
32
33 namespace xla {
34
WithLogBacktrace(const Status & status)35 Status WithLogBacktrace(const Status& status) {
36 CHECK(!status.ok());
37 VLOG(1) << status.ToString();
38 VLOG(1) << tensorflow::CurrentStackTrace();
39 return status;
40 }
41
ScopedLoggingTimer(const string & label,bool enabled)42 ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
43 : enabled(enabled), label(label) {
44 if (enabled) {
45 start_micros = tensorflow::Env::Default()->NowMicros();
46 }
47 }
48
~ScopedLoggingTimer()49 ScopedLoggingTimer::~ScopedLoggingTimer() {
50 if (enabled) {
51 uint64 end_micros = tensorflow::Env::Default()->NowMicros();
52 double secs = (end_micros - start_micros) / 1000000.0;
53
54 LOG(INFO) << label << " time: "
55 << tensorflow::strings::HumanReadableElapsedTime(secs);
56 }
57 }
58
AddStatus(Status prior,absl::string_view context)59 Status AddStatus(Status prior, absl::string_view context) {
60 CHECK(!prior.ok());
61 return Status{prior.code(),
62 absl::StrCat(context, ": ", prior.error_message())};
63 }
64
AppendStatus(Status prior,absl::string_view context)65 Status AppendStatus(Status prior, absl::string_view context) {
66 CHECK(!prior.ok());
67 return Status{prior.code(),
68 absl::StrCat(prior.error_message(), ": ", context)};
69 }
70
Reindent(absl::string_view original,const absl::string_view indentation)71 string Reindent(absl::string_view original,
72 const absl::string_view indentation) {
73 std::vector<string> pieces =
74 absl::StrSplit(absl::string_view(original.data(), original.size()), '\n');
75 return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) {
76 absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s));
77 });
78 }
79
IsPermutation(absl::Span<const int64> permutation,int64 rank)80 bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
81 if (rank != permutation.size()) {
82 return false;
83 }
84 absl::InlinedVector<int64, 8> trivial_permutation(rank);
85 absl::c_iota(trivial_permutation, 0);
86 return absl::c_is_permutation(permutation, trivial_permutation);
87 }
88
InversePermutation(absl::Span<const int64> input_permutation)89 std::vector<int64> InversePermutation(
90 absl::Span<const int64> input_permutation) {
91 DCHECK(IsPermutation(input_permutation, input_permutation.size()));
92 std::vector<int64> output_permutation(input_permutation.size(), -1);
93 for (size_t i = 0; i < input_permutation.size(); ++i) {
94 output_permutation[input_permutation[i]] = i;
95 }
96 return output_permutation;
97 }
98
ComposePermutations(absl::Span<const int64> p1,absl::Span<const int64> p2)99 std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
100 absl::Span<const int64> p2) {
101 CHECK_EQ(p1.size(), p2.size());
102 std::vector<int64> output;
103 for (size_t i = 0; i < p1.size(); ++i) {
104 output.push_back(p1[p2[i]]);
105 }
106 return output;
107 }
108
IsIdentityPermutation(absl::Span<const int64> permutation)109 bool IsIdentityPermutation(absl::Span<const int64> permutation) {
110 for (int64 i = 0; i < permutation.size(); ++i) {
111 if (permutation[i] != i) {
112 return false;
113 }
114 }
115 return true;
116 }
117
MakeNoPaddingConfig(int64 rank)118 PaddingConfig MakeNoPaddingConfig(int64 rank) {
119 PaddingConfig padding_config;
120 for (int64 dnum = 0; dnum < rank; ++dnum) {
121 auto dimension = padding_config.add_dimensions();
122 dimension->set_edge_padding_low(0);
123 dimension->set_edge_padding_high(0);
124 dimension->set_interior_padding(0);
125 }
126 return padding_config;
127 }
128
MakeEdgePaddingConfig(absl::Span<const std::pair<int64,int64>> padding)129 PaddingConfig MakeEdgePaddingConfig(
130 absl::Span<const std::pair<int64, int64>> padding) {
131 PaddingConfig padding_config;
132 for (const std::pair<int64, int64>& dim : padding) {
133 auto dimension = padding_config.add_dimensions();
134 dimension->set_edge_padding_low(dim.first);
135 dimension->set_edge_padding_high(dim.second);
136 dimension->set_interior_padding(0);
137 }
138 return padding_config;
139 }
140
HasInteriorPadding(const PaddingConfig & config)141 bool HasInteriorPadding(const PaddingConfig& config) {
142 for (const auto& dim : config.dimensions()) {
143 if (dim.interior_padding() != 0) {
144 return true;
145 }
146 }
147 return false;
148 }
149
150 namespace {
HumanReadableNumOps(double flops,double nanoseconds,absl::string_view op_prefix)151 string HumanReadableNumOps(double flops, double nanoseconds,
152 absl::string_view op_prefix) {
153 if (nanoseconds == 0) {
154 return absl::StrCat("NaN ", op_prefix, "OP/s");
155 }
156 double nano_flops = flops / nanoseconds;
157 string throughput = tensorflow::strings::HumanReadableNum(
158 static_cast<int64>(nano_flops * 1e9));
159 absl::string_view sp(throughput);
160 // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
161 if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case
162 absl::EndsWith(sp, "b")) {
163 *throughput.rbegin() = 'G';
164 }
165 throughput += absl::StrCat(op_prefix, "OP/s");
166 return throughput;
167 }
168 } // namespace
169
HumanReadableNumFlops(double flops,double nanoseconds)170 string HumanReadableNumFlops(double flops, double nanoseconds) {
171 return HumanReadableNumOps(flops, nanoseconds, "FL");
172 }
173
HumanReadableNumTranscendentalOps(double trops,double nanoseconds)174 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
175 return HumanReadableNumOps(trops, nanoseconds, "TR");
176 }
177
LogLines(int sev,absl::string_view text,const char * fname,int lineno)178 void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
179 const int orig_sev = sev;
180 if (sev == tensorflow::FATAL) {
181 sev = tensorflow::ERROR;
182 }
183
184 // Protect calls with a mutex so we don't interleave calls to LogLines from
185 // multiple threads.
186 static tensorflow::mutex log_lines_mu(tensorflow::LINKER_INITIALIZED);
187 tensorflow::mutex_lock lock(log_lines_mu);
188
189 size_t cur = 0;
190 while (cur < text.size()) {
191 size_t eol = text.find('\n', cur);
192 if (eol == absl::string_view::npos) {
193 eol = text.size();
194 }
195 auto msg = text.substr(cur, eol - cur);
196 tensorflow::internal::LogString(fname, lineno, sev,
197 string(msg.data(), msg.size()));
198 cur = eol + 1;
199 }
200
201 if (orig_sev == tensorflow::FATAL) {
202 tensorflow::internal::LogString(fname, lineno, orig_sev,
203 "Aborting due to errors.");
204 }
205 }
206
Product(absl::Span<const int64> xs)207 int64 Product(absl::Span<const int64> xs) {
208 return std::accumulate(xs.begin(), xs.end(), static_cast<int64>(1),
209 std::multiplies<int64>());
210 }
211
CommonFactors(absl::Span<const int64> a,absl::Span<const int64> b)212 std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
213 absl::Span<const int64> b) {
214 CHECK_EQ(Product(a), Product(b));
215 if (0 == Product(a)) {
216 return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
217 }
218
219 std::vector<std::pair<int64, int64>> bounds;
220 for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
221 partial_size_b = 1;
222 ;) {
223 if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
224 std::tie(prior_i, prior_j) = std::make_pair(i, j);
225 bounds.emplace_back(i, j);
226 continue;
227 }
228 bool in_bounds_i = i < a.size();
229 bool in_bounds_j = j < b.size();
230 if (!(in_bounds_i || in_bounds_j)) {
231 break;
232 }
233 bool next_a =
234 partial_size_a < partial_size_b ||
235 (in_bounds_i &&
236 (!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
237 bool next_b =
238 partial_size_b < partial_size_a ||
239 (in_bounds_j &&
240 (!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
241 if (next_a) {
242 partial_size_a *= a[i];
243 ++i;
244 }
245 if (next_b) {
246 partial_size_b *= b[j];
247 ++j;
248 }
249 }
250 return bounds;
251 }
252
SanitizeFileName(string file_name)253 string SanitizeFileName(string file_name) {
254 for (char& c : file_name) {
255 if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
256 c = '_';
257 }
258 }
259 return file_name;
260 }
261
262 } // namespace xla
263