• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/costmodel.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/allocation_description.pb.h"
20 #include "tensorflow/core/framework/cost_graph.pb.h"
21 #include "tensorflow/core/framework/step_stats.pb.h"
22 #include "tensorflow/core/framework/tensor_description.pb.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace tensorflow {
27 namespace {
28 const Microseconds kDefaultTimeEstimate(1);
29 const Microseconds kMinTimeEstimate(1);
30 }  // namespace
31 
SuppressInfrequent()32 void CostModel::SuppressInfrequent() {
33   // Find the median of the non-zero counts, and use half of its value
34   // as the cutoff for a "normal" execution mode node.
35   if (count_.empty()) return;
36   std::vector<int32> non_zero;
37   for (auto v : count_) {
38     if (v > 0) non_zero.push_back(v);
39   }
40   const size_t sz = non_zero.size();
41   if (sz > 0) {
42     std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2,
43                      non_zero.end());
44     int32 median_value = non_zero[sz / 2];
45     min_count_ = median_value / 2;
46     VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value "
47             << median_value;
48   } else {
49     min_count_ = 1;
50   }
51 }
52 
MergeFromLocal(const Graph & g,const CostModel & cm)53 void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
54   CHECK(is_global_);
55   CHECK(!cm.is_global());
56   for (const Node* n : g.nodes()) {
57     const int local_id = cm.Id(n);
58     const int global_id = Id(n);
59     if (local_id < 0 || global_id < 0) continue;
60     int num_slots = cm.slot_bytes_[local_id].size();
61     Ensure(global_id, num_slots);
62     count_[global_id] += cm.count_[local_id];
63     time_[global_id] += cm.time_[local_id];
64     if (num_slots > 0) {
65       if (slot_bytes_[global_id].empty()) {
66         slot_bytes_[global_id].resize(num_slots);
67       } else {
68         CHECK_EQ(num_slots, slot_bytes_[global_id].size());
69       }
70       for (int s = 0; s < num_slots; ++s) {
71         slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s];
72       }
73     }
74   }
75 }
76 
MergeFromGlobal(const CostModel & cm)77 void CostModel::MergeFromGlobal(const CostModel& cm) {
78   CHECK(is_global_);
79   CHECK_EQ(true, cm.is_global());
80   const int num_nodes = cm.count_.size();
81   for (int i = num_nodes - 1; i >= 0; --i) {
82     count_[i] += cm.count_[i];
83     time_[i] += cm.time_[i];
84     int num_slots = cm.slot_bytes_[i].size();
85     Ensure(i, num_slots);
86     if (num_slots > 0) {
87       if (slot_bytes_[i].empty()) {
88         slot_bytes_[i].resize(num_slots);
89       } else {
90         CHECK_EQ(num_slots, slot_bytes_[i].size());
91       }
92       for (int s = 0; s < num_slots; ++s) {
93         slot_bytes_[i][s] += cm.slot_bytes_[i][s];
94       }
95     }
96   }
97 }
98 
MergeFromStats(const NodeNameToCostIdMap & map,const StepStats & ss)99 void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
100                                const StepStats& ss) {
101   CHECK(is_global_);
102   for (auto& ds : ss.dev_stats()) {
103     for (auto& ns : ds.node_stats()) {
104       NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name());
105       // We don't keep stats for nodes not in the global graph, i.e.
106       // copy/send/recv nodes, feed/fetch, etc.
107       if (iter == map.end()) continue;
108       int32 global_id = iter->second;
109       Ensure(global_id, ns.output_size());
110       int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
111       count_[global_id]++;
112       time_[global_id] += elapsed_micros;
113       for (auto& no : ns.output()) {
114         int si = no.slot();
115         if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) {
116           slot_bytes_[global_id].resize(1 + si);
117         }
118         slot_bytes_[global_id][si] +=
119             no.tensor_description().allocation_description().requested_bytes();
120       }
121     }
122   }
123 }
124 
Ensure(int id,int num_outputs)125 void CostModel::Ensure(int id, int num_outputs) {
126   if (slot_bytes_.size() <= static_cast<size_t>(id)) {
127     slot_bytes_.resize(id + 1);
128     count_.resize(id + 1);
129     time_.resize(id + 1);
130     max_mem_usage_.resize(id + 1);
131     max_exec_time_.resize(id + 1);
132     output_port_alloc_ids_.resize(id + 1);
133   }
134   if (num_outputs > 0) {
135     auto perslot = &slot_bytes_[id];
136     auto output_port_alloc_ids = &output_port_alloc_ids_[id];
137     auto max_mem_usage = &max_mem_usage_[id];
138 
139     CHECK_LE(perslot->size(), num_outputs);
140     DCHECK_EQ(output_port_alloc_ids->size(), perslot->size());
141     DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size());
142     DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size());
143     DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size());
144 
145     perslot->resize(num_outputs, Bytes(-1));
146     output_port_alloc_ids->resize(num_outputs, -1);
147     max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
148     max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
149     max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
150   }
151 }
152 
SetNumOutputs(const Node * node,int num_outputs)153 void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
154   const int id = Id(node);
155   if (id < 0) return;
156   // Do not resize the number of slots before checking its existing number of
157   // slots.
158   Ensure(id, 0);
159   auto perslot = &slot_bytes_[id];
160   if (!perslot->empty()) {
161     CHECK_EQ(num_outputs, perslot->size())
162         << "Cannot resize slot_bytes, node=" << node->name();
163   }
164   Ensure(id, num_outputs);
165 }
166 
RecordCount(const Node * node,int count)167 void CostModel::RecordCount(const Node* node, int count) {
168   const int id = Id(node);
169   if (id < 0) return;
170   CHECK_LT(id, slot_bytes_.size());
171   count_[id] += count;
172 }
173 
TotalCount(const Node * node) const174 int32 CostModel::TotalCount(const Node* node) const {
175   const int id = Id(node);
176   if (id < 0) return 0;
177   return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0;
178 }
179 
RecordSize(const Node * node,int slot,Bytes bytes)180 void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) {
181   const int id = Id(node);
182   if (id < 0) return;
183   CHECK_LT(id, slot_bytes_.size());
184   auto perslot = &slot_bytes_[id];
185   CHECK_LT(slot, perslot->size());
186   auto v = &(*perslot)[slot];
187   if (*v >= 0) {
188     *v += bytes;
189   } else {
190     *v = bytes;
191   }
192 }
193 
TotalBytes(const Node * node,int slot) const194 Bytes CostModel::TotalBytes(const Node* node, int slot) const {
195   const int id = Id(node);
196   if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
197       slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
198     return Bytes(0);
199   }
200   return slot_bytes_[id][slot];
201 }
202 
SizeEstimate(const Node * node,int slot) const203 Bytes CostModel::SizeEstimate(const Node* node, int slot) const {
204   int32 count = TotalCount(node);
205   if (count < min_count_) return Bytes(0);
206   return TotalBytes(node, slot) / std::max(1, TotalCount(node));
207 }
208 
RecordTime(const Node * node,Microseconds time)209 void CostModel::RecordTime(const Node* node, Microseconds time) {
210   const int id = Id(node);
211   if (id < 0) return;
212   DCHECK(node->IsOp()) << node->DebugString();
213   Ensure(id, node->num_outputs());
214   time_[id] += time;
215 }
216 
TotalTime(const Node * node) const217 Microseconds CostModel::TotalTime(const Node* node) const {
218   DCHECK(node->IsOp()) << node->DebugString();
219   const int id = Id(node);
220   if (id < 0 || static_cast<size_t>(id) >= time_.size() ||
221       time_[id] < Microseconds(0)) {
222     return Microseconds(0);
223   }
224   return time_[id];
225 }
226 
TimeEstimate(const Node * node) const227 Microseconds CostModel::TimeEstimate(const Node* node) const {
228   int32 count = TotalCount(node);
229   if (count <= min_count_) return kMinTimeEstimate;
230   return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count));
231 }
232 
CheckInitialized(const Graph & graph) const233 void CostModel::CheckInitialized(const Graph& graph) const {
234   for (const Node* n : graph.op_nodes()) {
235     CHECK(static_cast<size_t>(n->id()) < time_.size() &&
236           time_[n->id()] >= Microseconds(0))
237         << ": no time estimate for " << n->DebugString();
238 
239     CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
240         << ": no size estimate for " << n->DebugString();
241     const auto& perslot = slot_bytes_[n->id()];
242     for (size_t i = 0; i < perslot.size(); i++) {
243       CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
244                                      << " of " << n->DebugString();
245     }
246   }
247 }
248 
RecordMaxMemorySize(const Node * node,int output_slot,Bytes bytes,const TensorShapeProto & tensor_shape,const DataType & dtype)249 void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
250                                     Bytes bytes,
251                                     const TensorShapeProto& tensor_shape,
252                                     const DataType& dtype) {
253   const int id = Id(node);
254   if (id < 0) return;
255   if (output_slot >= node->num_outputs()) {
256     LOG(ERROR) << "Unexpected output slot for node " << node->DebugString()
257                << ". Got " << output_slot << " but its num_outputs is "
258                << node->num_outputs();
259     return;
260   }
261   Ensure(id, node->num_outputs());
262   auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
263   // If the memory allocator doesn't track memory usage, let's infer a lower
264   // bound from the tensor shape and its data type.
265   if (bytes.value() < 0) {
266     bytes = MinTensorMemoryUsage(tensor_shape, dtype);
267   }
268   if (bytes.value() > current_max.value()) {
269     current_max = bytes.value();
270     max_mem_usage_[id].output_port_shape[output_slot] = tensor_shape;
271     max_mem_usage_[id].output_port_type[output_slot] = dtype;
272   }
273 }
274 
MaxMemorySize(const Node * node,int slot) const275 Bytes CostModel::MaxMemorySize(const Node* node, int slot) const {
276   const int id = Id(node);
277   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
278       max_mem_usage_[id].output_port_mem.size() <= static_cast<size_t>(slot)) {
279     return Bytes(0);
280   }
281   return max_mem_usage_[id].output_port_mem[slot];
282 }
283 
MaxMemoryShape(const Node * node,int slot) const284 const TensorShapeProto& CostModel::MaxMemoryShape(const Node* node,
285                                                   int slot) const {
286   const int id = Id(node);
287   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
288       max_mem_usage_[id].output_port_shape.size() <=
289           static_cast<size_t>(slot)) {
290     return unknown_shape_;
291   }
292   return max_mem_usage_[id].output_port_shape[slot];
293 }
294 
MaxMemoryType(const Node * node,int slot) const295 DataType CostModel::MaxMemoryType(const Node* node, int slot) const {
296   const int id = Id(node);
297   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
298       max_mem_usage_[id].output_port_type.size() <= static_cast<size_t>(slot)) {
299     return DT_INVALID;
300   }
301   return max_mem_usage_[id].output_port_type[slot];
302 }
303 
TempMemorySize(const Node * node) const304 Bytes CostModel::TempMemorySize(const Node* node) const {
305   const int id = Id(node);
306   if (id < 0) {
307     return Bytes(0);
308   }
309   return max_mem_usage_[id].temp_memory_size;
310 }
311 
PersistentMemorySize(const Node * node) const312 Bytes CostModel::PersistentMemorySize(const Node* node) const {
313   const int id = Id(node);
314   if (id < 0) {
315     return Bytes(0);
316   }
317   return max_mem_usage_[id].persistent_memory_size;
318 }
319 
RecordMemoryStats(const Node * node,const MemoryStats & memory_stats)320 void CostModel::RecordMemoryStats(const Node* node,
321                                   const MemoryStats& memory_stats) {
322   const int id = Id(node);
323   if (id < 0) return;
324   max_mem_usage_[id].temp_memory_size = memory_stats.temp_memory_size();
325   max_mem_usage_[id].persistent_memory_size =
326       memory_stats.persistent_memory_size();
327   for (int64 alloc_id : memory_stats.persistent_tensor_alloc_ids()) {
328     if (alloc_id > 0) {
329       persistent_alloc_ids_.insert(alloc_id);
330     }
331   }
332 }
333 
RecordMaxExecutionTime(const Node * node,Microseconds time)334 void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) {
335   const int id = Id(node);
336   if (id < 0) return;
337   Ensure(id, node->num_outputs());
338   max_exec_time_[id] = std::max(max_exec_time_[id], time);
339 }
340 
MaxExecutionTime(const Node * node) const341 Microseconds CostModel::MaxExecutionTime(const Node* node) const {
342   const int id = Id(node);
343   if (id < 0 || static_cast<size_t>(id) >= max_exec_time_.size()) {
344     return Microseconds(0);
345   }
346   return max_exec_time_[id];
347 }
348 
RecordAllocationId(const Node * node,int output_slot,int64 alloc_id)349 void CostModel::RecordAllocationId(const Node* node, int output_slot,
350                                    int64 alloc_id) {
351   const int id = Id(node);
352   if (id < 0) return;
353   Ensure(id, node->num_outputs());
354   output_port_alloc_ids_[id][output_slot] = alloc_id;
355 }
356 
AllocationId(const Node * node,int slot) const357 int64 CostModel::AllocationId(const Node* node, int slot) const {
358   const int id = Id(node);
359   if (id < 0 || static_cast<size_t>(id) >= output_port_alloc_ids_.size() ||
360       output_port_alloc_ids_[id].size() <= static_cast<size_t>(slot)) {
361     return -1;
362   }
363   return output_port_alloc_ids_[id][slot];
364 }
365 
IsPersistentTensor(const Node * node,int64 alloc_id) const366 bool CostModel::IsPersistentTensor(const Node* node, int64 alloc_id) const {
367   if (persistent_alloc_ids_.count(alloc_id) > 0) {
368     return true;
369   }
370   if (persistent_alloc_ids_by_devices_.find(node->assigned_device_name()) ==
371       persistent_alloc_ids_by_devices_.end()) {
372     return false;
373   }
374   return persistent_alloc_ids_by_devices_.at(node->assigned_device_name())
375       .count(alloc_id);
376 }
377 
CopyTimeEstimate(Bytes b,double network_latency_millis,double estimated_gbps)378 Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
379                                          double estimated_gbps) {
380   // TODO(jeff,sanjay): estimate cost based on bandwidth along the
381   // communication path and the type of transport we are using between
382   // devices.
383   //
384   // We assume the copy time follows a linear model:
385   //    copy_time = copy_bytes / rate + min_time
386   int64 copy_bytes = b.value();
387   const double bytes_per_usec = estimated_gbps * 1000.0 / 8;
388   const double min_micros = network_latency_millis * 1000.0;
389   return Microseconds(
390       static_cast<int64>(copy_bytes / bytes_per_usec + min_micros));
391 }
392 
ComputationTimeEstimate(int64 math_ops)393 Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) {
394   // TODO(jeff,sanjay): Eventually we should pass in the type of device
395   // (GPU vs. CPU) and use that to affect the estimate.
396 
397   // We estimate the microseconds using that value.  We divide
398   // by 1000 to convert the madd number into microseconds (assuming
399   // roughly 1000 madds per microsecond (~1 GHz for one core)).
400   return Microseconds(math_ops / 1000);
401 }
402 
IncrementUpdateTimes()403 void CostModel::IncrementUpdateTimes() { update_times_++; }
404 
GetUpdateTimes() const405 int32 CostModel::GetUpdateTimes() const { return update_times_; }
406 
407 // ----------------------------------------------------------------------------
408 // InitCostModel
409 // ----------------------------------------------------------------------------
410 
411 namespace {
412 
AddNodesToCostModel(const Graph & g,CostModel * cost_model)413 static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) {
414   for (Node* n : g.nodes()) {
415     const int num_outputs = n->num_outputs();
416     cost_model->SetNumOutputs(n, num_outputs);
417     for (int output = 0; output < num_outputs; output++) {
418       // Set up an initial bogus estimate for the node's outputs
419       cost_model->RecordSize(n, output, Bytes(1));
420     }
421   }
422 }
423 
AssignSizes(const Graph & g,CostModel * cost_model)424 static void AssignSizes(const Graph& g, CostModel* cost_model) {
425   for (const Edge* e : g.edges()) {
426     // Skip if it is a control edge.
427     if (e->IsControlEdge()) {
428       continue;
429     }
430     const Node* src = e->src();
431 
432     // TODO(josh11b): Get an estimate from the Op
433     Bytes size(1);
434     cost_model->RecordSize(src, e->src_output(), size);
435   }
436 }
437 
438 // This generates an extremely simple initial guess for the
439 // computation cost of each node. For ordinary Ops, its value should quickly
440 // be wiped out by the real runtime measurements.  For other Ops we don't
441 // actually generate measurements, so suppression of infrequent Ops ends up
442 // giving them 0 costs.  So, this is not of much consequence except perhaps
443 // in tests.
TimeEstimateForNode(CostModel * cost_model,Node * n)444 static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) {
445   CHECK(n->IsOp());
446   VLOG(2) << "Node " << n->id() << ": " << n->name()
447           << " type_string: " << n->type_string();
448   if (IsConstant(n) || IsVariable(n)) {
449     return Microseconds(0);
450   }
451   return kDefaultTimeEstimate;
452 }
453 
EstimateComputationCosts(const Graph & g,CostModel * cost_model)454 static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) {
455   for (Node* n : g.nodes()) {
456     if (!n->IsOp()) continue;
457     cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n));
458   }
459 }
460 
461 }  // namespace
462 
InitFromGraph(const Graph & g)463 void CostModel::InitFromGraph(const Graph& g) {
464   const int num_node_ids = g.num_node_ids();
465   slot_bytes_.reserve(num_node_ids);
466   count_.reserve(num_node_ids);
467   time_.reserve(num_node_ids);
468   max_mem_usage_.reserve(num_node_ids);
469   max_exec_time_.reserve(num_node_ids);
470   output_port_alloc_ids_.reserve(num_node_ids);
471 
472   AddNodesToCostModel(g, this);
473   AssignSizes(g, this);
474   EstimateComputationCosts(g, this);
475   CheckInitialized(g);
476 }
477 
AddToCostGraphDef(const Graph * graph,CostGraphDef * cost_graph) const478 void CostModel::AddToCostGraphDef(const Graph* graph,
479                                   CostGraphDef* cost_graph) const {
480   std::vector<const Edge*> inputs;
481   std::vector<const Edge*> control_inputs;
482   for (const Node* n : graph->nodes()) {
483     CostGraphDef::Node* cnode = cost_graph->add_node();
484     cnode->set_name(n->name());
485     cnode->set_device(n->assigned_device_name());
486     cnode->set_id(Id(n));
487 
488     inputs.clear();
489     inputs.resize(n->num_inputs(), nullptr);
490     control_inputs.clear();
491     for (const Edge* e : n->in_edges()) {
492       if (e->IsControlEdge()) {
493         control_inputs.push_back(e);
494       } else {
495         inputs[e->dst_input()] = e;
496       }
497     }
498     std::sort(control_inputs.begin(), control_inputs.end(),
499               [this](Edge const* a, Edge const* b) {
500                 return Id(a->src()) < Id(b->src());
501               });
502 
503     for (const Edge* e : inputs) {
504       CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
505       input_info->set_preceding_node(Id(e->src()));
506       input_info->set_preceding_port(e->src_output());
507     }
508 
509     for (int i = 0; i < n->num_outputs(); i++) {
510       CostGraphDef::Node::OutputInfo* output_info = cnode->add_output_info();
511       int64 alloc_id = AllocationId(n, i);
512       int64 alias_to_input = -1;
513       for (const Edge* e : inputs) {
514         int64 input_alloc_id = AllocationId(e->src(), e->src_output());
515         if (input_alloc_id == alloc_id) {
516           alias_to_input = e->dst_input();
517           break;
518         }
519       }
520       output_info->set_alias_input_port(alias_to_input);
521       output_info->set_dtype(MaxMemoryType(n, i));
522       *output_info->mutable_shape() = MaxMemoryShape(n, i);
523       if (alias_to_input < 0 && IsPersistentTensor(n, alloc_id)) {
524         output_info->set_size(0);
525       } else {
526         output_info->set_size(MaxMemorySize(n, i).value());
527       }
528     }
529 
530     for (const Edge* e : control_inputs) {
531       cnode->add_control_input(Id(e->src()));
532     }
533 
534     cnode->set_temporary_memory_size(TempMemorySize(n).value());
535     cnode->set_persistent_memory_size(PersistentMemorySize(n).value());
536 
537     cnode->set_compute_cost(MaxExecutionTime(n).value());
538 
539     // For now we treat all send nodes as final.
540     // TODO(yuanbyu): Send nodes for fetches shouldn't be treated as final.
541     cnode->set_is_final(n->IsSend());
542   }
543 }
544 
WriteSummaryToLog() const545 void CostModel::WriteSummaryToLog() const {
546   LOG(INFO) << " min_count_=" << min_count_;
547   for (size_t i = 0; i < count_.size(); ++i) {
548     LOG(INFO) << "Node " << i << " count " << count_[i] << " total time "
549               << time_[i] << " avg time "
550               << (time_[i] / (std::max(1, count_[i])));
551   }
552 }
553 
MinTensorMemoryUsage(const TensorShapeProto & tensor_shape,const DataType & dtype)554 Bytes CostModel::MinTensorMemoryUsage(const TensorShapeProto& tensor_shape,
555                                       const DataType& dtype) {
556   if (tensor_shape.unknown_rank()) {
557     return Bytes(-1);
558   }
559 
560   size_t num_coefficients = 1;
561   for (const TensorShapeProto::Dim& dim : tensor_shape.dim()) {
562     // If the dimension is unknown, it has to be at least 1
563     num_coefficients *= std::max<size_t>(dim.size(), 1);
564   }
565   return Bytes(num_coefficients * DataTypeSize(dtype));
566 }
567 
568 }  // namespace tensorflow
569