• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
17 
18 #include <list>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/base/call_once.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/memory/memory.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/core/platform/env.h"
26 
27 namespace xla {
28 namespace {
29 
30 absl::Mutex mu(absl::kConstInit);
31 absl::CondVar* ready;
32 absl::once_flag init_flag;
33 std::list<SlowOperationAlarm*>* outstanding_alarms ABSL_PT_GUARDED_BY(mu) =
34     nullptr;
35 
AlarmLoop()36 void AlarmLoop() {
37   while (true) {
38     absl::MutexLock lock(&mu);
39 
40     // Fire any alarms which are ready.
41     absl::Time now = absl::Now();
42     for (auto it = outstanding_alarms->begin();
43          it != outstanding_alarms->end();) {
44       auto next = std::next(it);
45       auto* alarm = *it;
46       // Fire the alarm if applicable.
47       if (alarm->deadline() <= now) {
48         outstanding_alarms->erase(it);
49         int64 count =
50             alarm->counter() == nullptr ? 0 : alarm->counter()->fetch_add(1);
51         // If the alarm has a counter, only fire if the count is a power of 2.
52         if (count == 0 || (count & (count - 1)) == 0) {
53           // We fire alarms with LOG(ERROR) because otherwise it might not show
54           // up without --logtostderr.
55           LOG(ERROR) << alarm->msg();
56         }
57       }
58       it = next;
59     }
60 
61     if (outstanding_alarms->empty()) {
62       ready->Wait(&mu);
63       continue;
64     }
65 
66     SlowOperationAlarm* next_alarm = *absl::c_min_element(
67         *outstanding_alarms,
68         [](const SlowOperationAlarm* a, const SlowOperationAlarm* b) {
69           return a->deadline() < b->deadline();
70         });
71     ready->WaitWithDeadline(&mu, next_alarm->deadline());
72   }
73 }
74 
ScheduleAlarm(SlowOperationAlarm * alarm)75 void ScheduleAlarm(SlowOperationAlarm* alarm) {
76   absl::call_once(init_flag, [] {
77     ready = new absl::CondVar();
78     outstanding_alarms = new std::list<SlowOperationAlarm*>();
79     (void)tensorflow::Env::Default()->StartThread(
80         tensorflow::ThreadOptions(), "SlowOperationAlarm", [] { AlarmLoop(); });
81   });
82 
83   absl::MutexLock lock(&mu);
84   outstanding_alarms->push_back(alarm);
85   ready->Signal();
86 }
87 
UnscheduleAlarm(const SlowOperationAlarm * alarm)88 void UnscheduleAlarm(const SlowOperationAlarm* alarm) {
89   absl::MutexLock lock(&mu);
90   CHECK(outstanding_alarms != nullptr);
91   auto it = absl::c_find(*outstanding_alarms, alarm);
92   if (it != outstanding_alarms->end()) {
93     outstanding_alarms->erase(it);
94   }
95 }
96 
97 }  // namespace
98 
SlowOperationAlarm(absl::Duration timeout,string msg,std::atomic<int64> * counter)99 SlowOperationAlarm::SlowOperationAlarm(absl::Duration timeout, string msg,
100                                        std::atomic<int64>* counter /*=nullptr*/)
101     : deadline_(absl::Now() + timeout),
102       msg_(std::move(msg)),
103       counter_(counter) {
104   ScheduleAlarm(this);
105 }
106 
~SlowOperationAlarm()107 SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); }
108 
SlowCompilationAlarm(absl::string_view msg)109 std::unique_ptr<SlowOperationAlarm> SlowCompilationAlarm(
110     absl::string_view msg) {
111   // Pass a counter to these alarms so they only log once every power-of-two
112   // occurrences.
113   static auto* counter = new std::atomic<int64>(0);
114 
115   const char* separator = "\n********************************";
116 
117   std::string msg_suffix;
118   if (!msg.empty()) {
119     msg_suffix = absl::StrCat("\n", msg);
120   }
121 
122 #if NDEBUG
123   return absl::make_unique<SlowOperationAlarm>(
124       absl::Duration(absl::Minutes(2)),
125       absl::StrCat(
126           separator,
127           "\nVery slow compile?  If you want to file a bug, run with envvar "
128           "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.",
129           msg_suffix, separator),
130       counter);
131 #else
132   return absl::make_unique<SlowOperationAlarm>(
133       absl::Duration(absl::Seconds(10)),
134       absl::StrCat(
135           separator,
136           "\nSlow compile?  XLA was built without compiler optimizations, "
137           "which can be slow.  Try rebuilding with -c opt.",
138           msg_suffix, separator),
139       counter);
140 #endif
141 }
142 
143 }  // namespace xla
144