• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A python interface for Grappler clusters."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22
23from tensorflow.core.framework import step_stats_pb2
24from tensorflow.core.grappler.costs import op_performance_data_pb2
25from tensorflow.core.protobuf import device_properties_pb2
26from tensorflow.python import pywrap_tensorflow as tf_cluster
27from tensorflow.python.framework import errors
28
29
30class Cluster(object):
31  """Grappler Clusters."""
32
33  def __init__(self,
34               allow_soft_placement=True,
35               disable_detailed_stats=True,
36               disable_timeline=True,
37               devices=None):
38    """Creates a Cluster.
39
40    Args:
41      allow_soft_placement: If True, TF will automatically fix illegal
42        placements instead of erroring out if the placement isn't legal.
43      disable_detailed_stats: If True, detailed statistics will not be
44        available.
45      disable_timeline: If True, the timeline information will not be reported.
46      devices: A list of devices of type device_properties_pb2.NamedDevice.
47        If None, a device list will be created based on the spec of
48        the local machine.
49    """
50    self._tf_cluster = None
51    self._generate_timeline = not disable_timeline
52    with errors.raise_exception_on_not_ok_status() as status:
53      if devices is None:
54        self._tf_cluster = tf_cluster.TF_NewCluster(
55            allow_soft_placement, disable_detailed_stats, status)
56      else:
57        devices_serialized = [device.SerializeToString() for device in devices]
58        self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
59            devices_serialized, status)
60
61  def Shutdown(self):
62    if self._tf_cluster is not None:
63      tf_cluster.TF_ShutdownCluster(self._tf_cluster)
64      self._tf_cluster = None
65
66  def __del__(self):
67    self.Shutdown()
68
69  @property
70  def tf_cluster(self):
71    return self._tf_cluster
72
73  def ListDevices(self):
74    """Returns a list of available hardware devices."""
75    if self._tf_cluster is None:
76      return []
77    return [device_properties_pb2.NamedDevice.FromString(device)
78            for device in tf_cluster.TF_ListDevices(self._tf_cluster)]
79
80  def ListAvailableOps(self):
81    """Returns a list of all available operations (sorted alphabetically)."""
82    return tf_cluster.TF_ListAvailableOps()
83
84  def GetSupportedDevices(self, item):
85    return tf_cluster.TF_GetSupportedDevices(self._tf_cluster, item.tf_item)
86
87  def EstimatePerformance(self, device):
88    return tf_cluster.TF_EstimatePerformance(device.SerializeToString())
89
90  def MeasureCosts(self, item):
91    """Returns the cost of running the specified item.
92
93    Args:
94      item: The item for which to measure the costs.
95    Returns: The triplet op_perfs, runtime, step_stats.
96    """
97    with errors.raise_exception_on_not_ok_status() as status:
98      ret_from_swig = tf_cluster.TF_MeasureCosts(
99          item.tf_item, self._tf_cluster, self._generate_timeline, status)
100
101    if ret_from_swig is None:
102      return None
103
104    op_perf_bytes_list, run_time, step_stats_bytes = ret_from_swig
105    op_perfs = [op_performance_data_pb2.OpPerformance.FromString(op_perf_bytes)
106                for op_perf_bytes in op_perf_bytes_list]
107    return (op_perfs, run_time,
108            step_stats_pb2.StepStats.FromString(step_stats_bytes))
109
110  def DeterminePeakMemoryUsage(self, item):
111    """Returns a snapshot of the peak memory usage.
112
113    Args:
114      item: The item for which to measure the costs.
115    Returns: A hashtable indexed by device name.
116    """
117    with errors.raise_exception_on_not_ok_status() as status:
118      return tf_cluster.TF_DeterminePeakMemoryUsage(
119          item.tf_item, self._tf_cluster, status)
120
121
122@contextlib.contextmanager
123def Provision(allow_soft_placement=True,
124              disable_detailed_stats=True,
125              disable_timeline=True,
126              devices=None):
127  cluster = Cluster(allow_soft_placement, disable_detailed_stats,
128                    disable_timeline, devices)
129  yield cluster
130  cluster.Shutdown()
131