• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/costmodel.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/allocation_description.pb.h"
22 #include "tensorflow/core/framework/cost_graph.pb.h"
23 #include "tensorflow/core/framework/step_stats.pb.h"
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 namespace {
30 const Microseconds kDefaultTimeEstimate(1);
31 const Microseconds kMinTimeEstimate(1);
32 }  // namespace
33 
SuppressInfrequent()34 void CostModel::SuppressInfrequent() {
35   // Find the median of the non-zero counts, and use half of its value
36   // as the cutoff for a "normal" execution mode node.
37   if (count_.empty()) return;
38   std::vector<int32> non_zero;
39   for (auto v : count_) {
40     if (v > 0) non_zero.push_back(v);
41   }
42   const size_t sz = non_zero.size();
43   if (sz > 0) {
44     std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2,
45                      non_zero.end());
46     int32 median_value = non_zero[sz / 2];
47     min_count_ = median_value / 2;
48     VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value "
49             << median_value;
50   } else {
51     min_count_ = 1;
52   }
53 }
54 
MergeFromLocal(const Graph & g,const CostModel & cm)55 void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
56   CHECK(is_global_);
57   CHECK(!cm.is_global());
58   for (const Node* n : g.nodes()) {
59     const int local_id = cm.Id(n);
60     const int global_id = Id(n);
61     if (local_id < 0 || global_id < 0) continue;
62     int num_slots = cm.slot_bytes_[local_id].size();
63     Ensure(global_id, num_slots);
64     count_[global_id] += cm.count_[local_id];
65     time_[global_id] += cm.time_[local_id];
66     if (num_slots > 0) {
67       if (slot_bytes_[global_id].empty()) {
68         slot_bytes_[global_id].resize(num_slots);
69       } else {
70         CHECK_EQ(num_slots, slot_bytes_[global_id].size());
71       }
72       for (int s = 0; s < num_slots; ++s) {
73         slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s];
74       }
75     }
76   }
77 }
78 
MergeFromGlobal(const CostModel & cm)79 void CostModel::MergeFromGlobal(const CostModel& cm) {
80   CHECK(is_global_);
81   CHECK_EQ(true, cm.is_global());
82   const int num_nodes = cm.count_.size();
83   for (int i = num_nodes - 1; i >= 0; --i) {
84     count_[i] += cm.count_[i];
85     time_[i] += cm.time_[i];
86     int num_slots = cm.slot_bytes_[i].size();
87     Ensure(i, num_slots);
88     if (num_slots > 0) {
89       if (slot_bytes_[i].empty()) {
90         slot_bytes_[i].resize(num_slots);
91       } else {
92         CHECK_EQ(num_slots, slot_bytes_[i].size());
93       }
94       for (int s = 0; s < num_slots; ++s) {
95         slot_bytes_[i][s] += cm.slot_bytes_[i][s];
96       }
97     }
98   }
99 }
100 
MergeFromStats(const NodeNameToCostIdMap & map,const StepStats & ss)101 void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
102                                const StepStats& ss) {
103   CHECK(is_global_);
104   for (auto& ds : ss.dev_stats()) {
105     for (auto& ns : ds.node_stats()) {
106       NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name());
107       // We don't keep stats for nodes not in the global graph, i.e.
108       // copy/send/recv nodes, feed/fetch, etc.
109       if (iter == map.end()) continue;
110       int32 global_id = iter->second;
111       Ensure(global_id, ns.output_size());
112       int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
113       count_[global_id]++;
114       time_[global_id] += elapsed_micros;
115       for (auto& no : ns.output()) {
116         int si = no.slot();
117         if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) {
118           slot_bytes_[global_id].resize(1 + si);
119         }
120         slot_bytes_[global_id][si] +=
121             no.tensor_description().allocation_description().requested_bytes();
122       }
123     }
124   }
125 }
126 
Ensure(int id,int num_outputs)127 void CostModel::Ensure(int id, int num_outputs) {
128   if (slot_bytes_.size() <= static_cast<size_t>(id)) {
129     slot_bytes_.resize(id + 1);
130     count_.resize(id + 1);
131     time_.resize(id + 1);
132     max_mem_usage_.resize(id + 1);
133     max_exec_time_.resize(id + 1);
134     output_port_alloc_ids_.resize(id + 1);
135   }
136   if (num_outputs > 0) {
137     auto perslot = &slot_bytes_[id];
138     auto output_port_alloc_ids = &output_port_alloc_ids_[id];
139     auto max_mem_usage = &max_mem_usage_[id];
140 
141     CHECK_LE(perslot->size(), num_outputs);
142     DCHECK_EQ(output_port_alloc_ids->size(), perslot->size());
143     DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size());
144     DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size());
145     DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size());
146 
147     perslot->resize(num_outputs, Bytes(-1));
148     output_port_alloc_ids->resize(num_outputs, -1);
149     max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
150     max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
151     max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
152   }
153 }
154 
SetNumOutputs(const Node * node,int num_outputs)155 void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
156   const int id = Id(node);
157   if (id < 0) return;
158   // Do not resize the number of slots before checking its existing number of
159   // slots.
160   Ensure(id, 0);
161   auto perslot = &slot_bytes_[id];
162   if (!perslot->empty()) {
163     CHECK_EQ(num_outputs, perslot->size())
164         << "Cannot resize slot_bytes, node=" << node->name();
165   }
166   Ensure(id, num_outputs);
167 }
168 
RecordCount(const Node * node,int count)169 void CostModel::RecordCount(const Node* node, int count) {
170   const int id = Id(node);
171   if (id < 0) return;
172   CHECK_LT(id, slot_bytes_.size());
173   count_[id] += count;
174 }
175 
TotalCount(const Node * node) const176 int32 CostModel::TotalCount(const Node* node) const {
177   const int id = Id(node);
178   if (id < 0) return 0;
179   return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0;
180 }
181 
RecordSize(const Node * node,int slot,Bytes bytes)182 void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) {
183   const int id = Id(node);
184   if (id < 0) return;
185   CHECK_LT(id, slot_bytes_.size());
186   auto perslot = &slot_bytes_[id];
187   CHECK_LT(slot, perslot->size());
188   auto v = &(*perslot)[slot];
189   if (*v >= 0) {
190     *v += bytes;
191   } else {
192     *v = bytes;
193   }
194 }
195 
TotalBytes(const Node * node,int slot) const196 Bytes CostModel::TotalBytes(const Node* node, int slot) const {
197   const int id = Id(node);
198   if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
199       slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
200     return Bytes(0);
201   }
202   return slot_bytes_[id][slot];
203 }
204 
SizeEstimate(const Node * node,int slot) const205 Bytes CostModel::SizeEstimate(const Node* node, int slot) const {
206   int32 count = TotalCount(node);
207   if (count < min_count_) return Bytes(0);
208   return TotalBytes(node, slot) / std::max(1, TotalCount(node));
209 }
210 
RecordTime(const Node * node,Microseconds time)211 void CostModel::RecordTime(const Node* node, Microseconds time) {
212   const int id = Id(node);
213   if (id < 0) return;
214   DCHECK(node->IsOp()) << node->DebugString();
215   Ensure(id, node->num_outputs());
216   time_[id] += time;
217 }
218 
TotalTime(const Node * node) const219 Microseconds CostModel::TotalTime(const Node* node) const {
220   DCHECK(node->IsOp()) << node->DebugString();
221   const int id = Id(node);
222   if (id < 0 || static_cast<size_t>(id) >= time_.size() ||
223       time_[id] < Microseconds(0)) {
224     return Microseconds(0);
225   }
226   return time_[id];
227 }
228 
TimeEstimate(const Node * node) const229 Microseconds CostModel::TimeEstimate(const Node* node) const {
230   int32 count = TotalCount(node);
231   if (count <= min_count_) return kMinTimeEstimate;
232   return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count));
233 }
234 
CheckInitialized(const Graph & graph) const235 void CostModel::CheckInitialized(const Graph& graph) const {
236   for (const Node* n : graph.op_nodes()) {
237     CHECK(static_cast<size_t>(n->id()) < time_.size() &&
238           time_[n->id()] >= Microseconds(0))
239         << ": no time estimate for " << n->DebugString();
240 
241     CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
242         << ": no size estimate for " << n->DebugString();
243     const auto& perslot = slot_bytes_[n->id()];
244     for (size_t i = 0; i < perslot.size(); i++) {
245       CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
246                                      << " of " << n->DebugString();
247     }
248   }
249 }
250 
RecordMaxMemorySize(const Node * node,int output_slot,Bytes bytes,const TensorShapeProto & tensor_shape,const DataType & dtype)251 void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
252                                     Bytes bytes,
253                                     const TensorShapeProto& tensor_shape,
254                                     const DataType& dtype) {
255   const int id = Id(node);
256   if (id < 0) return;
257   if (output_slot >= node->num_outputs()) {
258     LOG(ERROR) << "Unexpected output slot for node " << node->DebugString()
259                << ". Got " << output_slot << " but its num_outputs is "
260                << node->num_outputs();
261     return;
262   }
263   Ensure(id, node->num_outputs());
264   auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
265   // If the memory allocator doesn't track memory usage, let's infer a lower
266   // bound from the tensor shape and its data type.
267   if (bytes.value() < 0) {
268     bytes = MinTensorMemoryUsage(tensor_shape, dtype);
269   }
270   if (bytes.value() > current_max.value()) {
271     current_max = bytes.value();
272     max_mem_usage_[id].output_port_shape[output_slot] = tensor_shape;
273     max_mem_usage_[id].output_port_type[output_slot] = dtype;
274   }
275 }
276 
MaxMemorySize(const Node * node,int slot) const277 Bytes CostModel::MaxMemorySize(const Node* node, int slot) const {
278   const int id = Id(node);
279   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
280       max_mem_usage_[id].output_port_mem.size() <= static_cast<size_t>(slot)) {
281     return Bytes(0);
282   }
283   return max_mem_usage_[id].output_port_mem[slot];
284 }
285 
MaxMemoryShape(const Node * node,int slot) const286 const TensorShapeProto& CostModel::MaxMemoryShape(const Node* node,
287                                                   int slot) const {
288   const int id = Id(node);
289   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
290       max_mem_usage_[id].output_port_shape.size() <=
291           static_cast<size_t>(slot)) {
292     return unknown_shape_;
293   }
294   return max_mem_usage_[id].output_port_shape[slot];
295 }
296 
MaxMemoryType(const Node * node,int slot) const297 DataType CostModel::MaxMemoryType(const Node* node, int slot) const {
298   const int id = Id(node);
299   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
300       max_mem_usage_[id].output_port_type.size() <= static_cast<size_t>(slot)) {
301     return DT_INVALID;
302   }
303   return max_mem_usage_[id].output_port_type[slot];
304 }
305 
TempMemorySize(const Node * node) const306 Bytes CostModel::TempMemorySize(const Node* node) const {
307   const int id = Id(node);
308   if (id < 0) {
309     return Bytes(0);
310   }
311   return max_mem_usage_[id].temp_memory_size;
312 }
313 
PersistentMemorySize(const Node * node) const314 Bytes CostModel::PersistentMemorySize(const Node* node) const {
315   const int id = Id(node);
316   if (id < 0) {
317     return Bytes(0);
318   }
319   return max_mem_usage_[id].persistent_memory_size;
320 }
321 
RecordMemoryStats(const Node * node,const MemoryStats & memory_stats)322 void CostModel::RecordMemoryStats(const Node* node,
323                                   const MemoryStats& memory_stats) {
324   const int id = Id(node);
325   if (id < 0) return;
326   max_mem_usage_[id].temp_memory_size = memory_stats.temp_memory_size();
327   max_mem_usage_[id].persistent_memory_size =
328       memory_stats.persistent_memory_size();
329   for (int64 alloc_id : memory_stats.persistent_tensor_alloc_ids()) {
330     if (alloc_id > 0) {
331       persistent_alloc_ids_.insert(alloc_id);
332     }
333   }
334 }
335 
RecordMaxExecutionTime(const Node * node,Microseconds time)336 void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) {
337   const int id = Id(node);
338   if (id < 0) return;
339   Ensure(id, node->num_outputs());
340   max_exec_time_[id] = std::max(max_exec_time_[id], time);
341 }
342 
MaxExecutionTime(const Node * node) const343 Microseconds CostModel::MaxExecutionTime(const Node* node) const {
344   const int id = Id(node);
345   if (id < 0 || static_cast<size_t>(id) >= max_exec_time_.size()) {
346     return Microseconds(0);
347   }
348   return max_exec_time_[id];
349 }
350 
RecordAllocationId(const Node * node,int output_slot,int64 alloc_id)351 void CostModel::RecordAllocationId(const Node* node, int output_slot,
352                                    int64 alloc_id) {
353   const int id = Id(node);
354   if (id < 0) return;
355   Ensure(id, node->num_outputs());
356   output_port_alloc_ids_[id][output_slot] = alloc_id;
357 }
358 
AllocationId(const Node * node,int slot) const359 int64 CostModel::AllocationId(const Node* node, int slot) const {
360   const int id = Id(node);
361   if (id < 0 || static_cast<size_t>(id) >= output_port_alloc_ids_.size() ||
362       output_port_alloc_ids_[id].size() <= static_cast<size_t>(slot)) {
363     return -1;
364   }
365   return output_port_alloc_ids_[id][slot];
366 }
367 
IsPersistentTensor(const Node * node,int64 alloc_id) const368 bool CostModel::IsPersistentTensor(const Node* node, int64 alloc_id) const {
369   if (persistent_alloc_ids_.count(alloc_id) > 0) {
370     return true;
371   }
372   if (persistent_alloc_ids_by_devices_.find(node->assigned_device_name()) ==
373       persistent_alloc_ids_by_devices_.end()) {
374     return false;
375   }
376   return persistent_alloc_ids_by_devices_.at(node->assigned_device_name())
377       .count(alloc_id);
378 }
379 
CopyTimeEstimate(Bytes b,double network_latency_millis,double estimated_gbps)380 Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
381                                          double estimated_gbps) {
382   // TODO(jeff,sanjay): estimate cost based on bandwidth along the
383   // communication path and the type of transport we are using between
384   // devices.
385   //
386   // We assume the copy time follows a linear model:
387   //    copy_time = copy_bytes / rate + min_time
388   int64 copy_bytes = b.value();
389   const double bytes_per_usec = estimated_gbps * 1000.0 / 8;
390   const double min_micros = network_latency_millis * 1000.0;
391   return Microseconds(
392       static_cast<int64>(copy_bytes / bytes_per_usec + min_micros));
393 }
394 
ComputationTimeEstimate(int64 math_ops)395 Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) {
396   // TODO(jeff,sanjay): Eventually we should pass in the type of device
397   // (GPU vs. CPU) and use that to affect the estimate.
398 
399   // We estimate the microseconds using that value.  We divide
400   // by 1000 to convert the madd number into microseconds (assuming
401   // roughly 1000 madds per microsecond (~1 GHz for one core)).
402   return Microseconds(math_ops / 1000);
403 }
404 
IncrementUpdateTimes()405 void CostModel::IncrementUpdateTimes() { update_times_++; }
406 
GetUpdateTimes() const407 int32 CostModel::GetUpdateTimes() const { return update_times_; }
408 
409 // ----------------------------------------------------------------------------
410 // InitCostModel
411 // ----------------------------------------------------------------------------
412 
413 namespace {
414 
AddNodesToCostModel(const Graph & g,CostModel * cost_model)415 static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) {
416   for (Node* n : g.nodes()) {
417     const int num_outputs = n->num_outputs();
418     cost_model->SetNumOutputs(n, num_outputs);
419     for (int output = 0; output < num_outputs; output++) {
420       // Set up an initial bogus estimate for the node's outputs
421       cost_model->RecordSize(n, output, Bytes(1));
422     }
423   }
424 }
425 
AssignSizes(const Graph & g,CostModel * cost_model)426 static void AssignSizes(const Graph& g, CostModel* cost_model) {
427   for (const Edge* e : g.edges()) {
428     // Skip if it is a control edge.
429     if (e->IsControlEdge()) {
430       continue;
431     }
432     const Node* src = e->src();
433 
434     // TODO(josh11b): Get an estimate from the Op
435     Bytes size(1);
436     cost_model->RecordSize(src, e->src_output(), size);
437   }
438 }
439 
440 // This generates an extremely simple initial guess for the
441 // computation cost of each node. For ordinary Ops, its value should quickly
442 // be wiped out by the real runtime measurements.  For other Ops we don't
443 // actually generate measurements, so suppression of infrequent Ops ends up
444 // giving them 0 costs.  So, this is not of much consequence except perhaps
445 // in tests.
TimeEstimateForNode(CostModel * cost_model,Node * n)446 static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) {
447   CHECK(n->IsOp());
448   VLOG(2) << "Node " << n->id() << ": " << n->name()
449           << " type_string: " << n->type_string();
450   if (IsConstant(n) || IsVariable(n)) {
451     return Microseconds(0);
452   }
453   return kDefaultTimeEstimate;
454 }
455 
EstimateComputationCosts(const Graph & g,CostModel * cost_model)456 static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) {
457   for (Node* n : g.nodes()) {
458     if (!n->IsOp()) continue;
459     cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n));
460   }
461 }
462 
463 }  // namespace
464 
InitFromGraph(const Graph & g)465 void CostModel::InitFromGraph(const Graph& g) {
466   const int num_node_ids = g.num_node_ids();
467   slot_bytes_.reserve(num_node_ids);
468   count_.reserve(num_node_ids);
469   time_.reserve(num_node_ids);
470   max_mem_usage_.reserve(num_node_ids);
471   max_exec_time_.reserve(num_node_ids);
472   output_port_alloc_ids_.reserve(num_node_ids);
473 
474   AddNodesToCostModel(g, this);
475   AssignSizes(g, this);
476   EstimateComputationCosts(g, this);
477   CheckInitialized(g);
478 }
479 
AddToCostGraphDef(const Graph * graph,CostGraphDef * cost_graph) const480 void CostModel::AddToCostGraphDef(const Graph* graph,
481                                   CostGraphDef* cost_graph) const {
482   std::vector<const Edge*> inputs;
483   std::vector<const Edge*> control_inputs;
484   int offset = cost_graph->node_size();
485   for (const Node* n : graph->nodes()) {
486     CostGraphDef::Node* cnode = cost_graph->add_node();
487     cnode->set_name(n->name());
488     cnode->set_device(n->assigned_device_name());
489     cnode->set_id(GlobalId(n, offset));
490 
491     inputs.clear();
492     inputs.resize(n->num_inputs(), nullptr);
493     control_inputs.clear();
494     for (const Edge* e : n->in_edges()) {
495       if (e->IsControlEdge()) {
496         control_inputs.push_back(e);
497       } else {
498         inputs[e->dst_input()] = e;
499       }
500     }
501     std::sort(control_inputs.begin(), control_inputs.end(),
502               [this](Edge const* a, Edge const* b) {
503                 return Id(a->src()) < Id(b->src());
504               });
505 
506     for (const Edge* e : inputs) {
507       CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
508       input_info->set_preceding_node(GlobalId(e->src(), offset));
509       input_info->set_preceding_port(e->src_output());
510     }
511 
512     for (int i = 0; i < n->num_outputs(); i++) {
513       CostGraphDef::Node::OutputInfo* output_info = cnode->add_output_info();
514       int64 alloc_id = AllocationId(n, i);
515       int64 alias_to_input = -1;
516       for (const Edge* e : inputs) {
517         int64 input_alloc_id = AllocationId(e->src(), e->src_output());
518         if (input_alloc_id == alloc_id) {
519           alias_to_input = e->dst_input();
520           break;
521         }
522       }
523       output_info->set_alias_input_port(alias_to_input);
524       output_info->set_dtype(MaxMemoryType(n, i));
525       *output_info->mutable_shape() = MaxMemoryShape(n, i);
526       if (alias_to_input < 0 && IsPersistentTensor(n, alloc_id)) {
527         output_info->set_size(0);
528       } else {
529         output_info->set_size(MaxMemorySize(n, i).value());
530       }
531     }
532 
533     for (const Edge* e : control_inputs) {
534       cnode->add_control_input(GlobalId(e->src(), offset));
535     }
536 
537     cnode->set_temporary_memory_size(TempMemorySize(n).value());
538     cnode->set_persistent_memory_size(PersistentMemorySize(n).value());
539 
540     cnode->set_compute_cost(MaxExecutionTime(n).value());
541 
542     // For now we treat all send nodes as final.
543     // TODO(yuanbyu): Send nodes for fetches shouldn't be treated as final.
544     cnode->set_is_final(n->IsSend());
545   }
546 }
547 
WriteSummaryToLog() const548 void CostModel::WriteSummaryToLog() const {
549   LOG(INFO) << " min_count_=" << min_count_;
550   for (size_t i = 0; i < count_.size(); ++i) {
551     LOG(INFO) << "Node " << i << " count " << count_[i] << " total time "
552               << time_[i] << " avg time "
553               << (time_[i] / (std::max(1, count_[i])));
554   }
555 }
556 
MinTensorMemoryUsage(const TensorShapeProto & tensor_shape,const DataType & dtype)557 Bytes CostModel::MinTensorMemoryUsage(const TensorShapeProto& tensor_shape,
558                                       const DataType& dtype) {
559   if (tensor_shape.unknown_rank()) {
560     return Bytes(-1);
561   }
562 
563   size_t num_coefficients = 1;
564   for (const TensorShapeProto::Dim& dim : tensor_shape.dim()) {
565     // If the dimension is unknown, it has to be at least 1
566     num_coefficients *= std::max<size_t>(dim.size(), 1);
567   }
568   return Bytes(num_coefficients * DataTypeSize(dtype));
569 }
570 
571 }  // namespace tensorflow
572