• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
17 
18 #include "tensorflow/core/framework/cost_graph.pb.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/clusters/utils.h"
22 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
23 
24 namespace tensorflow {
25 namespace grappler {
26 
VirtualCluster(const std::unordered_map<string,DeviceProperties> & devices)27 VirtualCluster::VirtualCluster(
28     const std::unordered_map<string, DeviceProperties>& devices)
29     : VirtualCluster(devices, std::make_unique<OpLevelCostEstimator>(),
30                      ReadyNodeManagerFactory("FirstReady")) {}
31 
VirtualCluster(const std::unordered_map<string,DeviceProperties> & devices,std::unique_ptr<OpLevelCostEstimator> node_estimator,std::unique_ptr<ReadyNodeManager> node_manager)32 VirtualCluster::VirtualCluster(
33     const std::unordered_map<string, DeviceProperties>& devices,
34     std::unique_ptr<OpLevelCostEstimator> node_estimator,
35     std::unique_ptr<ReadyNodeManager> node_manager)
36     : Cluster(0) {
37   devices_ = devices;
38 
39   // Note that we do not use aggressive shape inference to preserve unknown
40   // shapes from the input graph.
41   estimator_ = std::make_unique<AnalyticalCostEstimator>(
42       this, std::move(node_estimator), std::move(node_manager),
43       /*use_static_shapes=*/true, /*use_aggressive_shape_inference=*/false);
44 }
45 
VirtualCluster(const DeviceSet * device_set)46 VirtualCluster::VirtualCluster(const DeviceSet* device_set)
47     : VirtualCluster(std::unordered_map<string, DeviceProperties>()) {
48   device_set_ = device_set;
49   for (const auto& device : device_set_->devices()) {
50     DeviceProperties props = GetDeviceInfo(device->parsed_name());
51     if (props.type() == "UNKNOWN") continue;
52     auto attrs = device->attributes();
53     props.set_memory_size(attrs.memory_limit());
54     devices_[device->name()] = props;
55   }
56 }
57 
~VirtualCluster()58 VirtualCluster::~VirtualCluster() {}
59 
Provision()60 Status VirtualCluster::Provision() { return OkStatus(); }
61 
Initialize(const GrapplerItem & item)62 Status VirtualCluster::Initialize(const GrapplerItem& item) {
63   return OkStatus();
64 }
65 
Run(const GraphDef & graph,const std::vector<std::pair<string,Tensor>> & feed,const std::vector<string> & fetch,RunMetadata * metadata)66 Status VirtualCluster::Run(const GraphDef& graph,
67                            const std::vector<std::pair<string, Tensor>>& feed,
68                            const std::vector<string>& fetch,
69                            RunMetadata* metadata) {
70   GrapplerItem item;
71   item.graph = graph;
72   item.feed = feed;
73   item.fetch = fetch;
74   return Run(item, metadata);
75 }
76 
Run(const GrapplerItem & item,RunMetadata * metadata)77 Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) {
78   // Initializes an analytical cost estimator to estimate the graph cost. Makes
79   // sure to use static shape inference to prevent the virtual scheduler from
80   // calling the Run method on the cluster and creating an infinite loop.
81   if (metadata) {
82     metadata->clear_step_stats();
83     metadata->clear_cost_graph();
84     metadata->clear_partition_graphs();
85   }
86 
87   TF_RETURN_IF_ERROR(estimator_->Initialize(item));
88   TF_RETURN_IF_ERROR(
89       estimator_->PredictCosts(item.graph, metadata, /*cost=*/nullptr));
90 
91   const std::unordered_map<string, DeviceProperties>& device = GetDevices();
92   std::unordered_map<string, int64_t> peak_mem_usage =
93       estimator_->GetScheduler()->GetPeakMemoryUsage();
94   for (const auto& mem_usage : peak_mem_usage) {
95     const string& device_name = mem_usage.first;
96     auto it = device.find(device_name);
97     if (it == device.end()) {
98       // It's probably the fake send/recv device. Eventually we'll need to
99       // remove this fake device to ensure proper memory accounting for
100       // multi-device settings.
101       continue;
102     }
103     const DeviceProperties& dev = it->second;
104     if (dev.memory_size() <= 0) {
105       // Available device memory unknown
106       continue;
107     }
108     int64_t peak_mem = mem_usage.second;
109     if (peak_mem >= dev.memory_size()) {
110       return errors::ResourceExhausted(
111           "Graph requires ", peak_mem, " bytes of memory on device ",
112           device_name, " to run ", " but device only has ", dev.memory_size(),
113           " available.");
114     }
115   }
116 
117   return OkStatus();
118 }
119 
120 }  // namespace grappler
121 }  // namespace tensorflow
122