• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7   http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/debug/debug_grpc_testlib.h"
17 
18 #include "tensorflow/core/debug/debug_graph_utils.h"
19 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/lib/io/path.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/tracing.h"
27 
28 namespace tensorflow {
29 
30 namespace test {
31 
SendEvents(::grpc::ServerContext * context,::grpc::ServerReaderWriter<::tensorflow::EventReply,::tensorflow::Event> * stream)32 ::grpc::Status TestEventListenerImpl::SendEvents(
33     ::grpc::ServerContext* context,
34     ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>*
35         stream) {
36   Event event;
37 
38   while (stream->Read(&event)) {
39     if (event.has_log_message()) {
40       debug_metadata_strings.push_back(event.log_message().message());
41       stream->Write(EventReply());
42     } else if (!event.graph_def().empty()) {
43       encoded_graph_defs.push_back(event.graph_def());
44       stream->Write(EventReply());
45     } else if (event.has_summary()) {
46       const Summary::Value& val = event.summary().value(0);
47 
48       std::vector<string> name_items =
49           tensorflow::str_util::Split(val.node_name(), ':');
50 
51       const string node_name = name_items[0];
52       const string debug_op = name_items[2];
53 
54       const TensorProto& tensor_proto = val.tensor();
55       Tensor tensor(tensor_proto.dtype());
56       if (!tensor.FromProto(tensor_proto)) {
57         return ::grpc::Status::CANCELLED;
58       }
59 
60       // Obtain the device name, which is encoded in JSON.
61       third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
62       if (val.metadata().plugin_data().plugin_name() != "debugger") {
63         // This plugin data was meant for another plugin.
64         continue;
65       }
66       auto status = tensorflow::protobuf::util::JsonStringToMessage(
67           val.metadata().plugin_data().content(), &metadata);
68       if (!status.ok()) {
69         // The device name could not be determined.
70         continue;
71       }
72 
73       device_names.push_back(metadata.device());
74       node_names.push_back(node_name);
75       output_slots.push_back(metadata.output_slot());
76       debug_ops.push_back(debug_op);
77       debug_tensors.push_back(tensor);
78 
79       // If the debug node is currently in the READ_WRITE mode, send an
80       // EventReply to 1) unblock the execution and 2) optionally modify the
81       // value.
82       const DebugNodeKey debug_node_key(metadata.device(), node_name,
83                                         metadata.output_slot(), debug_op);
84       if (write_enabled_debug_node_keys_.find(debug_node_key) !=
85           write_enabled_debug_node_keys_.end()) {
86         stream->Write(EventReply());
87       }
88     }
89   }
90 
91   {
92     mutex_lock l(states_mu_);
93     for (size_t i = 0; i < new_states_.size(); ++i) {
94       EventReply event_reply;
95       EventReply::DebugOpStateChange* change =
96           event_reply.add_debug_op_state_changes();
97 
98       // State changes will take effect in the next stream, i.e., next debugged
99       // Session.run() call.
100       change->set_state(new_states_[i]);
101       const DebugNodeKey& debug_node_key = debug_node_keys_[i];
102       change->set_node_name(debug_node_key.node_name);
103       change->set_output_slot(debug_node_key.output_slot);
104       change->set_debug_op(debug_node_key.debug_op);
105       stream->Write(event_reply);
106 
107       if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
108         write_enabled_debug_node_keys_.insert(debug_node_key);
109       } else {
110         write_enabled_debug_node_keys_.erase(debug_node_key);
111       }
112     }
113 
114     debug_node_keys_.clear();
115     new_states_.clear();
116   }
117 
118   return ::grpc::Status::OK;
119 }
120 
ClearReceivedDebugData()121 void TestEventListenerImpl::ClearReceivedDebugData() {
122   debug_metadata_strings.clear();
123   encoded_graph_defs.clear();
124   device_names.clear();
125   node_names.clear();
126   output_slots.clear();
127   debug_ops.clear();
128   debug_tensors.clear();
129 }
130 
RequestDebugOpStateChangeAtNextStream(const EventReply::DebugOpStateChange::State new_state,const DebugNodeKey & debug_node_key)131 void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
132     const EventReply::DebugOpStateChange::State new_state,
133     const DebugNodeKey& debug_node_key) {
134   mutex_lock l(states_mu_);
135 
136   debug_node_keys_.push_back(debug_node_key);
137   new_states_.push_back(new_state);
138 }
139 
RunServer(const int server_port)140 void TestEventListenerImpl::RunServer(const int server_port) {
141   ::grpc::ServerBuilder builder;
142   builder.AddListeningPort(strings::StrCat("localhost:", server_port),
143                            ::grpc::InsecureServerCredentials());
144   builder.RegisterService(this);
145   std::unique_ptr<::grpc::Server> server = builder.BuildAndStart();
146 
147   while (!stop_requested_.load()) {
148     Env::Default()->SleepForMicroseconds(200 * 1000);
149   }
150   server->Shutdown();
151   stopped_.store(true);
152 }
153 
StopServer()154 void TestEventListenerImpl::StopServer() {
155   stop_requested_.store(true);
156   while (!stopped_.load()) {
157   }
158 }
159 
PollTillFirstRequestSucceeds(const string & server_url,const size_t max_attempts)160 bool PollTillFirstRequestSucceeds(const string& server_url,
161                                   const size_t max_attempts) {
162   const int kSleepDurationMicros = 100 * 1000;
163   size_t n_attempts = 0;
164   bool success = false;
165 
166   // Try a number of times to send the Event proto to the server, as it may
167   // take the server a few seconds to start up and become responsive.
168   Tensor prep_tensor(DT_FLOAT, TensorShape({1, 1}));
169   prep_tensor.flat<float>()(0) = 42.0f;
170 
171   while (n_attempts++ < max_attempts) {
172     const uint64 wall_time = Env::Default()->NowMicros();
173     Status publish_s = DebugIO::PublishDebugTensor(
174         DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "prep_node", 0,
175                      "DebugIdentity"),
176         prep_tensor, wall_time, {server_url});
177     Status close_s = DebugIO::CloseDebugURL(server_url);
178 
179     if (publish_s.ok() && close_s.ok()) {
180       success = true;
181       break;
182     } else {
183       Env::Default()->SleepForMicroseconds(kSleepDurationMicros);
184     }
185   }
186 
187   return success;
188 }
189 
190 }  // namespace test
191 
192 }  // namespace tensorflow
193