• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A python interface for Grappler clusters."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.core.framework import step_stats_pb2
24from tensorflow.core.grappler.costs import op_performance_data_pb2
25from tensorflow.core.protobuf import device_properties_pb2
26from tensorflow.python import _pywrap_tf_cluster as tf_cluster
27
28
29class Cluster(object):
30  """Grappler Clusters."""
31
32  def __init__(self,
33               allow_soft_placement=True,
34               disable_detailed_stats=True,
35               disable_timeline=True,
36               devices=None):
37    """Creates a Cluster.
38
39    Args:
40      allow_soft_placement: If True, TF will automatically fix illegal
41        placements instead of erroring out if the placement isn't legal.
42      disable_detailed_stats: If True, detailed statistics will not be
43        available.
44      disable_timeline: If True, the timeline information will not be reported.
45      devices: A list of devices of type device_properties_pb2.NamedDevice.
46        If None, a device list will be created based on the spec of
47        the local machine.
48    """
49    self._tf_cluster = None
50    self._generate_timeline = not disable_timeline
51
52    if devices is None:
53      self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement,
54                                                  disable_detailed_stats)
55    else:
56      devices_serialized = [device.SerializeToString() for device in devices]
57      self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized)
58
59  def Shutdown(self):
60    if self._tf_cluster is not None:
61      tf_cluster.TF_ShutdownCluster(self._tf_cluster)
62      self._tf_cluster = None
63
64  def __del__(self):
65    self.Shutdown()
66
67  @property
68  def tf_cluster(self):
69    return self._tf_cluster
70
71  def ListDevices(self):
72    """Returns a list of available hardware devices."""
73    if self._tf_cluster is None:
74      return []
75    return [device_properties_pb2.NamedDevice.FromString(device)
76            for device in tf_cluster.TF_ListDevices(self._tf_cluster)]
77
78  def ListAvailableOps(self):
79    """Returns a list of all available operations (sorted alphabetically)."""
80    return tf_cluster.TF_ListAvailableOps()
81
82  def GetSupportedDevices(self, item):
83    return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item)
84
85  def EstimatePerformance(self, device):
86    return tf_cluster.TF_EstimatePerformance(device.SerializeToString())
87
88  def MeasureCosts(self, item):
89    """Returns the cost of running the specified item.
90
91    Args:
92      item: The item for which to measure the costs.
93    Returns: The triplet op_perfs, runtime, step_stats.
94    """
95    op_perf_bytes_list, run_time, step_stats_bytes = tf_cluster.TF_MeasureCosts(
96        item.tf_item, self._tf_cluster, self._generate_timeline)
97
98    op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
99                for op_perf_bytes in op_perf_bytes_list]
100    return (op_perfs, run_time,
101            step_stats_pb2.StepStats.FromString(step_stats_bytes))
102
103  def DeterminePeakMemoryUsage(self, item):
104    """Returns a snapshot of the peak memory usage.
105
106    Args:
107      item: The item for which to measure the costs.
108    Returns: A hashtable indexed by device name.
109    """
110    return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item,
111                                                  self._tf_cluster)
112
113
114@contextlib.contextmanager
115def Provision(allow_soft_placement=True,
116              disable_detailed_stats=True,
117              disable_timeline=True,
118              devices=None):
119  cluster = Cluster(allow_soft_placement, disable_detailed_stats,
120                    disable_timeline, devices)
121  yield cluster
122  cluster.Shutdown()
123