• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/util.h"
17 
18 #include <numeric>
19 #include <stdarg.h>
20 #include <numeric>
21 
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/strings/numbers.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/lib/strings/stringprintf.h"
28 #include "tensorflow/core/platform/env.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/stacktrace.h"
31 
32 namespace xla {
33 
WithLogBacktrace(const Status & status)34 Status WithLogBacktrace(const Status& status) {
35   CHECK(!status.ok());
36   VLOG(1) << status.ToString();
37   VLOG(1) << tensorflow::CurrentStackTrace();
38   return status;
39 }
40 
ScopedLoggingTimer(const string & label,bool enabled)41 ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
42     : enabled(enabled), label(label) {
43   if (enabled) {
44     start_micros = tensorflow::Env::Default()->NowMicros();
45   }
46 }
47 
~ScopedLoggingTimer()48 ScopedLoggingTimer::~ScopedLoggingTimer() {
49   if (enabled) {
50     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
51     double secs = (end_micros - start_micros) / 1000000.0;
52 
53     LOG(INFO) << label << " time: "
54               << tensorflow::strings::HumanReadableElapsedTime(secs);
55   }
56 }
57 
AddStatus(Status prior,tensorflow::StringPiece context)58 Status AddStatus(Status prior, tensorflow::StringPiece context) {
59   CHECK(!prior.ok());
60   return Status{prior.code(), tensorflow::strings::StrCat(
61                                   context, ": ", prior.error_message())};
62 }
63 
AppendStatus(Status prior,tensorflow::StringPiece context)64 Status AppendStatus(Status prior, tensorflow::StringPiece context) {
65   CHECK(!prior.ok());
66   return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(),
67                                                           ": ", context)};
68 }
69 
70 // Implementation note: we can't common these out (without using macros) because
71 // they all need to va_start/va_end their varargs in their frame.
72 
InvalidArgumentV(const char * format,va_list args)73 Status InvalidArgumentV(const char* format, va_list args) {
74   string message;
75   tensorflow::strings::Appendv(&message, format, args);
76   return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
77 }
78 
InvalidArgument(const char * format,...)79 Status InvalidArgument(const char* format, ...) {
80   va_list args;
81   va_start(args, format);
82   Status result = InvalidArgumentV(format, args);
83   va_end(args);
84   return result;
85 }
86 
Unimplemented(const char * format,...)87 Status Unimplemented(const char* format, ...) {
88   string message;
89   va_list args;
90   va_start(args, format);
91   tensorflow::strings::Appendv(&message, format, args);
92   va_end(args);
93   return WithLogBacktrace(tensorflow::errors::Unimplemented(message));
94 }
95 
InternalError(const char * format,...)96 Status InternalError(const char* format, ...) {
97   string message;
98   va_list args;
99   va_start(args, format);
100   tensorflow::strings::Appendv(&message, format, args);
101   va_end(args);
102   return WithLogBacktrace(tensorflow::errors::Internal(message));
103 }
104 
FailedPrecondition(const char * format,...)105 Status FailedPrecondition(const char* format, ...) {
106   string message;
107   va_list args;
108   va_start(args, format);
109   tensorflow::strings::Appendv(&message, format, args);
110   va_end(args);
111   return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message));
112 }
113 
Cancelled(const char * format,...)114 Status Cancelled(const char* format, ...) {
115   string message;
116   va_list args;
117   va_start(args, format);
118   tensorflow::strings::Appendv(&message, format, args);
119   va_end(args);
120   return WithLogBacktrace(tensorflow::errors::Cancelled(message));
121 }
122 
ResourceExhausted(const char * format,...)123 Status ResourceExhausted(const char* format, ...) {
124   string message;
125   va_list args;
126   va_start(args, format);
127   tensorflow::strings::Appendv(&message, format, args);
128   va_end(args);
129   return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message));
130 }
131 
NotFound(const char * format,...)132 Status NotFound(const char* format, ...) {
133   string message;
134   va_list args;
135   va_start(args, format);
136   tensorflow::strings::Appendv(&message, format, args);
137   va_end(args);
138   return WithLogBacktrace(tensorflow::errors::NotFound(message));
139 }
140 
Unavailable(const char * format,...)141 Status Unavailable(const char* format, ...) {
142   string message;
143   va_list args;
144   va_start(args, format);
145   tensorflow::strings::Appendv(&message, format, args);
146   va_end(args);
147   return WithLogBacktrace(tensorflow::errors::Unavailable(message));
148 }
149 
Reindent(tensorflow::StringPiece original,const tensorflow::StringPiece indentation)150 string Reindent(tensorflow::StringPiece original,
151                 const tensorflow::StringPiece indentation) {
152   std::vector<string> pieces = tensorflow::str_util::Split(
153       tensorflow::StringPiece(original.data(), original.size()), '\n');
154   return tensorflow::str_util::Join(
155       pieces, "\n", [indentation](string* out, string s) {
156         tensorflow::StringPiece piece(s);
157         tensorflow::str_util::RemoveWhitespaceContext(&piece);
158         tensorflow::strings::StrAppend(out, indentation, piece);
159       });
160 }
161 
IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation,int64 rank)162 bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
163   if (rank != permutation.size()) {
164     return false;
165   }
166   std::vector<int64> output(permutation.size(), -1);
167   for (auto index : permutation) {
168     CHECK_GE(index, 0);
169     CHECK_LT(index, rank);
170     output[index] = 0;
171   }
172   return std::find(output.begin(), output.end(), -1) == output.end();
173 }
174 
InversePermutation(tensorflow::gtl::ArraySlice<int64> input_permutation)175 std::vector<int64> InversePermutation(
176     tensorflow::gtl::ArraySlice<int64> input_permutation) {
177   DCHECK(IsPermutation(input_permutation, input_permutation.size()));
178   std::vector<int64> output_permutation(input_permutation.size(), -1);
179   for (size_t i = 0; i < input_permutation.size(); ++i) {
180     output_permutation[input_permutation[i]] = i;
181   }
182   return output_permutation;
183 }
184 
ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,tensorflow::gtl::ArraySlice<int64> p2)185 std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
186                                        tensorflow::gtl::ArraySlice<int64> p2) {
187   CHECK_EQ(p1.size(), p2.size());
188   std::vector<int64> output;
189   for (size_t i = 0; i < p1.size(); ++i) {
190     output.push_back(p1[p2[i]]);
191   }
192   return output;
193 }
194 
IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation)195 bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) {
196   for (int64 i = 0; i < permutation.size(); ++i) {
197     if (permutation[i] != i) {
198       return false;
199     }
200   }
201   return true;
202 }
203 
MakeNoPaddingConfig(int64 rank)204 PaddingConfig MakeNoPaddingConfig(int64 rank) {
205   PaddingConfig padding_config;
206   for (int64 dnum = 0; dnum < rank; ++dnum) {
207     auto dimension = padding_config.add_dimensions();
208     dimension->set_edge_padding_low(0);
209     dimension->set_edge_padding_high(0);
210     dimension->set_interior_padding(0);
211   }
212   return padding_config;
213 }
214 
MakeEdgePaddingConfig(tensorflow::gtl::ArraySlice<std::pair<int64,int64>> padding)215 PaddingConfig MakeEdgePaddingConfig(
216     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
217   PaddingConfig padding_config;
218   for (const std::pair<int64, int64>& dim : padding) {
219     auto dimension = padding_config.add_dimensions();
220     dimension->set_edge_padding_low(dim.first);
221     dimension->set_edge_padding_high(dim.second);
222     dimension->set_interior_padding(0);
223   }
224   return padding_config;
225 }
226 
HasInteriorPadding(const PaddingConfig & config)227 bool HasInteriorPadding(const PaddingConfig& config) {
228   for (const auto& dim : config.dimensions()) {
229     if (dim.interior_padding() != 0) {
230       return true;
231     }
232   }
233   return false;
234 }
235 
236 namespace {
HumanReadableNumOps(double flops,double nanoseconds,tensorflow::StringPiece op_prefix)237 string HumanReadableNumOps(double flops, double nanoseconds,
238                            tensorflow::StringPiece op_prefix) {
239   if (nanoseconds == 0) {
240     return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s");
241   }
242   double nano_flops = flops / nanoseconds;
243   string throughput = tensorflow::strings::HumanReadableNum(
244       static_cast<int64>(nano_flops * 1e9));
245   tensorflow::StringPiece sp(throughput);
246   // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
247   if (sp.ends_with("B") ||  // Ends in 'B', ignoring case
248       sp.ends_with("b")) {
249     *throughput.rbegin() = 'G';
250   }
251   throughput += tensorflow::strings::StrCat(op_prefix, "OP/s");
252   return throughput;
253 }
254 }  // namespace
255 
HumanReadableNumFlops(double flops,double nanoseconds)256 string HumanReadableNumFlops(double flops, double nanoseconds) {
257   return HumanReadableNumOps(flops, nanoseconds, "FL");
258 }
259 
HumanReadableNumTranscendentalOps(double trops,double nanoseconds)260 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
261   return HumanReadableNumOps(trops, nanoseconds, "TR");
262 }
263 
LogLines(int sev,tensorflow::StringPiece text,const char * fname,int lineno)264 void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
265               int lineno) {
266   const int orig_sev = sev;
267   if (sev == tensorflow::FATAL) {
268     sev = tensorflow::ERROR;
269   }
270 
271   // Protect calls with a mutex so we don't interleave calls to LogLines from
272   // multiple threads.
273   static tensorflow::mutex log_lines_mu(tensorflow::LINKER_INITIALIZED);
274   tensorflow::mutex_lock lock(log_lines_mu);
275 
276   size_t cur = 0;
277   while (cur < text.size()) {
278     size_t eol = text.find('\n', cur);
279     if (eol == tensorflow::StringPiece::npos) {
280       eol = text.size();
281     }
282     auto msg = text.substr(cur, eol - cur);
283     tensorflow::internal::LogString(fname, lineno, sev,
284                                     string(msg.data(), msg.size()));
285     cur = eol + 1;
286   }
287 
288   if (orig_sev == tensorflow::FATAL) {
289     tensorflow::internal::LogString(fname, lineno, orig_sev,
290                                     "Aborting due to errors.");
291   }
292 }
293 
Product(tensorflow::gtl::ArraySlice<int64> xs)294 int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
295   return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies<int64>());
296 }
297 
CommonFactors(tensorflow::gtl::ArraySlice<int64> a,tensorflow::gtl::ArraySlice<int64> b)298 std::vector<std::pair<int64, int64>> CommonFactors(
299     tensorflow::gtl::ArraySlice<int64> a,
300     tensorflow::gtl::ArraySlice<int64> b) {
301   CHECK_EQ(Product(a), Product(b));
302   if (0 == Product(a)) {
303     return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
304   }
305 
306   std::vector<std::pair<int64, int64>> bounds;
307   for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
308              partial_size_b = 1;
309        ;) {
310     if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
311       std::tie(prior_i, prior_j) = std::make_pair(i, j);
312       bounds.emplace_back(i, j);
313       continue;
314     }
315     bool in_bounds_i = i < a.size();
316     bool in_bounds_j = j < b.size();
317     if (!(in_bounds_i || in_bounds_j)) {
318       break;
319     }
320     bool next_a =
321         partial_size_a < partial_size_b ||
322         (in_bounds_i &&
323          (!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
324     bool next_b =
325         partial_size_b < partial_size_a ||
326         (in_bounds_j &&
327          (!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
328     if (next_a) {
329       partial_size_a *= a[i];
330       ++i;
331     }
332     if (next_b) {
333       partial_size_b *= b[j];
334       ++j;
335     }
336   }
337   return bounds;
338 }
339 
SanitizeFileName(string file_name)340 string SanitizeFileName(string file_name) {
341   for (char& c : file_name) {
342     if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
343       c = '_';
344     }
345   }
346   return file_name;
347 }
348 
349 }  // namespace xla
350