• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/summary/summary_db_writer.h"
16 
17 #include <deque>
18 
19 #include "tensorflow/core/summary/summary_converter.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/summary.pb.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/db/sqlite.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/util/event.pb.h"
28 
29 // TODO(jart): Break this up into multiple files with excellent unit tests.
30 // TODO(jart): Make decision to write in separate op.
31 // TODO(jart): Add really good busy handling.
32 
33 // clang-format off
34 #define CALL_SUPPORTED_TYPES(m) \
35   TF_CALL_string(m)             \
36   TF_CALL_half(m)               \
37   TF_CALL_float(m)              \
38   TF_CALL_double(m)             \
39   TF_CALL_complex64(m)          \
40   TF_CALL_complex128(m)         \
41   TF_CALL_int8(m)               \
42   TF_CALL_int16(m)              \
43   TF_CALL_int32(m)              \
44   TF_CALL_int64(m)              \
45   TF_CALL_uint8(m)              \
46   TF_CALL_uint16(m)             \
47   TF_CALL_uint32(m)             \
48   TF_CALL_uint64(m)
49 // clang-format on
50 
51 namespace tensorflow {
52 namespace {
53 
54 // https://www.sqlite.org/fileformat.html#record_format
55 const uint64 kIdTiers[] = {
56     0x7fffffULL,        // 23-bit (3 bytes on disk)
57     0x7fffffffULL,      // 31-bit (4 bytes on disk)
58     0x7fffffffffffULL,  // 47-bit (5 bytes on disk)
59                         // remaining bits for future use
60 };
61 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
62 const int kIdCollisionDelayMicros = 10;
63 const int kMaxIdCollisions = 21;  // sum(2**i*10µs for i in range(21))~=21s
64 const int64 kAbsent = 0LL;
65 
66 const char* kScalarPluginName = "scalars";
67 const char* kImagePluginName = "images";
68 const char* kAudioPluginName = "audio";
69 const char* kHistogramPluginName = "histograms";
70 
71 const int64 kReserveMinBytes = 32;
72 const double kReserveMultiplier = 1.5;
73 const int64 kPreallocateRows = 1000;
74 
75 // Flush is a misnomer because what we're actually doing is having lots
76 // of commits inside any SqliteTransaction that writes potentially
77 // hundreds of megs but doesn't need the transaction to maintain its
78 // invariants. This ensures the WAL read penalty is small and might
79 // allow writers in other processes a chance to schedule.
80 const uint64 kFlushBytes = 1024 * 1024;
81 
DoubleTime(uint64 micros)82 double DoubleTime(uint64 micros) {
83   // TODO(@jart): Follow precise definitions for time laid out in schema.
84   // TODO(@jart): Use monotonic clock from gRPC codebase.
85   return static_cast<double>(micros) / 1.0e6;
86 }
87 
StringifyShape(const TensorShape & shape)88 string StringifyShape(const TensorShape& shape) {
89   string result;
90   bool first = true;
91   for (const auto& dim : shape) {
92     if (first) {
93       first = false;
94     } else {
95       strings::StrAppend(&result, ",");
96     }
97     strings::StrAppend(&result, dim.size);
98   }
99   return result;
100 }
101 
CheckSupportedType(const Tensor & t)102 Status CheckSupportedType(const Tensor& t) {
103 #define CASE(T)                  \
104   case DataTypeToEnum<T>::value: \
105     break;
106   switch (t.dtype()) {
107     CALL_SUPPORTED_TYPES(CASE)
108     default:
109       return errors::Unimplemented(DataTypeString(t.dtype()),
110                                    " tensors unsupported on platform");
111   }
112   return Status::OK();
113 #undef CASE
114 }
115 
AsScalar(const Tensor & t)116 Tensor AsScalar(const Tensor& t) {
117   Tensor t2{t.dtype(), {}};
118 #define CASE(T)                        \
119   case DataTypeToEnum<T>::value:       \
120     t2.scalar<T>()() = t.flat<T>()(0); \
121     break;
122   switch (t.dtype()) {
123     CALL_SUPPORTED_TYPES(CASE)
124     default:
125       t2 = {DT_FLOAT, {}};
126       t2.scalar<float>()() = NAN;
127       break;
128   }
129   return t2;
130 #undef CASE
131 }
132 
PatchPluginName(SummaryMetadata * metadata,const char * name)133 void PatchPluginName(SummaryMetadata* metadata, const char* name) {
134   if (metadata->plugin_data().plugin_name().empty()) {
135     metadata->mutable_plugin_data()->set_plugin_name(name);
136   }
137 }
138 
SetDescription(Sqlite * db,int64 id,const StringPiece & markdown)139 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
140   const char* sql = R"sql(
141     INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
142   )sql";
143   SqliteStatement insert_desc;
144   TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
145   insert_desc.BindInt(1, id);
146   insert_desc.BindText(2, markdown);
147   return insert_desc.StepAndReset();
148 }
149 
150 /// \brief Generates unique IDs randomly in the [1,2**63-1] range.
151 ///
152 /// This class starts off generating IDs in the [1,2**23-1] range,
153 /// because it's human friendly and occupies 4 bytes max on disk with
154 /// SQLite's zigzag varint encoding. Then, each time a collision
155 /// happens, the random space is increased by 8 bits.
156 ///
157 /// This class uses exponential back-off so writes gradually slow down
158 /// as IDs become exhausted but reads are still possible.
159 ///
160 /// This class is thread safe.
161 class IdAllocator {
162  public:
IdAllocator(Env * env,Sqlite * db)163   IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
164     DCHECK(env_ != nullptr);
165     DCHECK(db_ != nullptr);
166   }
167 
CreateNewId(int64 * id)168   Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) {
169     mutex_lock lock(mu_);
170     Status s;
171     SqliteStatement stmt;
172     TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
173     for (int i = 0; i < kMaxIdCollisions; ++i) {
174       int64 tid = MakeRandomId();
175       stmt.BindInt(1, tid);
176       s = stmt.StepAndReset();
177       if (s.ok()) {
178         *id = tid;
179         break;
180       }
181       // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc
182       if (s.code() != error::INVALID_ARGUMENT) break;
183       if (tier_ < kMaxIdTier) {
184         LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of "
185                   << kMaxIdTier << ") so auto-adjusting to a higher tier";
186         ++tier_;
187       } else {
188         LOG(WARNING) << "IdAllocator (attempt #" << i << ") "
189                      << "resulted in a collision at the highest tier; this "
190                         "is problematic if it happens often; you can try "
191                         "pruning the Ids table; you can also file a bug "
192                         "asking for the ID space to be increased; otherwise "
193                         "writes will gradually slow down over time until they "
194                         "become impossible";
195       }
196       env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros);
197     }
198     return s;
199   }
200 
201  private:
MakeRandomId()202   int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
203     int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
204     if (id == kAbsent) ++id;
205     return id;
206   }
207 
208   mutex mu_;
209   Env* const env_;
210   Sqlite* const db_;
211   int tier_ GUARDED_BY(mu_) = 0;
212 
213   TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
214 };
215 
216 class GraphWriter {
217  public:
Save(Sqlite * db,SqliteTransaction * txn,IdAllocator * ids,GraphDef * graph,uint64 now,int64 run_id,int64 * graph_id)218   static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
219                      GraphDef* graph, uint64 now, int64 run_id, int64* graph_id)
220       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
221     TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
222     GraphWriter saver{db, txn, graph, now, *graph_id};
223     saver.MapNameToNodeId();
224     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
225     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
226     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
227     return Status::OK();
228   }
229 
230  private:
GraphWriter(Sqlite * db,SqliteTransaction * txn,GraphDef * graph,uint64 now,int64 graph_id)231   GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
232               int64 graph_id)
233       : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
234 
MapNameToNodeId()235   void MapNameToNodeId() {
236     size_t toto = static_cast<size_t>(graph_->node_size());
237     name_copies_.reserve(toto);
238     name_to_node_id_.reserve(toto);
239     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
240       // Copy name into memory region, since we call clear_name() later.
241       // Then wrap in StringPiece so we can compare slices without copy.
242       name_copies_.emplace_back(graph_->node(node_id).name());
243       name_to_node_id_.emplace(name_copies_.back(), node_id);
244     }
245   }
246 
SaveNodeInputs()247   Status SaveNodeInputs() {
248     const char* sql = R"sql(
249       INSERT INTO NodeInputs (
250         graph_id,
251         node_id,
252         idx,
253         input_node_id,
254         input_node_idx,
255         is_control
256       ) VALUES (?, ?, ?, ?, ?, ?)
257     )sql";
258     SqliteStatement insert;
259     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
260     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
261       const NodeDef& node = graph_->node(node_id);
262       for (int idx = 0; idx < node.input_size(); ++idx) {
263         StringPiece name = node.input(idx);
264         int64 input_node_id;
265         int64 input_node_idx = 0;
266         int64 is_control = 0;
267         size_t i = name.rfind(':');
268         if (i != StringPiece::npos) {
269           if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
270                                      &input_node_idx)) {
271             return errors::DataLoss("Bad NodeDef.input: ", name);
272           }
273           name.remove_suffix(name.size() - i);
274         }
275         if (!name.empty() && name[0] == '^') {
276           name.remove_prefix(1);
277           is_control = 1;
278         }
279         auto e = name_to_node_id_.find(name);
280         if (e == name_to_node_id_.end()) {
281           return errors::DataLoss("Could not find node: ", name);
282         }
283         input_node_id = e->second;
284         insert.BindInt(1, graph_id_);
285         insert.BindInt(2, node_id);
286         insert.BindInt(3, idx);
287         insert.BindInt(4, input_node_id);
288         insert.BindInt(5, input_node_idx);
289         insert.BindInt(6, is_control);
290         unflushed_bytes_ += insert.size();
291         TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
292                                         " -> ", name);
293         TF_RETURN_IF_ERROR(MaybeFlush());
294       }
295     }
296     return Status::OK();
297   }
298 
SaveNodes()299   Status SaveNodes() {
300     const char* sql = R"sql(
301       INSERT INTO Nodes (
302         graph_id,
303         node_id,
304         node_name,
305         op,
306         device,
307         node_def)
308       VALUES (?, ?, ?, ?, ?, ?)
309     )sql";
310     SqliteStatement insert;
311     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
312     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
313       NodeDef* node = graph_->mutable_node(node_id);
314       insert.BindInt(1, graph_id_);
315       insert.BindInt(2, node_id);
316       insert.BindText(3, node->name());
317       insert.BindText(4, node->op());
318       insert.BindText(5, node->device());
319       node->clear_name();
320       node->clear_op();
321       node->clear_device();
322       node->clear_input();
323       string node_def;
324       if (node->SerializeToString(&node_def)) {
325         insert.BindBlobUnsafe(6, node_def);
326       }
327       unflushed_bytes_ += insert.size();
328       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
329       TF_RETURN_IF_ERROR(MaybeFlush());
330     }
331     return Status::OK();
332   }
333 
SaveGraph(int64 run_id)334   Status SaveGraph(int64 run_id) {
335     const char* sql = R"sql(
336       INSERT OR REPLACE INTO Graphs (
337         run_id,
338         graph_id,
339         inserted_time,
340         graph_def
341       ) VALUES (?, ?, ?, ?)
342     )sql";
343     SqliteStatement insert;
344     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
345     if (run_id != kAbsent) insert.BindInt(1, run_id);
346     insert.BindInt(2, graph_id_);
347     insert.BindDouble(3, DoubleTime(now_));
348     graph_->clear_node();
349     string graph_def;
350     if (graph_->SerializeToString(&graph_def)) {
351       insert.BindBlobUnsafe(4, graph_def);
352     }
353     return insert.StepAndReset();
354   }
355 
MaybeFlush()356   Status MaybeFlush() {
357     if (unflushed_bytes_ >= kFlushBytes) {
358       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
359                                       unflushed_bytes_, " bytes");
360       unflushed_bytes_ = 0;
361     }
362     return Status::OK();
363   }
364 
365   Sqlite* const db_;
366   SqliteTransaction* const txn_;
367   uint64 unflushed_bytes_ = 0;
368   GraphDef* const graph_;
369   const uint64 now_;
370   const int64 graph_id_;
371   std::vector<string> name_copies_;
372   std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
373 
374   TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
375 };
376 
377 /// \brief Run metadata manager.
378 ///
379 /// This class gives us Tag IDs we can pass to SeriesWriter. In order
380 /// to do that, rows are created in the Ids, Tags, Runs, Experiments,
381 /// and Users tables.
382 ///
383 /// This class is thread safe.
384 class RunMetadata {
385  public:
RunMetadata(IdAllocator * ids,const string & experiment_name,const string & run_name,const string & user_name)386   RunMetadata(IdAllocator* ids, const string& experiment_name,
387               const string& run_name, const string& user_name)
388       : ids_{ids},
389         experiment_name_{experiment_name},
390         run_name_{run_name},
391         user_name_{user_name} {
392     DCHECK(ids_ != nullptr);
393   }
394 
experiment_name()395   const string& experiment_name() { return experiment_name_; }
run_name()396   const string& run_name() { return run_name_; }
user_name()397   const string& user_name() { return user_name_; }
398 
run_id()399   int64 run_id() LOCKS_EXCLUDED(mu_) {
400     mutex_lock lock(mu_);
401     return run_id_;
402   }
403 
SetGraph(Sqlite * db,uint64 now,double computed_time,std::unique_ptr<GraphDef> g)404   Status SetGraph(Sqlite* db, uint64 now, double computed_time,
405                   std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
406       LOCKS_EXCLUDED(mu_) {
407     int64 run_id;
408     {
409       mutex_lock lock(mu_);
410       TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
411       run_id = run_id_;
412     }
413     int64 graph_id;
414     SqliteTransaction txn(*db);  // only to increase performance
415     TF_RETURN_IF_ERROR(
416         GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
417     return txn.Commit();
418   }
419 
GetTagId(Sqlite * db,uint64 now,double computed_time,const string & tag_name,int64 * tag_id,const SummaryMetadata & metadata)420   Status GetTagId(Sqlite* db, uint64 now, double computed_time,
421                   const string& tag_name, int64* tag_id,
422                   const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) {
423     mutex_lock lock(mu_);
424     TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
425     auto e = tag_ids_.find(tag_name);
426     if (e != tag_ids_.end()) {
427       *tag_id = e->second;
428       return Status::OK();
429     }
430     TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
431     tag_ids_[tag_name] = *tag_id;
432     TF_RETURN_IF_ERROR(
433         SetDescription(db, *tag_id, metadata.summary_description()));
434     const char* sql = R"sql(
435       INSERT INTO Tags (
436         run_id,
437         tag_id,
438         tag_name,
439         inserted_time,
440         display_name,
441         plugin_name,
442         plugin_data
443       ) VALUES (
444         :run_id,
445         :tag_id,
446         :tag_name,
447         :inserted_time,
448         :display_name,
449         :plugin_name,
450         :plugin_data
451       )
452     )sql";
453     SqliteStatement insert;
454     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
455     if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
456     insert.BindInt(":tag_id", *tag_id);
457     insert.BindTextUnsafe(":tag_name", tag_name);
458     insert.BindDouble(":inserted_time", DoubleTime(now));
459     insert.BindTextUnsafe(":display_name", metadata.display_name());
460     insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
461     insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
462     return insert.StepAndReset();
463   }
464 
465  private:
InitializeUser(Sqlite * db,uint64 now)466   Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
467     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
468     const char* get_sql = R"sql(
469       SELECT user_id FROM Users WHERE user_name = ?
470     )sql";
471     SqliteStatement get;
472     TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
473     get.BindText(1, user_name_);
474     bool is_done;
475     TF_RETURN_IF_ERROR(get.Step(&is_done));
476     if (!is_done) {
477       user_id_ = get.ColumnInt(0);
478       return Status::OK();
479     }
480     TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
481     const char* insert_sql = R"sql(
482       INSERT INTO Users (
483         user_id,
484         user_name,
485         inserted_time
486       ) VALUES (?, ?, ?)
487     )sql";
488     SqliteStatement insert;
489     TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
490     insert.BindInt(1, user_id_);
491     insert.BindText(2, user_name_);
492     insert.BindDouble(3, DoubleTime(now));
493     TF_RETURN_IF_ERROR(insert.StepAndReset());
494     return Status::OK();
495   }
496 
InitializeExperiment(Sqlite * db,uint64 now,double computed_time)497   Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
498       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
499     if (experiment_name_.empty()) return Status::OK();
500     if (experiment_id_ == kAbsent) {
501       TF_RETURN_IF_ERROR(InitializeUser(db, now));
502       const char* get_sql = R"sql(
503         SELECT
504           experiment_id,
505           started_time
506         FROM
507           Experiments
508         WHERE
509           user_id IS ?
510           AND experiment_name = ?
511       )sql";
512       SqliteStatement get;
513       TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
514       if (user_id_ != kAbsent) get.BindInt(1, user_id_);
515       get.BindText(2, experiment_name_);
516       bool is_done;
517       TF_RETURN_IF_ERROR(get.Step(&is_done));
518       if (!is_done) {
519         experiment_id_ = get.ColumnInt(0);
520         experiment_started_time_ = get.ColumnInt(1);
521       } else {
522         TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
523         experiment_started_time_ = computed_time;
524         const char* insert_sql = R"sql(
525           INSERT INTO Experiments (
526             user_id,
527             experiment_id,
528             experiment_name,
529             inserted_time,
530             started_time,
531             is_watching
532           ) VALUES (?, ?, ?, ?, ?, ?)
533         )sql";
534         SqliteStatement insert;
535         TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
536         if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
537         insert.BindInt(2, experiment_id_);
538         insert.BindText(3, experiment_name_);
539         insert.BindDouble(4, DoubleTime(now));
540         insert.BindDouble(5, computed_time);
541         insert.BindInt(6, 0);
542         TF_RETURN_IF_ERROR(insert.StepAndReset());
543       }
544     }
545     if (computed_time < experiment_started_time_) {
546       experiment_started_time_ = computed_time;
547       const char* update_sql = R"sql(
548         UPDATE
549           Experiments
550         SET
551           started_time = ?
552         WHERE
553           experiment_id = ?
554       )sql";
555       SqliteStatement update;
556       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
557       update.BindDouble(1, computed_time);
558       update.BindInt(2, experiment_id_);
559       TF_RETURN_IF_ERROR(update.StepAndReset());
560     }
561     return Status::OK();
562   }
563 
InitializeRun(Sqlite * db,uint64 now,double computed_time)564   Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
565       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
566     if (run_name_.empty()) return Status::OK();
567     TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
568     if (run_id_ == kAbsent) {
569       TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
570       run_started_time_ = computed_time;
571       const char* insert_sql = R"sql(
572         INSERT OR REPLACE INTO Runs (
573           experiment_id,
574           run_id,
575           run_name,
576           inserted_time,
577           started_time
578         ) VALUES (?, ?, ?, ?, ?)
579       )sql";
580       SqliteStatement insert;
581       TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
582       if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
583       insert.BindInt(2, run_id_);
584       insert.BindText(3, run_name_);
585       insert.BindDouble(4, DoubleTime(now));
586       insert.BindDouble(5, computed_time);
587       TF_RETURN_IF_ERROR(insert.StepAndReset());
588     }
589     if (computed_time < run_started_time_) {
590       run_started_time_ = computed_time;
591       const char* update_sql = R"sql(
592         UPDATE
593           Runs
594         SET
595           started_time = ?
596         WHERE
597           run_id = ?
598       )sql";
599       SqliteStatement update;
600       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
601       update.BindDouble(1, computed_time);
602       update.BindInt(2, run_id_);
603       TF_RETURN_IF_ERROR(update.StepAndReset());
604     }
605     return Status::OK();
606   }
607 
608   mutex mu_;
609   IdAllocator* const ids_;
610   const string experiment_name_;
611   const string run_name_;
612   const string user_name_;
613   int64 experiment_id_ GUARDED_BY(mu_) = kAbsent;
614   int64 run_id_ GUARDED_BY(mu_) = kAbsent;
615   int64 user_id_ GUARDED_BY(mu_) = kAbsent;
616   double experiment_started_time_ GUARDED_BY(mu_) = 0.0;
617   double run_started_time_ GUARDED_BY(mu_) = 0.0;
618   std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_);
619 
620   TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
621 };
622 
623 /// \brief Tensor writer for a single series, e.g. Tag.
624 ///
625 /// This class is thread safe.
626 class SeriesWriter {
627  public:
SeriesWriter(int64 series,RunMetadata * meta)628   SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} {
629     DCHECK(series_ > 0);
630   }
631 
Append(Sqlite * db,int64 step,uint64 now,double computed_time,const Tensor & t)632   Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
633                 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
634       LOCKS_EXCLUDED(mu_) {
635     mutex_lock lock(mu_);
636     if (rowids_.empty()) {
637       Status s = Reserve(db, t);
638       if (!s.ok()) {
639         rowids_.clear();
640         return s;
641       }
642     }
643     int64 rowid = rowids_.front();
644     Status s = Write(db, rowid, step, computed_time, t);
645     if (s.ok()) {
646       ++count_;
647     }
648     rowids_.pop_front();
649     return s;
650   }
651 
Finish(Sqlite * db)652   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
653       LOCKS_EXCLUDED(mu_) {
654     mutex_lock lock(mu_);
655     // Delete unused pre-allocated Tensors.
656     if (!rowids_.empty()) {
657       SqliteTransaction txn(*db);
658       const char* sql = R"sql(
659         DELETE FROM Tensors WHERE rowid = ?
660       )sql";
661       SqliteStatement deleter;
662       TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
663       for (size_t i = count_; i < rowids_.size(); ++i) {
664         deleter.BindInt(1, rowids_.front());
665         TF_RETURN_IF_ERROR(deleter.StepAndReset());
666         rowids_.pop_front();
667       }
668       TF_RETURN_IF_ERROR(txn.Commit());
669       rowids_.clear();
670     }
671     return Status::OK();
672   }
673 
674  private:
Write(Sqlite * db,int64 rowid,int64 step,double computed_time,const Tensor & t)675   Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time,
676                const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
677     if (t.dtype() == DT_STRING) {
678       if (t.dims() == 0) {
679         return Update(db, step, computed_time, t, t.scalar<string>()(), rowid);
680       } else {
681         SqliteTransaction txn(*db);
682         TF_RETURN_IF_ERROR(
683             Update(db, step, computed_time, t, StringPiece(), rowid));
684         TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
685         return txn.Commit();
686       }
687     } else {
688       return Update(db, step, computed_time, t, t.tensor_data(), rowid);
689     }
690   }
691 
Update(Sqlite * db,int64 step,double computed_time,const Tensor & t,const StringPiece & data,int64 rowid)692   Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
693                 const StringPiece& data, int64 rowid) {
694     const char* sql = R"sql(
695       UPDATE OR REPLACE
696         Tensors
697       SET
698         step = ?,
699         computed_time = ?,
700         dtype = ?,
701         shape = ?,
702         data = ?
703       WHERE
704         rowid = ?
705     )sql";
706     SqliteStatement stmt;
707     TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
708     stmt.BindInt(1, step);
709     stmt.BindDouble(2, computed_time);
710     stmt.BindInt(3, t.dtype());
711     stmt.BindText(4, StringifyShape(t.shape()));
712     stmt.BindBlobUnsafe(5, data);
713     stmt.BindInt(6, rowid);
714     TF_RETURN_IF_ERROR(stmt.StepAndReset());
715     return Status::OK();
716   }
717 
UpdateNdString(Sqlite * db,const Tensor & t,int64 tensor_rowid)718   Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid)
719       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
720     DCHECK_EQ(t.dtype(), DT_STRING);
721     DCHECK_GT(t.dims(), 0);
722     const char* deleter_sql = R"sql(
723       DELETE FROM TensorStrings WHERE tensor_rowid = ?
724     )sql";
725     SqliteStatement deleter;
726     TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
727     deleter.BindInt(1, tensor_rowid);
728     TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
729     const char* inserter_sql = R"sql(
730       INSERT INTO TensorStrings (
731         tensor_rowid,
732         idx,
733         data
734       ) VALUES (?, ?, ?)
735     )sql";
736     SqliteStatement inserter;
737     TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
738     auto flat = t.flat<string>();
739     for (int64 i = 0; i < flat.size(); ++i) {
740       inserter.BindInt(1, tensor_rowid);
741       inserter.BindInt(2, i);
742       inserter.BindBlobUnsafe(3, flat(i));
743       TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
744     }
745     return Status::OK();
746   }
747 
Reserve(Sqlite * db,const Tensor & t)748   Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
749       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
750     SqliteTransaction txn(*db);  // only for performance
751     unflushed_bytes_ = 0;
752     if (t.dtype() == DT_STRING) {
753       if (t.dims() == 0) {
754         TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size()));
755       } else {
756         TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
757       }
758     } else {
759       TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
760     }
761     return txn.Commit();
762   }
763 
ReserveData(Sqlite * db,SqliteTransaction * txn,size_t size)764   Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
765       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
766           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
767     int64 space =
768         static_cast<int64>(static_cast<double>(size) * kReserveMultiplier);
769     if (space < kReserveMinBytes) space = kReserveMinBytes;
770     return ReserveTensors(db, txn, space);
771   }
772 
ReserveTensors(Sqlite * db,SqliteTransaction * txn,int64 reserved_bytes)773   Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
774                         int64 reserved_bytes)
775       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
776           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
777     const char* sql = R"sql(
778       INSERT INTO Tensors (
779         series,
780         data
781       ) VALUES (?, ZEROBLOB(?))
782     )sql";
783     SqliteStatement insert;
784     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
785     // TODO(jart): Maybe preallocate index pages by setting step. This
786     //             is tricky because UPDATE OR REPLACE can have a side
787     //             effect of deleting preallocated rows.
788     for (int64 i = 0; i < kPreallocateRows; ++i) {
789       insert.BindInt(1, series_);
790       insert.BindInt(2, reserved_bytes);
791       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
792       rowids_.push_back(db->last_insert_rowid());
793       unflushed_bytes_ += reserved_bytes;
794       TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
795     }
796     return Status::OK();
797   }
798 
MaybeFlush(Sqlite * db,SqliteTransaction * txn)799   Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
800       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
801           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
802     if (unflushed_bytes_ >= kFlushBytes) {
803       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
804                                       unflushed_bytes_, " bytes");
805       unflushed_bytes_ = 0;
806     }
807     return Status::OK();
808   }
809 
810   mutex mu_;
811   const int64 series_;
812   RunMetadata* const meta_;
813   uint64 count_ GUARDED_BY(mu_) = 0;
814   std::deque<int64> rowids_ GUARDED_BY(mu_);
815   uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
816 
817   TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
818 };
819 
820 /// \brief Tensor writer for a single Run.
821 ///
822 /// This class farms out tensors to SeriesWriter instances. It also
823 /// keeps track of whether or not someone is watching the TensorBoard
824 /// GUI, so it can avoid writes when possible.
825 ///
826 /// This class is thread safe.
827 class RunWriter {
828  public:
RunWriter(RunMetadata * meta)829   explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
830 
Append(Sqlite * db,int64 tag_id,int64 step,uint64 now,double computed_time,const Tensor & t)831   Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
832                 double computed_time, const Tensor& t)
833       SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
834     SeriesWriter* writer = GetSeriesWriter(tag_id);
835     return writer->Append(db, step, now, computed_time, t);
836   }
837 
Finish(Sqlite * db)838   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
839       LOCKS_EXCLUDED(mu_) {
840     mutex_lock lock(mu_);
841     if (series_writers_.empty()) return Status::OK();
842     for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
843       if (!i->second) continue;
844       TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
845                                       "finish tag_id=", i->first);
846       i->second.reset();
847     }
848     return Status::OK();
849   }
850 
851  private:
GetSeriesWriter(int64 tag_id)852   SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) {
853     mutex_lock sl(mu_);
854     auto spot = series_writers_.find(tag_id);
855     if (spot == series_writers_.end()) {
856       SeriesWriter* writer = new SeriesWriter(tag_id, meta_);
857       series_writers_[tag_id].reset(writer);
858       return writer;
859     } else {
860       return spot->second.get();
861     }
862   }
863 
864   mutex mu_;
865   RunMetadata* const meta_;
866   std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_
867       GUARDED_BY(mu_);
868 
869   TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
870 };
871 
872 /// \brief SQLite implementation of SummaryWriterInterface.
873 ///
874 /// This class is thread safe.
875 class SummaryDbWriter : public SummaryWriterInterface {
876  public:
SummaryDbWriter(Env * env,Sqlite * db,const string & experiment_name,const string & run_name,const string & user_name)877   SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
878                   const string& run_name, const string& user_name)
879       : SummaryWriterInterface(),
880         env_{env},
881         db_{db},
882         ids_{env_, db_},
883         meta_{&ids_, experiment_name, run_name, user_name},
884         run_{&meta_} {
885     DCHECK(env_ != nullptr);
886     db_->Ref();
887   }
888 
~SummaryDbWriter()889   ~SummaryDbWriter() override {
890     core::ScopedUnref unref(db_);
891     Status s = run_.Finish(db_);
892     if (!s.ok()) {
893       // TODO(jart): Retry on transient errors here.
894       LOG(ERROR) << s.ToString();
895     }
896     int64 run_id = meta_.run_id();
897     if (run_id == kAbsent) return;
898     const char* sql = R"sql(
899       UPDATE Runs SET finished_time = ? WHERE run_id = ?
900     )sql";
901     SqliteStatement update;
902     s = db_->Prepare(sql, &update);
903     if (s.ok()) {
904       update.BindDouble(1, DoubleTime(env_->NowMicros()));
905       update.BindInt(2, run_id);
906       s = update.StepAndReset();
907     }
908     if (!s.ok()) {
909       LOG(ERROR) << "Failed to set Runs[" << run_id
910                  << "].finish_time: " << s.ToString();
911     }
912   }
913 
Flush()914   Status Flush() override { return Status::OK(); }
915 
WriteTensor(int64 global_step,Tensor t,const string & tag,const string & serialized_metadata)916   Status WriteTensor(int64 global_step, Tensor t, const string& tag,
917                      const string& serialized_metadata) override {
918     TF_RETURN_IF_ERROR(CheckSupportedType(t));
919     SummaryMetadata metadata;
920     if (!metadata.ParseFromString(serialized_metadata)) {
921       return errors::InvalidArgument("Bad serialized_metadata");
922     }
923     return Write(global_step, t, tag, metadata);
924   }
925 
WriteScalar(int64 global_step,Tensor t,const string & tag)926   Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
927     TF_RETURN_IF_ERROR(CheckSupportedType(t));
928     SummaryMetadata metadata;
929     PatchPluginName(&metadata, kScalarPluginName);
930     return Write(global_step, AsScalar(t), tag, metadata);
931   }
932 
WriteGraph(int64 global_step,std::unique_ptr<GraphDef> g)933   Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
934     uint64 now = env_->NowMicros();
935     return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
936   }
937 
WriteEvent(std::unique_ptr<Event> e)938   Status WriteEvent(std::unique_ptr<Event> e) override {
939     return MigrateEvent(std::move(e));
940   }
941 
WriteHistogram(int64 global_step,Tensor t,const string & tag)942   Status WriteHistogram(int64 global_step, Tensor t,
943                         const string& tag) override {
944     uint64 now = env_->NowMicros();
945     std::unique_ptr<Event> e{new Event};
946     e->set_step(global_step);
947     e->set_wall_time(DoubleTime(now));
948     TF_RETURN_IF_ERROR(
949         AddTensorAsHistogramToSummary(t, tag, e->mutable_summary()));
950     return MigrateEvent(std::move(e));
951   }
952 
WriteImage(int64 global_step,Tensor t,const string & tag,int max_images,Tensor bad_color)953   Status WriteImage(int64 global_step, Tensor t, const string& tag,
954                     int max_images, Tensor bad_color) override {
955     uint64 now = env_->NowMicros();
956     std::unique_ptr<Event> e{new Event};
957     e->set_step(global_step);
958     e->set_wall_time(DoubleTime(now));
959     TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color,
960                                                  e->mutable_summary()));
961     return MigrateEvent(std::move(e));
962   }
963 
WriteAudio(int64 global_step,Tensor t,const string & tag,int max_outputs,float sample_rate)964   Status WriteAudio(int64 global_step, Tensor t, const string& tag,
965                     int max_outputs, float sample_rate) override {
966     uint64 now = env_->NowMicros();
967     std::unique_ptr<Event> e{new Event};
968     e->set_step(global_step);
969     e->set_wall_time(DoubleTime(now));
970     TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary(
971         t, tag, max_outputs, sample_rate, e->mutable_summary()));
972     return MigrateEvent(std::move(e));
973   }
974 
DebugString() const975   string DebugString() const override { return "SummaryDbWriter"; }
976 
977  private:
Write(int64 step,const Tensor & t,const string & tag,const SummaryMetadata & metadata)978   Status Write(int64 step, const Tensor& t, const string& tag,
979                const SummaryMetadata& metadata) {
980     uint64 now = env_->NowMicros();
981     double computed_time = DoubleTime(now);
982     int64 tag_id;
983     TF_RETURN_IF_ERROR(
984         meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
985     TF_RETURN_WITH_CONTEXT_IF_ERROR(
986         run_.Append(db_, tag_id, step, now, computed_time, t),
987         meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
988         "/", tag, "@", step);
989     return Status::OK();
990   }
991 
MigrateEvent(std::unique_ptr<Event> e)992   Status MigrateEvent(std::unique_ptr<Event> e) {
993     switch (e->what_case()) {
994       case Event::WhatCase::kSummary: {
995         uint64 now = env_->NowMicros();
996         auto summaries = e->mutable_summary();
997         for (int i = 0; i < summaries->value_size(); ++i) {
998           Summary::Value* value = summaries->mutable_value(i);
999           TF_RETURN_WITH_CONTEXT_IF_ERROR(
1000               MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
1001               meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
1002               "@", e->step());
1003         }
1004         break;
1005       }
1006       case Event::WhatCase::kGraphDef:
1007         TF_RETURN_WITH_CONTEXT_IF_ERROR(
1008             MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
1009             meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
1010             e->step());
1011         break;
1012       default:
1013         // TODO(@jart): Handle other stuff.
1014         break;
1015     }
1016     return Status::OK();
1017   }
1018 
MigrateGraph(const Event * e,const string & graph_def)1019   Status MigrateGraph(const Event* e, const string& graph_def) {
1020     uint64 now = env_->NowMicros();
1021     std::unique_ptr<GraphDef> graph{new GraphDef};
1022     if (!ParseProtoUnlimited(graph.get(), graph_def)) {
1023       return errors::InvalidArgument("bad proto");
1024     }
1025     return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
1026   }
1027 
MigrateSummary(const Event * e,Summary::Value * s,uint64 now)1028   Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
1029     switch (s->value_case()) {
1030       case Summary::Value::ValueCase::kTensor:
1031         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
1032         break;
1033       case Summary::Value::ValueCase::kSimpleValue:
1034         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
1035         break;
1036       case Summary::Value::ValueCase::kHisto:
1037         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
1038         break;
1039       case Summary::Value::ValueCase::kImage:
1040         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
1041         break;
1042       case Summary::Value::ValueCase::kAudio:
1043         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
1044         break;
1045       default:
1046         break;
1047     }
1048     return Status::OK();
1049   }
1050 
MigrateTensor(const Event * e,Summary::Value * s,uint64 now)1051   Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
1052     Tensor t;
1053     if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
1054     TF_RETURN_IF_ERROR(CheckSupportedType(t));
1055     int64 tag_id;
1056     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1057                                       &tag_id, s->metadata()));
1058     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1059   }
1060 
1061   // TODO(jart): Refactor Summary -> Tensor logic into separate file.
1062 
MigrateScalar(const Event * e,Summary::Value * s,uint64 now)1063   Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
1064     // See tensorboard/plugins/scalar/summary.py and data_compat.py
1065     Tensor t{DT_FLOAT, {}};
1066     t.scalar<float>()() = s->simple_value();
1067     int64 tag_id;
1068     PatchPluginName(s->mutable_metadata(), kScalarPluginName);
1069     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1070                                       &tag_id, s->metadata()));
1071     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1072   }
1073 
MigrateHistogram(const Event * e,Summary::Value * s,uint64 now)1074   Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
1075     const HistogramProto& histo = s->histo();
1076     int k = histo.bucket_size();
1077     if (k != histo.bucket_limit_size()) {
1078       return errors::InvalidArgument("size mismatch");
1079     }
1080     // See tensorboard/plugins/histogram/summary.py and data_compat.py
1081     Tensor t{DT_DOUBLE, {k, 3}};
1082     auto data = t.flat<double>();
1083     for (int i = 0, j = 0; i < k; ++i) {
1084       // TODO(nickfelt): reconcile with TensorBoard's data_compat.py
1085       // From summary.proto
1086       // Parallel arrays encoding the bucket boundaries and the bucket values.
1087       // bucket(i) is the count for the bucket i.  The range for
1088       // a bucket is:
1089       //   i == 0:  -DBL_MAX .. bucket_limit(0)
1090       //   i != 0:  bucket_limit(i-1) .. bucket_limit(i)
1091       double left_edge = (i == 0) ? std::numeric_limits<double>::min()
1092                                   : histo.bucket_limit(i - 1);
1093 
1094       data(j++) = left_edge;
1095       data(j++) = histo.bucket_limit(i);
1096       data(j++) = histo.bucket(i);
1097     }
1098     int64 tag_id;
1099     PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
1100     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1101                                       &tag_id, s->metadata()));
1102     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1103   }
1104 
MigrateImage(const Event * e,Summary::Value * s,uint64 now)1105   Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
1106     // See tensorboard/plugins/image/summary.py and data_compat.py
1107     Tensor t{DT_STRING, {3}};
1108     auto img = s->mutable_image();
1109     t.flat<string>()(0) = strings::StrCat(img->width());
1110     t.flat<string>()(1) = strings::StrCat(img->height());
1111     t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string());
1112     int64 tag_id;
1113     PatchPluginName(s->mutable_metadata(), kImagePluginName);
1114     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1115                                       &tag_id, s->metadata()));
1116     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1117   }
1118 
MigrateAudio(const Event * e,Summary::Value * s,uint64 now)1119   Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
1120     // See tensorboard/plugins/audio/summary.py and data_compat.py
1121     Tensor t{DT_STRING, {1, 2}};
1122     auto wav = s->mutable_audio();
1123     t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string());
1124     t.flat<string>()(1) = "";
1125     int64 tag_id;
1126     PatchPluginName(s->mutable_metadata(), kAudioPluginName);
1127     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1128                                       &tag_id, s->metadata()));
1129     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1130   }
1131 
1132   Env* const env_;
1133   Sqlite* const db_;
1134   IdAllocator ids_;
1135   RunMetadata meta_;
1136   RunWriter run_;
1137 };
1138 
1139 }  // namespace
1140 
CreateSummaryDbWriter(Sqlite * db,const string & experiment_name,const string & run_name,const string & user_name,Env * env,SummaryWriterInterface ** result)1141 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
1142                              const string& run_name, const string& user_name,
1143                              Env* env, SummaryWriterInterface** result) {
1144   *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
1145   return Status::OK();
1146 }
1147 
1148 }  // namespace tensorflow
1149