• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/util.h"
17 
18 #include <stdarg.h>
19 #include <numeric>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/platform/env.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/stacktrace.h"
32 
33 namespace xla {
34 
WithLogBacktrace(const Status & status)35 Status WithLogBacktrace(const Status& status) {
36   CHECK(!status.ok());
37   VLOG(1) << status.ToString();
38   VLOG(1) << tensorflow::CurrentStackTrace();
39   return status;
40 }
41 
ScopedLoggingTimer(const string & label,bool enabled)42 ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
43     : enabled(enabled), label(label) {
44   if (enabled) {
45     start_micros = tensorflow::Env::Default()->NowMicros();
46   }
47 }
48 
~ScopedLoggingTimer()49 ScopedLoggingTimer::~ScopedLoggingTimer() {
50   if (enabled) {
51     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
52     double secs = (end_micros - start_micros) / 1000000.0;
53 
54     LOG(INFO) << label << " time: "
55               << tensorflow::strings::HumanReadableElapsedTime(secs);
56   }
57 }
58 
AddStatus(Status prior,absl::string_view context)59 Status AddStatus(Status prior, absl::string_view context) {
60   CHECK(!prior.ok());
61   return Status{prior.code(),
62                 absl::StrCat(context, ": ", prior.error_message())};
63 }
64 
AppendStatus(Status prior,absl::string_view context)65 Status AppendStatus(Status prior, absl::string_view context) {
66   CHECK(!prior.ok());
67   return Status{prior.code(),
68                 absl::StrCat(prior.error_message(), ": ", context)};
69 }
70 
Reindent(absl::string_view original,const absl::string_view indentation)71 string Reindent(absl::string_view original,
72                 const absl::string_view indentation) {
73   std::vector<string> pieces =
74       absl::StrSplit(absl::string_view(original.data(), original.size()), '\n');
75   return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) {
76     absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s));
77   });
78 }
79 
IsPermutation(absl::Span<const int64> permutation,int64 rank)80 bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
81   if (rank != permutation.size()) {
82     return false;
83   }
84   absl::InlinedVector<int64, 8> trivial_permutation(rank);
85   absl::c_iota(trivial_permutation, 0);
86   return absl::c_is_permutation(permutation, trivial_permutation);
87 }
88 
InversePermutation(absl::Span<const int64> input_permutation)89 std::vector<int64> InversePermutation(
90     absl::Span<const int64> input_permutation) {
91   DCHECK(IsPermutation(input_permutation, input_permutation.size()));
92   std::vector<int64> output_permutation(input_permutation.size(), -1);
93   for (size_t i = 0; i < input_permutation.size(); ++i) {
94     output_permutation[input_permutation[i]] = i;
95   }
96   return output_permutation;
97 }
98 
ComposePermutations(absl::Span<const int64> p1,absl::Span<const int64> p2)99 std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
100                                        absl::Span<const int64> p2) {
101   CHECK_EQ(p1.size(), p2.size());
102   std::vector<int64> output;
103   for (size_t i = 0; i < p1.size(); ++i) {
104     output.push_back(p1[p2[i]]);
105   }
106   return output;
107 }
108 
IsIdentityPermutation(absl::Span<const int64> permutation)109 bool IsIdentityPermutation(absl::Span<const int64> permutation) {
110   for (int64 i = 0; i < permutation.size(); ++i) {
111     if (permutation[i] != i) {
112       return false;
113     }
114   }
115   return true;
116 }
117 
MakeNoPaddingConfig(int64 rank)118 PaddingConfig MakeNoPaddingConfig(int64 rank) {
119   PaddingConfig padding_config;
120   for (int64 dnum = 0; dnum < rank; ++dnum) {
121     auto dimension = padding_config.add_dimensions();
122     dimension->set_edge_padding_low(0);
123     dimension->set_edge_padding_high(0);
124     dimension->set_interior_padding(0);
125   }
126   return padding_config;
127 }
128 
MakeEdgePaddingConfig(absl::Span<const std::pair<int64,int64>> padding)129 PaddingConfig MakeEdgePaddingConfig(
130     absl::Span<const std::pair<int64, int64>> padding) {
131   PaddingConfig padding_config;
132   for (const std::pair<int64, int64>& dim : padding) {
133     auto dimension = padding_config.add_dimensions();
134     dimension->set_edge_padding_low(dim.first);
135     dimension->set_edge_padding_high(dim.second);
136     dimension->set_interior_padding(0);
137   }
138   return padding_config;
139 }
140 
HasInteriorPadding(const PaddingConfig & config)141 bool HasInteriorPadding(const PaddingConfig& config) {
142   for (const auto& dim : config.dimensions()) {
143     if (dim.interior_padding() != 0) {
144       return true;
145     }
146   }
147   return false;
148 }
149 
150 namespace {
HumanReadableNumOps(double flops,double nanoseconds,absl::string_view op_prefix)151 string HumanReadableNumOps(double flops, double nanoseconds,
152                            absl::string_view op_prefix) {
153   if (nanoseconds == 0) {
154     return absl::StrCat("NaN ", op_prefix, "OP/s");
155   }
156   double nano_flops = flops / nanoseconds;
157   string throughput = tensorflow::strings::HumanReadableNum(
158       static_cast<int64>(nano_flops * 1e9));
159   absl::string_view sp(throughput);
160   // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
161   if (absl::EndsWith(sp, "B") ||  // Ends in 'B', ignoring case
162       absl::EndsWith(sp, "b")) {
163     *throughput.rbegin() = 'G';
164   }
165   throughput += absl::StrCat(op_prefix, "OP/s");
166   return throughput;
167 }
168 }  // namespace
169 
HumanReadableNumFlops(double flops,double nanoseconds)170 string HumanReadableNumFlops(double flops, double nanoseconds) {
171   return HumanReadableNumOps(flops, nanoseconds, "FL");
172 }
173 
HumanReadableNumTranscendentalOps(double trops,double nanoseconds)174 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
175   return HumanReadableNumOps(trops, nanoseconds, "TR");
176 }
177 
LogLines(int sev,absl::string_view text,const char * fname,int lineno)178 void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
179   const int orig_sev = sev;
180   if (sev == tensorflow::FATAL) {
181     sev = tensorflow::ERROR;
182   }
183 
184   // Protect calls with a mutex so we don't interleave calls to LogLines from
185   // multiple threads.
186   static tensorflow::mutex log_lines_mu(tensorflow::LINKER_INITIALIZED);
187   tensorflow::mutex_lock lock(log_lines_mu);
188 
189   size_t cur = 0;
190   while (cur < text.size()) {
191     size_t eol = text.find('\n', cur);
192     if (eol == absl::string_view::npos) {
193       eol = text.size();
194     }
195     auto msg = text.substr(cur, eol - cur);
196     tensorflow::internal::LogString(fname, lineno, sev,
197                                     string(msg.data(), msg.size()));
198     cur = eol + 1;
199   }
200 
201   if (orig_sev == tensorflow::FATAL) {
202     tensorflow::internal::LogString(fname, lineno, orig_sev,
203                                     "Aborting due to errors.");
204   }
205 }
206 
Product(absl::Span<const int64> xs)207 int64 Product(absl::Span<const int64> xs) {
208   return std::accumulate(xs.begin(), xs.end(), static_cast<int64>(1),
209                          std::multiplies<int64>());
210 }
211 
CommonFactors(absl::Span<const int64> a,absl::Span<const int64> b)212 std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
213                                                    absl::Span<const int64> b) {
214   CHECK_EQ(Product(a), Product(b));
215   if (0 == Product(a)) {
216     return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
217   }
218 
219   std::vector<std::pair<int64, int64>> bounds;
220   for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
221              partial_size_b = 1;
222        ;) {
223     if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
224       std::tie(prior_i, prior_j) = std::make_pair(i, j);
225       bounds.emplace_back(i, j);
226       continue;
227     }
228     bool in_bounds_i = i < a.size();
229     bool in_bounds_j = j < b.size();
230     if (!(in_bounds_i || in_bounds_j)) {
231       break;
232     }
233     bool next_a =
234         partial_size_a < partial_size_b ||
235         (in_bounds_i &&
236          (!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
237     bool next_b =
238         partial_size_b < partial_size_a ||
239         (in_bounds_j &&
240          (!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
241     if (next_a) {
242       partial_size_a *= a[i];
243       ++i;
244     }
245     if (next_b) {
246       partial_size_b *= b[j];
247       ++j;
248     }
249   }
250   return bounds;
251 }
252 
SanitizeFileName(string file_name)253 string SanitizeFileName(string file_name) {
254   for (char& c : file_name) {
255     if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
256       c = '_';
257     }
258   }
259   return file_name;
260 }
261 
262 }  // namespace xla
263