• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_memory.h"
17 
18 #include <deque>
19 #include "tensorflow/core/framework/allocation_description.pb.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/step_stats.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
28 #include "tensorflow/core/grappler/costs/graph_properties.h"
29 #include "tensorflow/core/grappler/utils.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
InferStatically(const std::unordered_map<string,DeviceProperties> & devices)34 Status GraphMemory::InferStatically(
35     const std::unordered_map<string, DeviceProperties>& devices) {
36   VirtualCluster cluster(devices);
37   TF_RETURN_IF_ERROR(cluster.Provision());
38   TF_RETURN_IF_ERROR(cluster.Initialize(item_));
39   RunMetadata metadata;
40   Status s = cluster.Run(item_, &metadata);
41   // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
42   // that the model would run out of memory. We still get the metadata we need
43   // out of the simulation, so we just ignore this error.
44   if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
45     return s;
46   }
47   InferFromTrace(metadata.step_stats());
48   return Status::OK();
49 }
50 
InferDynamically(Cluster * cluster)51 Status GraphMemory::InferDynamically(Cluster* cluster) {
52   if (!cluster->DetailedStatsEnabled()) {
53     return errors::Unavailable("Detailed stats collection must be enabled");
54   }
55 
56   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
57   RunMetadata metadata;
58   TF_RETURN_IF_ERROR(cluster->Run(item_, &metadata));
59   InferFromTrace(metadata.step_stats());
60   return Status::OK();
61 }
62 
GetWorstCaseMemoryUsage() const63 int64 GraphMemory::GetWorstCaseMemoryUsage() const {
64   int64 worst_case = -1;
65   for (const auto& peak_usage : peak_usage_) {
66     worst_case = std::max(worst_case, peak_usage.second.used_memory);
67   }
68   return worst_case;
69 }
70 
InferMemUsageForNodes(const std::vector<const NodeDef * > & nodes,GraphProperties * properties,int64 * worst_case_memory_usage,int64 * best_case_memory_usage) const71 void GraphMemory::InferMemUsageForNodes(
72     const std::vector<const NodeDef*>& nodes, GraphProperties* properties,
73     int64* worst_case_memory_usage, int64* best_case_memory_usage) const {
74   // TODO(bsteiner) refine this: we should consider the multidevice case.
75   *worst_case_memory_usage = 0;
76   *best_case_memory_usage = 0;
77   for (const auto& node : item_.graph.node()) {
78     // Estimate the memory required to store the tensors generated by the node.
79     std::vector<OpInfo::TensorProperties> outputs =
80         properties->GetOutputProperties(node.name());
81     int64 node_memory_usage = InferMemUsageForNeighbors(outputs);
82 
83     // Worst case memory usage corresponds to the case where all the nodes are
84     // alive.
85     *worst_case_memory_usage += node_memory_usage;
86 
87     // Estimate the memory required to store the input tensors needed by the
88     // node.
89     std::vector<OpInfo::TensorProperties> inputs =
90         properties->GetInputProperties(node.name());
91     node_memory_usage += InferMemUsageForNeighbors(inputs);
92 
93     *best_case_memory_usage =
94         std::max(*best_case_memory_usage, node_memory_usage);
95   }
96 }
97 
InferMemUsageForNeighbors(const std::vector<OpInfo::TensorProperties> & props) const98 int64 GraphMemory::InferMemUsageForNeighbors(
99     const std::vector<OpInfo::TensorProperties>& props) const {
100   int64 neighbors_memory_usage = 0;
101   for (const auto& prop : props) {
102     DataType dtype = prop.dtype();
103     int size = DataTypeSize(dtype);
104     TensorShapeProto shape = prop.shape();
105     if (shape.unknown_rank()) {
106       // Can't infer the size if the rank is unknown, just skip.
107       continue;
108     }
109     // If one of the dimensions is unknown statically, assume it's one.
110     for (int i = 0; i < shape.dim_size(); ++i) {
111       if (shape.dim(i).size() < 0) {
112         shape.mutable_dim(i)->set_size(1);
113       }
114     }
115     int num_elems = TensorShape(shape).num_elements();
116     neighbors_memory_usage += num_elems * size;
117   }
118   return neighbors_memory_usage;
119 }
120 
FindOrCreateLiveTensor(const string & node_name,int output_id,std::unordered_map<string,GraphMemory::LiveTensor * > * live_tensors,std::deque<GraphMemory::LiveTensor> * device_tensors)121 static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
122     const string& node_name, int output_id,
123     std::unordered_map<string, GraphMemory::LiveTensor*>* live_tensors,
124     std::deque<GraphMemory::LiveTensor>* device_tensors) {
125   string name = strings::StrCat(node_name, ":", output_id);
126   GraphMemory::LiveTensor* live;
127   auto it = live_tensors->find(name);
128   if (it == live_tensors->end()) {
129     GraphMemory::LiveTensor temp;
130     temp.node = node_name;
131     temp.output_id = output_id;
132     temp.allocation_time = 0;
133     temp.deallocation_time = 0;
134     device_tensors->push_front(temp);
135     live = &device_tensors->front();
136     (*live_tensors)[name] = live;
137   } else {
138     live = it->second;
139   }
140   return live;
141 }
142 
143 namespace {
144 struct Event {
Eventtensorflow::grappler::__anon1ecf95b90111::Event145   Event(int64 _timestamp, bool _allocated,
146         const GraphMemory::LiveTensor* _tensor)
147       : timestamp(_timestamp), allocated(_allocated), tensor(_tensor) {}
148 
149   int64 timestamp;
150   bool allocated;
151   const GraphMemory::LiveTensor* tensor;
152 
operator <tensorflow::grappler::__anon1ecf95b90111::Event153   bool operator<(const Event& other) const {
154     return timestamp < other.timestamp;
155   }
156 };
157 }  // namespace
158 
InferFromTrace(const StepStats & timeline)159 void GraphMemory::InferFromTrace(const StepStats& timeline) {
160   std::unordered_map<string, string> node_placement;
161   for (const auto& dev_stats : timeline.dev_stats()) {
162     for (const auto& node_stats : dev_stats.node_stats()) {
163       node_placement[node_stats.node_name()] = dev_stats.device();
164     }
165   }
166 
167   std::unordered_map<string, LiveTensor*> live_tensors;
168   std::unordered_map<string, std::deque<LiveTensor>> live_tensors_per_device;
169   std::unordered_map<string, const NodeDef*> node_map;
170   for (const NodeDef& node : item_.graph.node()) {
171     node_map[node.name()] = &node;
172   }
173   for (const auto& dev_stats : timeline.dev_stats()) {
174     const string& device_name = dev_stats.device();
175     const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
176     std::deque<LiveTensor>& device_tensors =
177         live_tensors_per_device[dev_stats.device()];
178     for (const auto& node_stats : dev_stats.node_stats()) {
179       for (int i = 0; i < node_stats.output_size(); ++i) {
180         const auto& output = node_stats.output(i);
181 
182         LiveTensor* live = FindOrCreateLiveTensor(
183             node_stats.node_name(), i, &live_tensors, &device_tensors);
184         live->memory_used = output.tensor_description()
185                                 .allocation_description()
186                                 .allocated_bytes();
187 
188         // Allocations typically take place at the very beginning of the op
189         // execution.
190         live->allocation_time =
191             Costs::MicroSeconds(node_stats.all_start_micros());
192         // Add one nanosecond to the completion time of the ops to account for
193         // TF overhead that slightly delays deallocations.
194         live->deallocation_time = std::max<Costs::Duration>(
195             live->deallocation_time,
196             Costs::NanoSeconds(1) +
197                 Costs::MicroSeconds(node_stats.all_start_micros() +
198                                     node_stats.op_end_rel_micros()));
199       }
200 
201       auto it = node_map.find(node_stats.node_name());
202       if (it == node_map.end()) {
203         // Skip nodes inserted by TF since they don't exist in the original
204         // graph (e.g _Send/_Recv nodes).
205         continue;
206       }
207       const NodeDef* node = it->second;
208       std::unordered_set<int> swapped_inputs;
209       if (is_gpu) {
210         auto it = node->attr().find("_swap_to_host");
211         if (it != node->attr().end()) {
212           const AttrValue& val = it->second;
213           for (int port_id : val.list().i()) {
214             swapped_inputs.insert(port_id);
215           }
216         }
217       }
218       for (int i = 0; i < node->input_size(); ++i) {
219         if (swapped_inputs.find(i) != swapped_inputs.end()) {
220           // The memory of swapped inputs will be released as early as possible:
221           // therefore ignore this input when determining the deallocation time
222           // of the tensor.
223           continue;
224         }
225         const string& input = node->input(i);
226         int position;
227         string input_node = ParseNodeName(input, &position);
228         if (position < 0) {
229           // Skip control dependencies
230           continue;
231         }
232         LiveTensor* live = FindOrCreateLiveTensor(
233             input_node, position, &live_tensors,
234             &live_tensors_per_device[node_placement[input_node]]);
235         live->deallocation_time = std::max<Costs::Duration>(
236             live->deallocation_time,
237             Costs::NanoSeconds(1) +
238                 Costs::MicroSeconds(node_stats.all_start_micros() +
239                                     node_stats.op_end_rel_micros()));
240       }
241     }
242   }
243 
244   for (const auto& live_per_device : live_tensors_per_device) {
245     std::vector<Event> events;
246     events.reserve(2 * live_per_device.second.size());
247     for (const auto& live : live_per_device.second) {
248       events.emplace_back(static_cast<int64>(live.allocation_time.count()),
249                           true, &live);
250       events.emplace_back(static_cast<int64>(live.deallocation_time.count()),
251                           false, &live);
252     }
253     std::stable_sort(events.begin(), events.end());
254     size_t peak = 0;
255     std::unordered_set<const LiveTensor*> live_at_peak;
256     size_t current = 0;
257     std::unordered_set<const LiveTensor*> currently_live;
258     int events_size = events.size();
259     for (int i = 0; i < events_size; ++i) {
260       const auto& event = events[i];
261 
262       if (event.allocated) {
263         VLOG(1) << "At time " << event.timestamp << " allocated "
264                 << event.tensor->memory_used << " for tensor "
265                 << event.tensor->node << ":" << event.tensor->output_id;
266         current += event.tensor->memory_used;
267         currently_live.insert(event.tensor);
268       } else {
269         VLOG(1) << "At time " << event.timestamp << " deallocated "
270                 << event.tensor->memory_used << " for tensor "
271                 << event.tensor->node << ":" << event.tensor->output_id;
272         current -= event.tensor->memory_used;
273         currently_live.erase(event.tensor);
274       }
275       if (i + 1 == events_size || event.timestamp != events[i + 1].timestamp) {
276         if (current > peak) {
277           peak = current;
278           live_at_peak = currently_live;
279         }
280       }
281     }
282     MemoryUsage& peak_mem_usage = peak_usage_[live_per_device.first];
283     peak_mem_usage.used_memory = peak;
284     peak_mem_usage.live_tensors.clear();
285     peak_mem_usage.live_tensors.reserve(live_at_peak.size());
286     for (const auto& live : live_at_peak) {
287       peak_mem_usage.live_tensors.push_back(*live);
288     }
289   }
290 }
291 
292 }  // end namespace grappler
293 }  // end namespace tensorflow
294