1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/logging_ops.h"
17
18 #include <iostream>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace tensorflow {
28
29 namespace {
30
31 // If the following string is found at the beginning of an output stream, it
32 // will be interpreted as a file path.
33 const char kOutputStreamEscapeStr[] = "file://";
34
35 // A mutex that guards appending strings to files.
36 static mutex* file_mutex = new mutex();
37
38 // Appends the given data to the specified file. It will create the file if it
39 // doesn't already exist.
AppendStringToFile(const std::string & fname,StringPiece data,Env * env)40 Status AppendStringToFile(const std::string& fname, StringPiece data,
41 Env* env) {
42 // TODO(ckluk): If opening and closing on every log causes performance issues,
43 // we can reimplement using reference counters.
44 mutex_lock l(*file_mutex);
45 std::unique_ptr<WritableFile> file;
46 TF_RETURN_IF_ERROR(env->NewAppendableFile(fname, &file));
47 Status a = file->Append(absl::StrCat(data, "\n"));
48 Status c = file->Close();
49 return a.ok() ? c : a;
50 }
51
52 } // namespace
53
54 namespace logging {
55
56 typedef std::vector<void (*)(const char*)> Listeners;
57
GetListeners()58 Listeners* GetListeners() {
59 static Listeners* listeners = new Listeners;
60 return listeners;
61 }
62
RegisterListener(void (* listener)(const char *))63 bool RegisterListener(void (*listener)(const char*)) {
64 GetListeners()->push_back(listener);
65 return true;
66 }
67
68 } // end namespace logging
69
70 class AssertOp : public OpKernel {
71 public:
AssertOp(OpKernelConstruction * ctx)72 explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
73 OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
74 }
75
Compute(OpKernelContext * ctx)76 void Compute(OpKernelContext* ctx) override {
77 const Tensor& cond = ctx->input(0);
78 OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
79 errors::InvalidArgument("In[0] should be a scalar: ",
80 cond.shape().DebugString()));
81
82 if (cond.scalar<bool>()()) {
83 return;
84 }
85 string msg = "assertion failed: ";
86 for (int i = 1; i < ctx->num_inputs(); ++i) {
87 strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
88 "]");
89 if (i < ctx->num_inputs() - 1) strings::StrAppend(&msg, " ");
90 }
91 ctx->SetStatus(errors::InvalidArgument(msg));
92 }
93
94 private:
95 int32 summarize_ = 0;
96 };
97
98 REGISTER_KERNEL_BUILDER(Name("Assert").Device(DEVICE_CPU), AssertOp);
99
100 #if GOOGLE_CUDA
101 REGISTER_KERNEL_BUILDER(Name("Assert")
102 .Device(DEVICE_GPU)
103 .HostMemory("condition")
104 .HostMemory("data"),
105 AssertOp);
106 #endif // GOOGLE_CUDA
107
108 class PrintOp : public OpKernel {
109 public:
PrintOp(OpKernelConstruction * ctx)110 explicit PrintOp(OpKernelConstruction* ctx)
111 : OpKernel(ctx), call_counter_(0) {
112 OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &message_));
113 OP_REQUIRES_OK(ctx, ctx->GetAttr("first_n", &first_n_));
114 OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
115 }
116
Compute(OpKernelContext * ctx)117 void Compute(OpKernelContext* ctx) override {
118 if (IsRefType(ctx->input_dtype(0))) {
119 ctx->forward_ref_input_to_ref_output(0, 0);
120 } else {
121 ctx->set_output(0, ctx->input(0));
122 }
123 if (first_n_ >= 0) {
124 mutex_lock l(mu_);
125 if (call_counter_ >= first_n_) return;
126 call_counter_++;
127 }
128 string msg;
129 strings::StrAppend(&msg, message_);
130 for (int i = 1; i < ctx->num_inputs(); ++i) {
131 strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
132 "]");
133 }
134 std::cerr << msg << std::endl;
135 }
136
137 private:
138 mutex mu_;
139 int64 call_counter_ GUARDED_BY(mu_) = 0;
140 int64 first_n_ = 0;
141 int32 summarize_ = 0;
142 string message_;
143 };
144
145 REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
146
147 class PrintV2Op : public OpKernel {
148 public:
PrintV2Op(OpKernelConstruction * ctx)149 explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
150 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
151
152 SetFilePathIfAny();
153 if (!file_path_.empty()) return;
154
155 auto output_stream_index =
156 std::find(std::begin(valid_output_streams_),
157 std::end(valid_output_streams_), output_stream_);
158
159 if (output_stream_index == std::end(valid_output_streams_)) {
160 string error_msg = strings::StrCat(
161 "Unknown output stream: ", output_stream_, ", Valid streams are:");
162 for (auto valid_stream : valid_output_streams_) {
163 strings::StrAppend(&error_msg, " ", valid_stream);
164 }
165 OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
166 }
167 }
168
Compute(OpKernelContext * ctx)169 void Compute(OpKernelContext* ctx) override {
170 const Tensor* input_;
171 OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
172 const string& msg = input_->scalar<string>()();
173
174 if (!file_path_.empty()) {
175 // Outputs to a file at the specified path.
176 OP_REQUIRES_OK(ctx, AppendStringToFile(file_path_, msg, ctx->env()));
177 return;
178 }
179 auto listeners = logging::GetListeners();
180 if (!listeners->empty()) {
181 for (auto& listener : *listeners) {
182 listener(msg.c_str());
183 }
184 } else if (output_stream_ == "stdout") {
185 std::cout << msg << std::endl;
186 } else if (output_stream_ == "stderr") {
187 std::cerr << msg << std::endl;
188 } else if (output_stream_ == "log(info)") {
189 LOG(INFO) << msg << std::endl;
190 } else if (output_stream_ == "log(warning)") {
191 LOG(WARNING) << msg << std::endl;
192 } else if (output_stream_ == "log(error)") {
193 LOG(ERROR) << msg << std::endl;
194 } else {
195 string error_msg = strings::StrCat(
196 "Unknown output stream: ", output_stream_, ", Valid streams are:");
197 for (auto valid_stream : valid_output_streams_) {
198 strings::StrAppend(&error_msg, " ", valid_stream);
199 }
200 strings::StrAppend(&error_msg, ", or file://<filename>");
201 OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
202 }
203 }
204
205 const char* valid_output_streams_[5] = {"stdout", "stderr", "log(info)",
206 "log(warning)", "log(error)"};
207
208 private:
209 // Either output_stream_ or file_path_ (but not both) will be non-empty.
210 string output_stream_;
211 string file_path_;
212
213 // If output_stream_ is a file path, extracts it to file_path_ and clears
214 // output_stream_; otherwise sets file_paths_ to "".
SetFilePathIfAny()215 void SetFilePathIfAny() {
216 if (absl::StartsWith(output_stream_, kOutputStreamEscapeStr)) {
217 file_path_ = output_stream_.substr(strlen(kOutputStreamEscapeStr));
218 output_stream_ = "";
219 } else {
220 file_path_ = "";
221 }
222 }
223 };
224
225 REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
226
227 class TimestampOp : public OpKernel {
228 public:
TimestampOp(OpKernelConstruction * context)229 explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
230
Compute(OpKernelContext * context)231 void Compute(OpKernelContext* context) override {
232 TensorShape output_shape; // Default shape is 0 dim, 1 element
233 Tensor* output_tensor = nullptr;
234 OP_REQUIRES_OK(context,
235 context->allocate_output(0, output_shape, &output_tensor));
236
237 auto output_scalar = output_tensor->scalar<double>();
238 double now_us = static_cast<double>(Env::Default()->NowMicros());
239 double now_s = now_us / 1000000;
240 output_scalar() = now_s;
241 }
242 };
243
244 REGISTER_KERNEL_BUILDER(Name("Timestamp").Device(DEVICE_CPU), TimestampOp);
245
246 } // end namespace tensorflow
247