1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A python interface for Grappler clusters.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.core.framework import step_stats_pb2 24from tensorflow.core.grappler.costs import op_performance_data_pb2 25from tensorflow.core.protobuf import device_properties_pb2 26from tensorflow.python import _pywrap_tf_cluster as tf_cluster 27 28 29class Cluster(object): 30 """Grappler Clusters.""" 31 32 def __init__(self, 33 allow_soft_placement=True, 34 disable_detailed_stats=True, 35 disable_timeline=True, 36 devices=None): 37 """Creates a Cluster. 38 39 Args: 40 allow_soft_placement: If True, TF will automatically fix illegal 41 placements instead of erroring out if the placement isn't legal. 42 disable_detailed_stats: If True, detailed statistics will not be 43 available. 44 disable_timeline: If True, the timeline information will not be reported. 45 devices: A list of devices of type device_properties_pb2.NamedDevice. 46 If None, a device list will be created based on the spec of 47 the local machine. 48 """ 49 self._tf_cluster = None 50 self._generate_timeline = not disable_timeline 51 52 if devices is None: 53 self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement, 54 disable_detailed_stats) 55 else: 56 devices_serialized = [device.SerializeToString() for device in devices] 57 self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized) 58 59 def Shutdown(self): 60 if self._tf_cluster is not None: 61 tf_cluster.TF_ShutdownCluster(self._tf_cluster) 62 self._tf_cluster = None 63 64 def __del__(self): 65 self.Shutdown() 66 67 @property 68 def tf_cluster(self): 69 return self._tf_cluster 70 71 def ListDevices(self): 72 """Returns a list of available hardware devices.""" 73 if self._tf_cluster is None: 74 return [] 75 return [device_properties_pb2.NamedDevice.FromString(device) 76 for device in tf_cluster.TF_ListDevices(self._tf_cluster)] 77 78 def ListAvailableOps(self): 79 """Returns a list of all available operations (sorted alphabetically).""" 80 return tf_cluster.TF_ListAvailableOps() 81 82 def GetSupportedDevices(self, item): 83 return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item) 84 85 def EstimatePerformance(self, device): 86 return tf_cluster.TF_EstimatePerformance(device.SerializeToString()) 87 88 def MeasureCosts(self, item): 89 """Returns the cost of running the specified item. 90 91 Args: 92 item: The item for which to measure the costs. 93 Returns: The triplet op_perfs, runtime, step_stats. 94 """ 95 op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts( 96 item.tf_item, self._tf_cluster, self._generate_timeline) 97 98 op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes) 99 for op_perf_bytes in op_perf_bytes_list] 100 return (op_perfs, run_time, 101 step_stats_pb2.StepStats.FromString(step_stats_bytes)) 102 103 def DeterminePeakMemoryUsage(self, item): 104 """Returns a snapshot of the peak memory usage. 105 106 Args: 107 item: The item for which to measure the costs. 108 Returns: A hashtable indexed by device name. 109 """ 110 return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item, 111 self._tf_cluster) 112 113 114@contextlib.contextmanager 115def Provision(allow_soft_placement=True, 116 disable_detailed_stats=True, 117 disable_timeline=True, 118 devices=None): 119 cluster = Cluster(allow_soft_placement, disable_detailed_stats, 120 disable_timeline, devices) 121 yield cluster 122 cluster.Shutdown() 123