1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A python interface for Grappler clusters.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.core.framework import step_stats_pb2 24from tensorflow.core.grappler.costs import op_performance_data_pb2 25from tensorflow.core.protobuf import device_properties_pb2 26from tensorflow.python import pywrap_tensorflow as tf_cluster 27from tensorflow.python.framework import errors 28 29 30class Cluster(object): 31 """Grappler Clusters.""" 32 33 def __init__(self, 34 allow_soft_placement=True, 35 disable_detailed_stats=True, 36 disable_timeline=True, 37 devices=None): 38 """Creates a Cluster. 39 40 Args: 41 allow_soft_placement: If True, TF will automatically fix illegal 42 placements instead of erroring out if the placement isn't legal. 43 disable_detailed_stats: If True, detailed statistics will not be 44 available. 45 disable_timeline: If True, the timeline information will not be reported. 46 devices: A list of devices of type device_properties_pb2.NamedDevice. 47 If None, a device list will be created based on the spec of 48 the local machine. 49 """ 50 self._tf_cluster = None 51 self._generate_timeline = not disable_timeline 52 with errors.raise_exception_on_not_ok_status() as status: 53 if devices is None: 54 self._tf_cluster = tf_cluster.TF_NewCluster( 55 allow_soft_placement, disable_detailed_stats, status) 56 else: 57 devices_serialized = [device.SerializeToString() for device in devices] 58 self._tf_cluster = tf_cluster.TF_NewVirtualCluster( 59 devices_serialized, status) 60 61 def Shutdown(self): 62 if self._tf_cluster is not None: 63 tf_cluster.TF_ShutdownCluster(self._tf_cluster) 64 self._tf_cluster = None 65 66 def __del__(self): 67 self.Shutdown() 68 69 @property 70 def tf_cluster(self): 71 return self._tf_cluster 72 73 def ListDevices(self): 74 """Returns a list of available hardware devices.""" 75 if self._tf_cluster is None: 76 return [] 77 return [device_properties_pb2.NamedDevice.FromString(device) 78 for device in tf_cluster.TF_ListDevices(self._tf_cluster)] 79 80 def ListAvailableOps(self): 81 """Returns a list of all available operations (sorted alphabetically).""" 82 return tf_cluster.TF_ListAvailableOps() 83 84 def GetSupportedDevices(self, item): 85 return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item) 86 87 def EstimatePerformance(self, device): 88 return tf_cluster.TF_EstimatePerformance(device.SerializeToString()) 89 90 def MeasureCosts(self, item): 91 """Returns the cost of running the specified item. 92 93 Args: 94 item: The item for which to measure the costs. 95 Returns: The triplet op_perfs, runtime, step_stats. 96 """ 97 with errors.raise_exception_on_not_ok_status() as status: 98 ret_from_swig = tf_cluster.TF_MeasureCosts( 99 item.tf_item, self._tf_cluster, self._generate_timeline, status) 100 101 if ret_from_swig is None: 102 return None 103 104 op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig 105 op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes) 106 for op_perf_bytes in op_perf_bytes_list] 107 return (op_perfs, run_time, 108 step_stats_pb2.StepStats.FromString(step_stats_bytes)) 109 110 def DeterminePeakMemoryUsage(self, item): 111 """Returns a snapshot of the peak memory usage. 112 113 Args: 114 item: The item for which to measure the costs. 115 Returns: A hashtable indexed by device name. 116 """ 117 with errors.raise_exception_on_not_ok_status() as status: 118 return tf_cluster.TF_DeterminePeakMemoryUsage( 119 item.tf_item, self._tf_cluster, status) 120 121 122@contextlib.contextmanager 123def Provision(allow_soft_placement=True, 124 disable_detailed_stats=True, 125 disable_timeline=True, 126 devices=None): 127 cluster = Cluster(allow_soft_placement, disable_detailed_stats, 128 disable_timeline, devices) 129 yield cluster 130 cluster.Shutdown() 131