1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
17
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/clusters/utils.h"
22 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
23
24 namespace tensorflow {
25 namespace grappler {
26
VirtualCluster(const std::unordered_map<string,DeviceProperties> & devices)27 VirtualCluster::VirtualCluster(
28 const std::unordered_map<string, DeviceProperties>& devices)
29 : VirtualCluster(devices, std::make_unique<OpLevelCostEstimator>(),
30 ReadyNodeManagerFactory("FirstReady")) {}
31
VirtualCluster(const std::unordered_map<string,DeviceProperties> & devices,std::unique_ptr<OpLevelCostEstimator> node_estimator,std::unique_ptr<ReadyNodeManager> node_manager)32 VirtualCluster::VirtualCluster(
33 const std::unordered_map<string, DeviceProperties>& devices,
34 std::unique_ptr<OpLevelCostEstimator> node_estimator,
35 std::unique_ptr<ReadyNodeManager> node_manager)
36 : Cluster(0) {
37 devices_ = devices;
38
39 // Note that we do not use aggressive shape inference to preserve unknown
40 // shapes from the input graph.
41 estimator_ = std::make_unique<AnalyticalCostEstimator>(
42 this, std::move(node_estimator), std::move(node_manager),
43 /*use_static_shapes=*/true, /*use_aggressive_shape_inference=*/false);
44 }
45
VirtualCluster(const DeviceSet * device_set)46 VirtualCluster::VirtualCluster(const DeviceSet* device_set)
47 : VirtualCluster(std::unordered_map<string, DeviceProperties>()) {
48 device_set_ = device_set;
49 for (const auto& device : device_set_->devices()) {
50 DeviceProperties props = GetDeviceInfo(device->parsed_name());
51 if (props.type() == "UNKNOWN") continue;
52 auto attrs = device->attributes();
53 props.set_memory_size(attrs.memory_limit());
54 devices_[device->name()] = props;
55 }
56 }
57
~VirtualCluster()58 VirtualCluster::~VirtualCluster() {}
59
Provision()60 Status VirtualCluster::Provision() { return OkStatus(); }
61
Initialize(const GrapplerItem & item)62 Status VirtualCluster::Initialize(const GrapplerItem& item) {
63 return OkStatus();
64 }
65
Run(const GraphDef & graph,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)66 Status VirtualCluster::Run(const GraphDef& graph,
67 const std::vector<std::pair<string, Tensor>>& feed,
68 const std::vector<string>& fetch,
69 RunMetadata* metadata) {
70 GrapplerItem item;
71 item.graph = graph;
72 item.feed = feed;
73 item.fetch = fetch;
74 return Run(item, metadata);
75 }
76
Run(const GrapplerItem & item,RunMetadata * metadata)77 Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) {
78 // Initializes an analytical cost estimator to estimate the graph cost. Makes
79 // sure to use static shape inference to prevent the virtual scheduler from
80 // calling the Run method on the cluster and creating an infinite loop.
81 if (metadata) {
82 metadata->clear_step_stats();
83 metadata->clear_cost_graph();
84 metadata->clear_partition_graphs();
85 }
86
87 TF_RETURN_IF_ERROR(estimator_->Initialize(item));
88 TF_RETURN_IF_ERROR(
89 estimator_->PredictCosts(item.graph, metadata, /*cost=*/nullptr));
90
91 const std::unordered_map<string, DeviceProperties>& device = GetDevices();
92 std::unordered_map<string, int64_t> peak_mem_usage =
93 estimator_->GetScheduler()->GetPeakMemoryUsage();
94 for (const auto& mem_usage : peak_mem_usage) {
95 const string& device_name = mem_usage.first;
96 auto it = device.find(device_name);
97 if (it == device.end()) {
98 // It's probably the fake send/recv device. Eventually we'll need to
99 // remove this fake device to ensure proper memory accounting for
100 // multi-device settings.
101 continue;
102 }
103 const DeviceProperties& dev = it->second;
104 if (dev.memory_size() <= 0) {
105 // Available device memory unknown
106 continue;
107 }
108 int64_t peak_mem = mem_usage.second;
109 if (peak_mem >= dev.memory_size()) {
110 return errors::ResourceExhausted(
111 "Graph requires ", peak_mem, " bytes of memory on device ",
112 device_name, " to run ", " but device only has ", dev.memory_size(),
113 " available.");
114 }
115 }
116
117 return OkStatus();
118 }
119
120 } // namespace grappler
121 } // namespace tensorflow
122