1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/debug/debug_grpc_testlib.h"
17
18 #include "tensorflow/core/debug/debug_graph_utils.h"
19 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/lib/io/path.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/tracing.h"
27
28 namespace tensorflow {
29
30 namespace test {
31
SendEvents(::grpc::ServerContext * context,::grpc::ServerReaderWriter<::tensorflow::EventReply,::tensorflow::Event> * stream)32 ::grpc::Status TestEventListenerImpl::SendEvents(
33 ::grpc::ServerContext* context,
34 ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>*
35 stream) {
36 Event event;
37
38 while (stream->Read(&event)) {
39 if (event.has_log_message()) {
40 debug_metadata_strings.push_back(event.log_message().message());
41 stream->Write(EventReply());
42 } else if (!event.graph_def().empty()) {
43 encoded_graph_defs.push_back(event.graph_def());
44 stream->Write(EventReply());
45 } else if (event.has_summary()) {
46 const Summary::Value& val = event.summary().value(0);
47
48 std::vector<string> name_items =
49 tensorflow::str_util::Split(val.node_name(), ':');
50
51 const string node_name = name_items[0];
52 const string debug_op = name_items[2];
53
54 const TensorProto& tensor_proto = val.tensor();
55 Tensor tensor(tensor_proto.dtype());
56 if (!tensor.FromProto(tensor_proto)) {
57 return ::grpc::Status::CANCELLED;
58 }
59
60 // Obtain the device name, which is encoded in JSON.
61 third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
62 if (val.metadata().plugin_data().plugin_name() != "debugger") {
63 // This plugin data was meant for another plugin.
64 continue;
65 }
66 auto status = tensorflow::protobuf::util::JsonStringToMessage(
67 val.metadata().plugin_data().content(), &metadata);
68 if (!status.ok()) {
69 // The device name could not be determined.
70 continue;
71 }
72
73 device_names.push_back(metadata.device());
74 node_names.push_back(node_name);
75 output_slots.push_back(metadata.output_slot());
76 debug_ops.push_back(debug_op);
77 debug_tensors.push_back(tensor);
78
79 // If the debug node is currently in the READ_WRITE mode, send an
80 // EventReply to 1) unblock the execution and 2) optionally modify the
81 // value.
82 const DebugNodeKey debug_node_key(metadata.device(), node_name,
83 metadata.output_slot(), debug_op);
84 if (write_enabled_debug_node_keys_.find(debug_node_key) !=
85 write_enabled_debug_node_keys_.end()) {
86 stream->Write(EventReply());
87 }
88 }
89 }
90
91 {
92 mutex_lock l(states_mu_);
93 for (size_t i = 0; i < new_states_.size(); ++i) {
94 EventReply event_reply;
95 EventReply::DebugOpStateChange* change =
96 event_reply.add_debug_op_state_changes();
97
98 // State changes will take effect in the next stream, i.e., next debugged
99 // Session.run() call.
100 change->set_state(new_states_[i]);
101 const DebugNodeKey& debug_node_key = debug_node_keys_[i];
102 change->set_node_name(debug_node_key.node_name);
103 change->set_output_slot(debug_node_key.output_slot);
104 change->set_debug_op(debug_node_key.debug_op);
105 stream->Write(event_reply);
106
107 if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
108 write_enabled_debug_node_keys_.insert(debug_node_key);
109 } else {
110 write_enabled_debug_node_keys_.erase(debug_node_key);
111 }
112 }
113
114 debug_node_keys_.clear();
115 new_states_.clear();
116 }
117
118 return ::grpc::Status::OK;
119 }
120
ClearReceivedDebugData()121 void TestEventListenerImpl::ClearReceivedDebugData() {
122 debug_metadata_strings.clear();
123 encoded_graph_defs.clear();
124 device_names.clear();
125 node_names.clear();
126 output_slots.clear();
127 debug_ops.clear();
128 debug_tensors.clear();
129 }
130
RequestDebugOpStateChangeAtNextStream(const EventReply::DebugOpStateChange::State new_state,const DebugNodeKey & debug_node_key)131 void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
132 const EventReply::DebugOpStateChange::State new_state,
133 const DebugNodeKey& debug_node_key) {
134 mutex_lock l(states_mu_);
135
136 debug_node_keys_.push_back(debug_node_key);
137 new_states_.push_back(new_state);
138 }
139
RunServer(const int server_port)140 void TestEventListenerImpl::RunServer(const int server_port) {
141 ::grpc::ServerBuilder builder;
142 builder.AddListeningPort(strings::StrCat("localhost:", server_port),
143 ::grpc::InsecureServerCredentials());
144 builder.RegisterService(this);
145 std::unique_ptr<::grpc::Server> server = builder.BuildAndStart();
146
147 while (!stop_requested_.load()) {
148 Env::Default()->SleepForMicroseconds(200 * 1000);
149 }
150 server->Shutdown();
151 stopped_.store(true);
152 }
153
StopServer()154 void TestEventListenerImpl::StopServer() {
155 stop_requested_.store(true);
156 while (!stopped_.load()) {
157 }
158 }
159
PollTillFirstRequestSucceeds(const string & server_url,const size_t max_attempts)160 bool PollTillFirstRequestSucceeds(const string& server_url,
161 const size_t max_attempts) {
162 const int kSleepDurationMicros = 100 * 1000;
163 size_t n_attempts = 0;
164 bool success = false;
165
166 // Try a number of times to send the Event proto to the server, as it may
167 // take the server a few seconds to start up and become responsive.
168 Tensor prep_tensor(DT_FLOAT, TensorShape({1, 1}));
169 prep_tensor.flat<float>()(0) = 42.0f;
170
171 while (n_attempts++ < max_attempts) {
172 const uint64 wall_time = Env::Default()->NowMicros();
173 Status publish_s = DebugIO::PublishDebugTensor(
174 DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "prep_node", 0,
175 "DebugIdentity"),
176 prep_tensor, wall_time, {server_url});
177 Status close_s = DebugIO::CloseDebugURL(server_url);
178
179 if (publish_s.ok() && close_s.ok()) {
180 success = true;
181 break;
182 } else {
183 Env::Default()->SleepForMicroseconds(kSleepDurationMicros);
184 }
185 }
186
187 return success;
188 }
189
190 } // namespace test
191
192 } // namespace tensorflow
193