• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/logging_ops.h"
17 
18 #include <iostream>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 
29 namespace {
30 
31 // If the following string is found at the beginning of an output stream, it
32 // will be interpreted as a file path.
33 const char kOutputStreamEscapeStr[] = "file://";
34 
35 // A mutex that guards appending strings to files.
36 static mutex* file_mutex = new mutex();
37 
38 // Appends the given data to the specified file. It will create the file if it
39 // doesn't already exist.
AppendStringToFile(const std::string & fname,StringPiece data,Env * env)40 Status AppendStringToFile(const std::string& fname, StringPiece data,
41                           Env* env) {
42   // TODO(ckluk): If opening and closing on every log causes performance issues,
43   // we can reimplement using reference counters.
44   mutex_lock l(*file_mutex);
45   std::unique_ptr<WritableFile> file;
46   TF_RETURN_IF_ERROR(env->NewAppendableFile(fname, &file));
47   Status a = file->Append(absl::StrCat(data, "\n"));
48   Status c = file->Close();
49   return a.ok() ? c : a;
50 }
51 
52 }  // namespace
53 
54 namespace logging {
55 
56 typedef std::vector<void (*)(const char*)> Listeners;
57 
GetListeners()58 Listeners* GetListeners() {
59   static Listeners* listeners = new Listeners;
60   return listeners;
61 }
62 
RegisterListener(void (* listener)(const char *))63 bool RegisterListener(void (*listener)(const char*)) {
64   GetListeners()->push_back(listener);
65   return true;
66 }
67 
68 }  // end namespace logging
69 
70 class AssertOp : public OpKernel {
71  public:
AssertOp(OpKernelConstruction * ctx)72   explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
73     OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
74   }
75 
Compute(OpKernelContext * ctx)76   void Compute(OpKernelContext* ctx) override {
77     const Tensor& cond = ctx->input(0);
78     OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
79                 errors::InvalidArgument("In[0] should be a scalar: ",
80                                         cond.shape().DebugString()));
81 
82     if (cond.scalar<bool>()()) {
83       return;
84     }
85     string msg = "assertion failed: ";
86     for (int i = 1; i < ctx->num_inputs(); ++i) {
87       strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
88                          "]");
89       if (i < ctx->num_inputs() - 1) strings::StrAppend(&msg, " ");
90     }
91     ctx->SetStatus(errors::InvalidArgument(msg));
92   }
93 
94  private:
95   int32 summarize_ = 0;
96 };
97 
98 REGISTER_KERNEL_BUILDER(Name("Assert").Device(DEVICE_CPU), AssertOp);
99 
100 #if GOOGLE_CUDA
101 REGISTER_KERNEL_BUILDER(Name("Assert")
102                             .Device(DEVICE_GPU)
103                             .HostMemory("condition")
104                             .HostMemory("data"),
105                         AssertOp);
106 #endif  // GOOGLE_CUDA
107 
108 class PrintOp : public OpKernel {
109  public:
PrintOp(OpKernelConstruction * ctx)110   explicit PrintOp(OpKernelConstruction* ctx)
111       : OpKernel(ctx), call_counter_(0) {
112     OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &message_));
113     OP_REQUIRES_OK(ctx, ctx->GetAttr("first_n", &first_n_));
114     OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
115   }
116 
Compute(OpKernelContext * ctx)117   void Compute(OpKernelContext* ctx) override {
118     if (IsRefType(ctx->input_dtype(0))) {
119       ctx->forward_ref_input_to_ref_output(0, 0);
120     } else {
121       ctx->set_output(0, ctx->input(0));
122     }
123     if (first_n_ >= 0) {
124       mutex_lock l(mu_);
125       if (call_counter_ >= first_n_) return;
126       call_counter_++;
127     }
128     string msg;
129     strings::StrAppend(&msg, message_);
130     for (int i = 1; i < ctx->num_inputs(); ++i) {
131       strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
132                          "]");
133     }
134     std::cerr << msg << std::endl;
135   }
136 
137  private:
138   mutex mu_;
139   int64 call_counter_ GUARDED_BY(mu_) = 0;
140   int64 first_n_ = 0;
141   int32 summarize_ = 0;
142   string message_;
143 };
144 
145 REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
146 
147 class PrintV2Op : public OpKernel {
148  public:
PrintV2Op(OpKernelConstruction * ctx)149   explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
150     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
151 
152     SetFilePathIfAny();
153     if (!file_path_.empty()) return;
154 
155     auto output_stream_index =
156         std::find(std::begin(valid_output_streams_),
157                   std::end(valid_output_streams_), output_stream_);
158 
159     if (output_stream_index == std::end(valid_output_streams_)) {
160       string error_msg = strings::StrCat(
161           "Unknown output stream: ", output_stream_, ", Valid streams are:");
162       for (auto valid_stream : valid_output_streams_) {
163         strings::StrAppend(&error_msg, " ", valid_stream);
164       }
165       OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
166     }
167   }
168 
Compute(OpKernelContext * ctx)169   void Compute(OpKernelContext* ctx) override {
170     const Tensor* input_;
171     OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
172     const string& msg = input_->scalar<string>()();
173 
174     if (!file_path_.empty()) {
175       // Outputs to a file at the specified path.
176       OP_REQUIRES_OK(ctx, AppendStringToFile(file_path_, msg, ctx->env()));
177       return;
178     }
179     auto listeners = logging::GetListeners();
180     if (!listeners->empty()) {
181       for (auto& listener : *listeners) {
182         listener(msg.c_str());
183       }
184     } else if (output_stream_ == "stdout") {
185       std::cout << msg << std::endl;
186     } else if (output_stream_ == "stderr") {
187       std::cerr << msg << std::endl;
188     } else if (output_stream_ == "log(info)") {
189       LOG(INFO) << msg << std::endl;
190     } else if (output_stream_ == "log(warning)") {
191       LOG(WARNING) << msg << std::endl;
192     } else if (output_stream_ == "log(error)") {
193       LOG(ERROR) << msg << std::endl;
194     } else {
195       string error_msg = strings::StrCat(
196           "Unknown output stream: ", output_stream_, ", Valid streams are:");
197       for (auto valid_stream : valid_output_streams_) {
198         strings::StrAppend(&error_msg, " ", valid_stream);
199       }
200       strings::StrAppend(&error_msg, ", or file://<filename>");
201       OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
202     }
203   }
204 
205   const char* valid_output_streams_[5] = {"stdout", "stderr", "log(info)",
206                                           "log(warning)", "log(error)"};
207 
208  private:
209   // Either output_stream_ or file_path_ (but not both) will be non-empty.
210   string output_stream_;
211   string file_path_;
212 
213   // If output_stream_ is a file path, extracts it to file_path_ and clears
214   // output_stream_; otherwise sets file_paths_ to "".
SetFilePathIfAny()215   void SetFilePathIfAny() {
216     if (absl::StartsWith(output_stream_, kOutputStreamEscapeStr)) {
217       file_path_ = output_stream_.substr(strlen(kOutputStreamEscapeStr));
218       output_stream_ = "";
219     } else {
220       file_path_ = "";
221     }
222   }
223 };
224 
225 REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
226 
227 class TimestampOp : public OpKernel {
228  public:
TimestampOp(OpKernelConstruction * context)229   explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
230 
Compute(OpKernelContext * context)231   void Compute(OpKernelContext* context) override {
232     TensorShape output_shape;  // Default shape is 0 dim, 1 element
233     Tensor* output_tensor = nullptr;
234     OP_REQUIRES_OK(context,
235                    context->allocate_output(0, output_shape, &output_tensor));
236 
237     auto output_scalar = output_tensor->scalar<double>();
238     double now_us = static_cast<double>(Env::Default()->NowMicros());
239     double now_s = now_us / 1000000;
240     output_scalar() = now_s;
241   }
242 };
243 
244 REGISTER_KERNEL_BUILDER(Name("Timestamp").Device(DEVICE_CPU), TimestampOp);
245 
246 }  // end namespace tensorflow
247