• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_
18 
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device_set.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/grappler/grappler_item.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 #include "tensorflow/core/protobuf/device_properties.pb.h"
30 #include "tensorflow/core/public/session_options.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 // A cluster represents of collection of hardware resources available to run
36 // the TensorFlow model.
37 // A process can only create a single cluster at a time.
38 class Cluster {
39  public:
40   explicit Cluster(int timeout_s);
41   virtual ~Cluster();
42 
43   // Returns a string that represent the type of cluster that was instantiated.
44   virtual string type() const = 0;
45 
46   // Provision the hardware resources needed to run TensorFlow and start a
47   // TensorFlow session that can take advantage of these resources.
48   // The actual resources that are leveraged depend on the type of cluster
49   // instantiated.
50   // Returns OK iff all the requested resources could be reserved and a
51   // TensorFlow session successfully created. Returns an error otherwise.
52   // There is no graceful degradation to handle the case where only a subset
53   // of the requested resources are available.
54   virtual Status Provision() = 0;
55 
56   // Attempts to shutdown the cluster.
57   // Returns OK iff there are no pending calls to the Run() method and all the
58   // resources used by the cluster could be released. Returns an error
59   // otherwise.
Shutdown()60   virtual Status Shutdown() { return Status::OK(); }
61 
62   // Whether soft placement is allowed. If allow_soft_placement is true,
63   // an op will be placed on CPU if there's no GPU implementation for the OP
64   // or if no GPU devices are known or registered or if we need to co-locate
65   // with reftype input(s) which are from CPU.
66   void AllowSoftPlacement(bool soft_placement_state);
67 
68   // Update the number of inter-op threads for each per-session threadpool
69   void SetNumInterOpThreads(int num_threads);
70 
71   // Set the number of steps required to warmup TensorFlow. Must be called
72   // before Provision().
73   void SetNumWarmupSteps(int num_steps);
74 
75   // Set executor type to instantiate
76   void SetExecutorType(const string* executor_type);
77 
78   // Returns the number of warmup steps.
79   int NumWarmupSteps() const;
80 
81   // Disable the collection of detailed statistics. Must be called
82   // before Provision().
83   void DisableDetailedStats(bool disable);
84 
85   // Returns true iff the collection of detailed statistics is enabled.
86   bool DetailedStatsEnabled() const;
87 
88   // Disable the TensorFlow optimizer. This ensures that the graph that TF
89   // executes is similar to the input graph. Must be called before Provision().
90   void DisableOptimizer(bool disable);
91 
92   // Return the list of TensorFlow devices that are available to execute a
93   // graph. This is empty until provision() is called.
GetDevices()94   const std::unordered_map<string, DeviceProperties>& GetDevices() const {
95     return devices_;
96   }
97 
98   // Convenience method that returns the set of device names. These names are
99   // sorted alphabetically.
100   const std::vector<string> GetDeviceNames() const;
101 
102   // The DeviceSet is not always available, but when it is it contains a
103   // superset of the devices listed in GetDevices/GetDeviceNames().
GetDeviceSet()104   virtual const DeviceSet* GetDeviceSet() const { return nullptr; }
105 
106   // Enables collecting the allocator stats. Call with enable=true must be made
107   // before Provision().
EnablePeakMemoryStats(bool enable)108   virtual Status EnablePeakMemoryStats(bool enable) {
109     return errors::Unimplemented(strings ::StrCat(
110         "Peak Memory Stats are not supported on ", type(), " clusters"));
111   }
112 
113   // Returns peak memory of all devices during the session creation and session
114   // runs.
GetPeakMemoryUsage(std::unordered_map<string,uint64> * device_peak_memory)115   virtual Status GetPeakMemoryUsage(
116       std::unordered_map<string, uint64>* device_peak_memory) const {
117     return errors::Unimplemented(
118         "GetPeakMemoryUsage is not implemented for this type of cluster.");
119   }
120 
121   // Prepare the session to run the specified grappler item. This include
122   // initializing all the model variables.
123   virtual Status Initialize(const GrapplerItem& item) = 0;
124 
125   // Run the specified graph_def and return the corresponding metadata.
126   virtual Status Run(const GraphDef& graph_def,
127                      const std::vector<std::pair<string, Tensor>>& feed,
128                      const std::vector<string>& fetch,
129                      RunMetadata* metadata) = 0;
130 
131  protected:
132   std::unordered_map<string, DeviceProperties> devices_;
133   const int timeout_s_;
134   SessionOptions options_;
135   RunOptions run_options_;
136 };
137 
138 }  // end namespace grappler
139 }  // end namespace tensorflow
140 
141 #endif  // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_
142