• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 %include "tensorflow/python/platform/base.i"
17 %include <std_shared_ptr.i>
18 %include "item.i"
19 
20 // Wrap the cluster into an object that swig can manipulate. This ensures it will call the object
21 // destructor upon garbage collection instead of leaking memory.
22 struct GCluster {
23   std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
24 };
25 
26 %{
27 #include "tensorflow/core/protobuf/device_properties.pb.h"
28 
29 template <>
_PyObjAs(PyObject * input,tensorflow::NamedDevice * out)30 bool _PyObjAs(PyObject *input, tensorflow::NamedDevice *out) {
31   char* c_string;
32   Py_ssize_t py_size;
33   if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
34     // Python has raised an error (likely TypeError or UnicodeEncodeError).
35     return false;
36   }
37 
38   tensorflow::NamedDevice named_device;
39   if (!named_device.ParseFromString(string(c_string, py_size))) {
40     PyErr_SetString(
41         PyExc_TypeError,
42         "The NamedDevice could not be parsed as a valid protocol buffer");
43     return false;
44   }
45   if (out) *out = named_device;
46   return true;
47 }
48 %}
49 
50 %typemap(in) const std::vector<tensorflow::NamedDevice>& (std::vector<tensorflow::NamedDevice> temp) {
51   if (!tf_vector_input_helper($input, &temp, &_PyObjAs<tensorflow::NamedDevice>)) {
52     SWIG_fail;
53   }
54   $1 = &temp;
55 }
56 
57 %typemap(in) const tensorflow::NamedDevice& (tensorflow::NamedDevice temp) {
58   char* c_string;
59   Py_ssize_t py_size;
60   if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
61     // Python has raised an error (likely TypeError or UnicodeEncodeError).
62     SWIG_fail;
63   }
64 
65   if (!temp.ParseFromString(string(c_string, py_size))) {
66     PyErr_SetString(
67         PyExc_TypeError,
68         "The NamedDevice could not be parsed as a valid protocol buffer");
69     SWIG_fail;
70   }
71   $1 = &temp;
72 }
73 
74 %typemap(in) const tensorflow::RunMetadata& (tensorflow::RunMetadata temp) {
75   char* c_string;
76   Py_ssize_t py_size;
77   if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
78     // Python has raised an error (likely TypeError or UnicodeEncodeError).
79     SWIG_fail;
80   }
81 
82   if (!temp.ParseFromString(string(c_string, py_size))) {
83     PyErr_SetString(
84         PyExc_TypeError,
85         "The RunMetadata could not be parsed as a valid protocol buffer");
86     SWIG_fail;
87   }
88   $1 = &temp;
89 }
90 
91 %typemap(in) const string& (string temp) {
92   char *buf;
93   Py_ssize_t len;
94   if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) return NULL;
95   temp.assign(buf, len);
96   $1 = &temp;
97 }
98 
99 %{
100 #include <memory>
101 #include <vector>
102 #include "tensorflow/core/grappler/devices.h"
103 #include "tensorflow/core/grappler/utils.h"
104 #include "tensorflow/core/grappler/clusters/single_machine.h"
105 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
106 #include "tensorflow/core/grappler/costs/graph_memory.h"
107 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
108 #include "tensorflow/core/grappler/costs/measuring_cost_estimator.h"
109 #include "tensorflow/core/grappler/costs/utils.h"
110 #include "tensorflow/core/protobuf/device_properties.pb.h"
111 #include "tensorflow/core/framework/kernel_def.pb.h"
112 #include "tensorflow/core/framework/memory_types.h"
113 
114 // Provide the implementation of the GCluster struct here.
115 struct GCluster {
GClusterGCluster116   GCluster() {}
GClusterGCluster117   GCluster(tensorflow::grappler::Cluster* cluster) : cluster_(cluster) {}
118 
119   tensorflow::grappler::Cluster* operator->() const {
120     return cluster_.get();
121   }
getGCluster122   tensorflow::grappler::Cluster* get() const {
123     return cluster_.get();
124   }
is_noneGCluster125   bool is_none() const {
126     return cluster_.get() == nullptr;
127   }
128 
129   std::shared_ptr<tensorflow::grappler::Cluster> cluster_;
130 };
131 
132 
TF_NewCluster(bool allow_soft_placement,bool disable_detailed_stats,TF_Status * out_status)133 static GCluster TF_NewCluster(bool allow_soft_placement,
134                    bool disable_detailed_stats, TF_Status* out_status) {
135   int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
136   int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
137   int timeout_s = 60 * 10;
138   tensorflow::grappler::Cluster* cluster_ =
139       new tensorflow::grappler::SingleMachine(
140           timeout_s, num_cpu_cores, num_gpus);
141   cluster_->DisableDetailedStats(disable_detailed_stats);
142   cluster_->AllowSoftPlacement(allow_soft_placement);
143   cluster_->SetNumWarmupSteps(10);
144   tensorflow::Status status = cluster_->Provision();
145   tensorflow::Set_TF_Status_from_Status(out_status, status);
146   return GCluster(cluster_);
147 }
148 
TF_NewVirtualCluster(const std::vector<tensorflow::NamedDevice> & named_devices,TF_Status * out_status)149 static GCluster TF_NewVirtualCluster(
150     const std::vector<tensorflow::NamedDevice>& named_devices,
151     TF_Status* out_status) {
152   std::unordered_map<string, tensorflow::DeviceProperties> devices;
153   for (const auto& named_device : named_devices) {
154     devices[named_device.name()]= named_device.properties();
155   }
156   tensorflow::grappler::Cluster* cluster_ =
157       new tensorflow::grappler::VirtualCluster(devices);
158   PyGILState_STATE gstate = PyGILState_Ensure();
159   tensorflow::Status status = cluster_->Provision();
160   PyGILState_Release(gstate);
161   tensorflow::Set_TF_Status_from_Status(out_status, status);
162   return GCluster(cluster_);
163 }
164 
TF_ShutdownCluster(GCluster cluster)165 static void TF_ShutdownCluster(GCluster cluster) {
166   PyGILState_STATE gstate = PyGILState_Ensure();
167   cluster->Shutdown();
168   PyGILState_Release(gstate);
169 }
170 
_GetOpPerformanceDataAndRunTime(const tensorflow::grappler::GrapplerItem & item,tensorflow::grappler::CostEstimator * cost_measure,tensorflow::OpPerformanceList * op_performance_data,tensorflow::grappler::Costs * costs)171 tensorflow::Status _GetOpPerformanceDataAndRunTime(
172     const tensorflow::grappler::GrapplerItem& item,
173     tensorflow::grappler::CostEstimator* cost_measure,
174     tensorflow::OpPerformanceList* op_performance_data,
175     tensorflow::grappler::Costs* costs) {
176   tensorflow::Status status = cost_measure->Initialize(item);
177   if (!status.ok()) return status;
178 
179   tensorflow::RunMetadata run_metadata;
180   TF_RETURN_IF_ERROR(
181       cost_measure->PredictCosts(item.graph, &run_metadata, costs));
182 
183   if (op_performance_data) {
184     *op_performance_data = tensorflow::grappler::CostGraphToOpPerformanceData(
185         run_metadata.cost_graph(), item.graph);
186   }
187   return tensorflow::Status::OK();
188 }
189 
TF_ListDevices(GCluster cluster)190 static PyObject* TF_ListDevices(GCluster cluster) {
191   const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
192   PyGILState_STATE gstate = PyGILState_Ensure();
193   PyObject* result = PyList_New(devices.size());
194   int i = 0;
195   for (auto& dev : devices) {
196     tensorflow::NamedDevice d;
197     d.set_name(dev.first);
198     *d.mutable_properties() = dev.second;
199     string dev_str = d.SerializeAsString();
200     PyObject* dev_obj = PyBytes_FromStringAndSize(dev_str.data(),
201                                                   dev_str.size());
202     PyList_SetItem(result, i, dev_obj);
203     ++i;
204   }
205   PyGILState_Release(gstate);
206   return result;
207 }
208 
TF_ListAvailableOps()209 static PyObject* TF_ListAvailableOps() {
210   tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
211   std::vector<tensorflow::OpDef> ops;
212   registry->GetRegisteredOps(&ops);
213   std::vector<string> op_names;
214   for (const tensorflow::OpDef& op : ops) {
215     op_names.push_back(op.name());
216   }
217   std::sort(op_names.begin(), op_names.end());
218 
219   PyGILState_STATE gstate = PyGILState_Ensure();
220   PyObject* result = PyList_New(op_names.size());
221   for (int i = 0; i < op_names.size(); ++i) {
222     PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
223   }
224   PyGILState_Release(gstate);
225   return result;
226 }
227 
TF_GetSupportedDevices(GCluster cluster,GItem item)228 static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
229   if (cluster.is_none() || item.is_none()) {
230     Py_RETURN_NONE;
231   }
232   const std::unordered_map<string, tensorflow::DeviceProperties>& devices = cluster->GetDevices();
233   std::unordered_map<string, std::vector<string>> device_types;
234   for (const auto& dev : devices) {
235     device_types[dev.second.type()].push_back(dev.first);
236   }
237 
238   std::unordered_map<string, std::set<string>> supported_device_types;
239   std::unordered_map<string, std::set<string>> device_restrictions;
240 
241   for (const auto& node : item->graph.node()) {
242     for (const auto& dev : device_types) {
243       const string& type = dev.first;
244       if (cluster->type() != "single_machine") {
245         // The actual kernel may not be linked in this binary.
246         supported_device_types[node.name()].insert(type);
247       } else {
248         // Check the kernel capabilities
249         const tensorflow::DeviceType dev_type(type);
250         tensorflow::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr);
251         if (s.ok()) {
252           supported_device_types[node.name()].insert(type);
253 
254           // Check which inputs are restricted to reside on the host.
255           // TODO: extends this to support outputs as well
256           tensorflow::MemoryTypeVector inp_mtypes;
257           tensorflow::MemoryTypeVector out_mtypes;
258           s = tensorflow::MemoryTypesForNode(tensorflow::OpRegistry::Global(), dev_type, node,
259                                              &inp_mtypes, &out_mtypes);
260           if (s.ok()) {
261             for (int i = 0; i < inp_mtypes.size(); ++i) {
262               if (inp_mtypes[i] == tensorflow::HOST_MEMORY) {
263                 device_restrictions[tensorflow::grappler::NodeName(node.input(i))].insert("CPU");
264                 break;
265               }
266             }
267           }
268         }
269       }
270     }
271   }
272 
273   PyGILState_STATE gstate = PyGILState_Ensure();
274   PyObject* result = PyDict_New();
275 
276   for (const auto& supported_dev : supported_device_types) {
277     const string& node = supported_dev.first;
278     std::set<string> feasible;
279     const auto it = device_restrictions.find(node);
280     if (it != device_restrictions.end()) {
281       const std::set<string>& candidates = supported_dev.second;
282       const std::set<string>& valid = it->second;
283       std::set_intersection(candidates.begin(), candidates.end(), valid.begin(), valid.end(),
284                             std::inserter(feasible, feasible.begin()));
285     } else {
286       feasible = supported_dev.second;
287     }
288 
289     std::vector<string> device_names;
290     for (const string& type : feasible) {
291       auto it = device_types.find(type);
292       CHECK(it != device_types.end());
293       for (const string& name : it->second) {
294         device_names.push_back(name);
295       }
296     }
297 
298     PyObject* dev = PyList_New(device_names.size());
299     for (int i = 0; i < device_names.size(); ++i) {
300       PyList_SetItem(dev, i, PyString_FromString(device_names[i].c_str()));
301     }
302     CHECK_EQ(0, PyDict_SetItem(result, PyString_FromString(node.c_str()), dev));
303   }
304   PyGILState_Release(gstate);
305   return result;
306 }
307 
308 
TF_EstimatePerformance(const tensorflow::NamedDevice & device)309 static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
310   tensorflow::grappler::OpLevelCostEstimator estimator;
311   tensorflow::grappler::DeviceInfo info =
312       estimator.GetDeviceInfo(device.properties());
313   return info.gigaops;
314 }
315 
TF_MeasureCosts(GItem item,GCluster cluster,bool generate_timeline,TF_Status * out_status)316 static PyObject* TF_MeasureCosts(
317     GItem item,
318     GCluster cluster,
319     bool generate_timeline, TF_Status* out_status) {
320   tensorflow::OpPerformanceList op_performance_data;
321   tensorflow::StepStats step_stats;
322 
323   const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
324   tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
325 
326   tensorflow::grappler::Costs costs;
327   tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
328       *item, &cost_measure, &op_performance_data, &costs);
329   double run_time = FLT_MAX;
330   if (status.ok()) {
331     run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
332   }
333   if (generate_timeline) {
334     tensorflow::RunMetadata metadata;
335     tensorflow::Status s = cluster->Run(
336         item->graph, item->feed, item->fetch, &metadata);
337     if (s.ok()) {
338       step_stats = metadata.step_stats();
339     } else {
340       status = s;
341     }
342   }
343 
344   tensorflow::Set_TF_Status_from_Status(out_status, status);
345   if (!status.ok()) {
346     Py_RETURN_NONE;
347   }
348   PyGILState_STATE gstate = PyGILState_Ensure();
349   PyObject* op_perf_objs = PyList_New(
350       op_performance_data.op_performance_size());
351   for (int i = 0; i < op_performance_data.op_performance_size(); i++) {
352     string op_perf_str =
353         op_performance_data.op_performance(i).SerializeAsString();
354     PyObject* op_perf_obj = PyBytes_FromStringAndSize(op_perf_str.data(),
355                                                       op_perf_str.size());
356     PyList_SetItem(op_perf_objs, i, op_perf_obj);
357   }
358 
359   PyObject* run_time_obj = PyFloat_FromDouble(run_time);
360 
361   string step_stats_str = step_stats.SerializeAsString();
362   PyObject* metadata_obj = PyBytes_FromStringAndSize(step_stats_str.data(),
363                                                      step_stats_str.size());
364 
365   PyObject* ret = PyTuple_New(3);
366   if (PyTuple_SetItem(ret, 0, op_perf_objs) != 0 ||
367       PyTuple_SetItem(ret, 1, run_time_obj) != 0 ||
368       PyTuple_SetItem(ret, 2, metadata_obj) != 0) {
369     Py_DECREF(ret);
370     Py_XDECREF(op_perf_objs);
371     Py_XDECREF(run_time_obj);
372     Py_XDECREF(metadata_obj);
373     status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
374                                 "Error setting return tuples.");
375     tensorflow::Set_TF_Status_from_Status(out_status, status);
376     Py_INCREF(Py_None);
377     ret = Py_None;
378   }
379   PyGILState_Release(gstate);
380   return ret;
381 }
382 
383 
TF_DeterminePeakMemoryUsage(GItem item,GCluster cluster,TF_Status * out_status)384 static PyObject* TF_DeterminePeakMemoryUsage(
385     GItem item,
386     GCluster cluster,
387     TF_Status* out_status) {
388   if (item.is_none() || cluster.is_none()) {
389     tensorflow::Status status(tensorflow::error::Code::INTERNAL,
390                               "You need both a cluster and an item to determine peak memory usage");
391     tensorflow::Set_TF_Status_from_Status(out_status, status);
392     Py_RETURN_NONE;
393   }
394   tensorflow::grappler::GraphMemory memory(*item);
395 
396   tensorflow::Status status;
397   if (cluster->DetailedStatsEnabled()) {
398     status = memory.InferDynamically(cluster.get());
399   } else {
400     status = memory.InferStatically(cluster->GetDevices());
401   }
402   if (!status.ok()) {
403     tensorflow::Set_TF_Status_from_Status(out_status, status);
404     Py_RETURN_NONE;
405   }
406 
407   PyGILState_STATE gstate = PyGILState_Ensure();
408   PyObject* result = PyDict_New();
409   for (const auto& device : cluster->GetDevices()) {
410     const tensorflow::grappler::GraphMemory::MemoryUsage& usage =
411         memory.GetPeakMemoryUsage(device.first);
412     PyObject* per_device = PyList_New(usage.live_tensors.size());
413     for (int i = 0; i < usage.live_tensors.size(); ++i) {
414       const auto& live_tensor = usage.live_tensors[i];
415       PyObject* live = PyTuple_New(5);
416       PyTuple_SetItem(live, 0, PyString_FromString(live_tensor.node.c_str()));
417       PyTuple_SetItem(live, 1, PyInt_FromLong(live_tensor.output_id));
418       PyTuple_SetItem(live, 2, PyLong_FromLong(live_tensor.memory_used));
419       PyTuple_SetItem(live, 3, PyLong_FromLong(live_tensor.allocation_time.count()));
420       PyTuple_SetItem(live, 4, PyLong_FromLong(live_tensor.deallocation_time.count()));
421       PyList_SetItem(per_device, i, live);
422 
423     }
424     PyObject* ret = PyTuple_New(2);
425     PyTuple_SetItem(ret, 0, PyLong_FromLong(usage.used_memory));
426     PyTuple_SetItem(ret, 1, per_device);
427     PyDict_SetItem(result, PyString_FromString(device.first.c_str()), ret);
428   }
429   PyGILState_Release(gstate);
430   return result;
431 }
432 
433 %}
434 
435 // Wrap these functions.
436 static GCluster TF_NewCluster(
437     bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
438 static GCluster TF_NewVirtualCluster(
439     const std::vector<tensorflow::NamedDevice>& named_devices,
440     TF_Status* out_status);
441 static void TF_ShutdownCluster(GCluster cluster);
442 static PyObject* TF_ListDevices(GCluster cluster);
443 static PyObject* TF_ListAvailableOps();
444 static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
445 static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
446 static PyObject* TF_MeasureCosts(
447     GItem item, GCluster cluster,
448     bool generate_timeline, TF_Status* out_status);
449 static PyObject* TF_DeterminePeakMemoryUsage(
450     GItem item, GCluster cluster,
451     TF_Status* out_status);
452