• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_memory.h"
17 
18 #include <deque>
19 #include "tensorflow/core/framework/allocation_description.pb.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/step_stats.pb.h"
23 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
28 #include "tensorflow/core/grappler/costs/graph_properties.h"
29 #include "tensorflow/core/grappler/utils.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
InferStatically(const std::unordered_map<string,DeviceProperties> & devices)34 Status GraphMemory::InferStatically(
35     const std::unordered_map<string, DeviceProperties>& devices) {
36   VirtualCluster cluster(devices);
37   TF_RETURN_IF_ERROR(cluster.Provision());
38   TF_RETURN_IF_ERROR(cluster.Initialize(item_));
39   RunMetadata metadata;
40   Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
41   // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
42   // that the model would run out of memory. We still get the metadata we need
43   // out of the simulation, so we just ignore this error.
44   if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
45     return s;
46   }
47   InferFromTrace(metadata.step_stats());
48   return Status::OK();
49 }
50 
InferDynamically(Cluster * cluster)51 Status GraphMemory::InferDynamically(Cluster* cluster) {
52   if (!cluster->DetailedStatsEnabled()) {
53     return errors::Unavailable("Detailed stats collection must be enabled");
54   }
55 
56   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
57   RunMetadata metadata;
58   TF_RETURN_IF_ERROR(
59       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
60   InferFromTrace(metadata.step_stats());
61   return Status::OK();
62 }
63 
GetWorstCaseMemoryUsage() const64 int64 GraphMemory::GetWorstCaseMemoryUsage() const {
65   int64 worst_case = -1;
66   for (const auto& peak_usage : peak_usage_) {
67     worst_case = std::max(worst_case, peak_usage.second.used_memory);
68   }
69   return worst_case;
70 }
71 
InferMemUsageForNodes(const std::vector<const NodeDef * > & nodes,GraphProperties * properties,int64 * worst_case_memory_usage,int64 * best_case_memory_usage) const72 void GraphMemory::InferMemUsageForNodes(
73     const std::vector<const NodeDef*>& nodes, GraphProperties* properties,
74     int64* worst_case_memory_usage, int64* best_case_memory_usage) const {
75   // TODO(bsteiner) refine this: we should consider the multidevice case.
76   *worst_case_memory_usage = 0;
77   *best_case_memory_usage = 0;
78   for (const auto& node : item_.graph.node()) {
79     // Estimate the memory required to store the tensors generated by the node.
80     std::vector<OpInfo::TensorProperties> outputs =
81         properties->GetOutputProperties(node.name());
82     int64 node_memory_usage = InferMemUsageForNeighbors(outputs);
83 
84     // Worst case memory usage corresponds to the case where all the nodes are
85     // alive.
86     *worst_case_memory_usage += node_memory_usage;
87 
88     // Estimate the memory required to store the input tensors needed by the
89     // node.
90     std::vector<OpInfo::TensorProperties> inputs =
91         properties->GetInputProperties(node.name());
92     node_memory_usage += InferMemUsageForNeighbors(inputs);
93 
94     *best_case_memory_usage =
95         std::max(*best_case_memory_usage, node_memory_usage);
96   }
97 }
98 
InferMemUsageForNeighbors(const std::vector<OpInfo::TensorProperties> & props) const99 int64 GraphMemory::InferMemUsageForNeighbors(
100     const std::vector<OpInfo::TensorProperties>& props) const {
101   int64 neighbors_memory_usage = 0;
102   for (const auto& prop : props) {
103     DataType dtype = prop.dtype();
104     int size = DataTypeSize(dtype);
105     TensorShapeProto shape = prop.shape();
106     if (shape.unknown_rank()) {
107       // Can't infer the size if the rank is unknown, just skip.
108       continue;
109     }
110     // If one of the dimensions is unknown statically, assume it's one.
111     for (int i = 0; i < shape.dim_size(); ++i) {
112       if (shape.dim(i).size() < 0) {
113         shape.mutable_dim(i)->set_size(1);
114       }
115     }
116     int num_elems = TensorShape(shape).num_elements();
117     neighbors_memory_usage += num_elems * size;
118   }
119   return neighbors_memory_usage;
120 }
121 
FindOrCreateLiveTensor(const string & node_name,int output_id,std::unordered_map<string,GraphMemory::LiveTensor * > * live_tensors,std::deque<GraphMemory::LiveTensor> * device_tensors)122 static GraphMemory::LiveTensor* FindOrCreateLiveTensor(
123     const string& node_name, int output_id,
124     std::unordered_map<string, GraphMemory::LiveTensor*>* live_tensors,
125     std::deque<GraphMemory::LiveTensor>* device_tensors) {
126   string name = strings::StrCat(node_name, ":", output_id);
127   GraphMemory::LiveTensor* live;
128   auto it = live_tensors->find(name);
129   if (it == live_tensors->end()) {
130     GraphMemory::LiveTensor temp;
131     temp.node = node_name;
132     temp.output_id = output_id;
133     temp.allocation_time = 0;
134     temp.deallocation_time = 0;
135     device_tensors->push_front(temp);
136     live = &device_tensors->front();
137     (*live_tensors)[name] = live;
138   } else {
139     live = it->second;
140   }
141   return live;
142 }
143 
144 namespace {
145 struct Event {
Eventtensorflow::grappler::__anonf4a2a89b0111::Event146   Event(int64 _timestamp, bool _allocated,
147         const GraphMemory::LiveTensor* _tensor)
148       : timestamp(_timestamp), allocated(_allocated), tensor(_tensor) {}
149 
150   int64 timestamp;
151   bool allocated;
152   const GraphMemory::LiveTensor* tensor;
153 
operator <tensorflow::grappler::__anonf4a2a89b0111::Event154   bool operator<(const Event& other) const {
155     return timestamp < other.timestamp;
156   }
157 };
158 }  // namespace
159 
InferFromTrace(const StepStats & timeline)160 void GraphMemory::InferFromTrace(const StepStats& timeline) {
161   std::unordered_map<string, string> node_placement;
162   for (const auto& dev_stats : timeline.dev_stats()) {
163     for (const auto& node_stats : dev_stats.node_stats()) {
164       node_placement[node_stats.node_name()] = dev_stats.device();
165     }
166   }
167 
168   std::unordered_map<string, LiveTensor*> live_tensors;
169   std::unordered_map<string, std::deque<LiveTensor>> live_tensors_per_device;
170   std::unordered_map<string, const NodeDef*> node_map;
171   for (const NodeDef& node : item_.graph.node()) {
172     node_map[node.name()] = &node;
173   }
174   for (const auto& dev_stats : timeline.dev_stats()) {
175     const string& device_name = dev_stats.device();
176     const bool is_gpu = (device_name.find("GPU:") || device_name.find("gpu:"));
177     std::deque<LiveTensor>& device_tensors =
178         live_tensors_per_device[dev_stats.device()];
179     for (const auto& node_stats : dev_stats.node_stats()) {
180       for (int i = 0; i < node_stats.output_size(); ++i) {
181         const auto& output = node_stats.output(i);
182 
183         LiveTensor* live = FindOrCreateLiveTensor(
184             node_stats.node_name(), i, &live_tensors, &device_tensors);
185         live->memory_used = output.tensor_description()
186                                 .allocation_description()
187                                 .allocated_bytes();
188 
189         // Allocations typically take place at the very beginning of the op
190         // execution.
191         live->allocation_time =
192             Costs::MicroSeconds(node_stats.all_start_micros());
193         // Add one nanosecond to the completion time of the ops to account for
194         // TF overhead that slightly delays deallocations.
195         live->deallocation_time = std::max<Costs::Duration>(
196             live->deallocation_time,
197             Costs::NanoSeconds(1) +
198                 Costs::MicroSeconds(node_stats.all_start_micros() +
199                                     node_stats.op_end_rel_micros()));
200       }
201 
202       auto it = node_map.find(node_stats.node_name());
203       if (it == node_map.end()) {
204         // Skip nodes inserted by TF since they don't exist in the original
205         // graph (e.g _Send/_Recv nodes).
206         continue;
207       }
208       const NodeDef* node = it->second;
209       std::unordered_set<int> swapped_inputs;
210       if (is_gpu) {
211         auto it = node->attr().find("_swap_to_host");
212         if (it != node->attr().end()) {
213           const AttrValue& val = it->second;
214           for (int port_id : val.list().i()) {
215             swapped_inputs.insert(port_id);
216           }
217         }
218       }
219       for (int i = 0; i < node->input_size(); ++i) {
220         if (swapped_inputs.find(i) != swapped_inputs.end()) {
221           // The memory of swapped inputs will be released as early as possible:
222           // therefore ignore this input when determining the deallocation time
223           // of the tensor.
224           continue;
225         }
226         const string& input = node->input(i);
227         int position;
228         string input_node = ParseNodeName(input, &position);
229         if (position < 0) {
230           // Skip control dependencies
231           continue;
232         }
233         LiveTensor* live = FindOrCreateLiveTensor(
234             input_node, position, &live_tensors,
235             &live_tensors_per_device[node_placement[input_node]]);
236         live->deallocation_time = std::max<Costs::Duration>(
237             live->deallocation_time,
238             Costs::NanoSeconds(1) +
239                 Costs::MicroSeconds(node_stats.all_start_micros() +
240                                     node_stats.op_end_rel_micros()));
241       }
242     }
243   }
244 
245   for (const auto& live_per_device : live_tensors_per_device) {
246     std::vector<Event> events;
247     events.reserve(2 * live_per_device.second.size());
248     for (const auto& live : live_per_device.second) {
249       events.emplace_back(static_cast<int64>(live.allocation_time.count()),
250                           true, &live);
251       events.emplace_back(static_cast<int64>(live.deallocation_time.count()),
252                           false, &live);
253     }
254     std::stable_sort(events.begin(), events.end());
255     size_t peak = 0;
256     std::unordered_set<const LiveTensor*> live_at_peak;
257     size_t current = 0;
258     std::unordered_set<const LiveTensor*> currently_live;
259     for (int i = 0; i < events.size(); ++i) {
260       const auto& event = events[i];
261 
262       if (event.allocated) {
263         VLOG(1) << "At time " << event.timestamp << " allocated "
264                 << event.tensor->memory_used << " for tensor "
265                 << event.tensor->node << ":" << event.tensor->output_id;
266         current += event.tensor->memory_used;
267         currently_live.insert(event.tensor);
268       } else {
269         VLOG(1) << "At time " << event.timestamp << " deallocated "
270                 << event.tensor->memory_used << " for tensor "
271                 << event.tensor->node << ":" << event.tensor->output_id;
272         current -= event.tensor->memory_used;
273         currently_live.erase(event.tensor);
274       }
275       if (i + 1 == events.size() ||
276           event.timestamp != events[i + 1].timestamp) {
277         if (current > peak) {
278           peak = current;
279           live_at_peak = currently_live;
280         }
281       }
282     }
283     MemoryUsage& peak_mem_usage = peak_usage_[live_per_device.first];
284     peak_mem_usage.used_memory = peak;
285     peak_mem_usage.live_tensors.clear();
286     peak_mem_usage.live_tensors.reserve(live_at_peak.size());
287     for (const auto& live : live_at_peak) {
288       peak_mem_usage.live_tensors.push_back(*live);
289     }
290   }
291 }
292 
293 }  // end namespace grappler
294 }  // end namespace tensorflow
295