1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ 18 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/device_set.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/grappler/grappler_item.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/protobuf/device_properties.pb.h" 30 #include "tensorflow/core/public/session_options.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 // A cluster represents of collection of hardware resources available to run 36 // the TensorFlow model. 37 // A process can only create a single cluster at a time. 38 class Cluster { 39 public: 40 explicit Cluster(int timeout_s); 41 virtual ~Cluster(); 42 43 // Returns a string that represent the type of cluster that was instantiated. 44 virtual string type() const = 0; 45 46 // Provision the hardware resources needed to run TensorFlow and start a 47 // TensorFlow session that can take advantage of these resources. 48 // The actual resources that are leveraged depend on the type of cluster 49 // instantiated. 50 // Returns OK iff all the requested resources could be reserved and a 51 // TensorFlow session successfully created. Returns an error otherwise. 52 // There is no graceful degradation to handle the case where only a subset 53 // of the requested resources are available. 54 virtual Status Provision() = 0; 55 56 // Attempts to shutdown the cluster. 57 // Returns OK iff there are no pending calls to the Run() method and all the 58 // resources used by the cluster could be released. Returns an error 59 // otherwise. Shutdown()60 virtual Status Shutdown() { return Status::OK(); } 61 62 // Whether soft placement is allowed. If allow_soft_placement is true, 63 // an op will be placed on CPU if there's no GPU implementation for the OP 64 // or if no GPU devices are known or registered or if we need to co-locate 65 // with reftype input(s) which are from CPU. 66 void AllowSoftPlacement(bool soft_placement_state); 67 68 // Update the number of inter-op threads for each per-session threadpool 69 void SetNumInterOpThreads(int num_threads); 70 71 // Set the number of steps required to warmup TensorFlow. Must be called 72 // before Provision(). 73 void SetNumWarmupSteps(int num_steps); 74 75 // Set executor type to instantiate 76 void SetExecutorType(const string* executor_type); 77 78 // Returns the number of warmup steps. 79 int NumWarmupSteps() const; 80 81 // Disable the collection of detailed statistics. Must be called 82 // before Provision(). 83 void DisableDetailedStats(bool disable); 84 85 // Returns true iff the collection of detailed statistics is enabled. 86 bool DetailedStatsEnabled() const; 87 88 // Disable the TensorFlow optimizer. This ensures that the graph that TF 89 // executes is similar to the input graph. Must be called before Provision(). 90 void DisableOptimizer(bool disable); 91 92 // Return the list of TensorFlow devices that are available to execute a 93 // graph. This is empty until provision() is called. GetDevices()94 const std::unordered_map<string, DeviceProperties>& GetDevices() const { 95 return devices_; 96 } 97 98 // Convenience method that returns the set of device names. These names are 99 // sorted alphabetically. 100 const std::vector<string> GetDeviceNames() const; 101 102 // The DeviceSet is not always available, but when it is it contains a 103 // superset of the devices listed in GetDevices/GetDeviceNames(). GetDeviceSet()104 virtual const DeviceSet* GetDeviceSet() const { return nullptr; } 105 106 // Enables collecting the allocator stats. Call with enable=true must be made 107 // before Provision(). EnablePeakMemoryStats(bool enable)108 virtual Status EnablePeakMemoryStats(bool enable) { 109 return errors::Unimplemented(strings ::StrCat( 110 "Peak Memory Stats are not supported on ", type(), " clusters")); 111 } 112 113 // Returns peak memory of all devices during the session creation and session 114 // runs. GetPeakMemoryUsage(std::unordered_map<string,uint64> * device_peak_memory)115 virtual Status GetPeakMemoryUsage( 116 std::unordered_map<string, uint64>* device_peak_memory) const { 117 return errors::Unimplemented( 118 "GetPeakMemoryUsage is not implemented for this type of cluster."); 119 } 120 121 // Prepare the session to run the specified grappler item. This include 122 // initializing all the model variables. 123 virtual Status Initialize(const GrapplerItem& item) = 0; 124 125 // Run the specified graph_def and return the corresponding metadata. 126 virtual Status Run(const GraphDef& graph_def, 127 const std::vector<std::pair<string, Tensor>>& feed, 128 const std::vector<string>& fetch, 129 RunMetadata* metadata) = 0; 130 131 protected: 132 std::unordered_map<string, DeviceProperties> devices_; 133 const int timeout_s_; 134 SessionOptions options_; 135 RunOptions run_options_; 136 }; 137 138 } // end namespace grappler 139 } // end namespace tensorflow 140 141 #endif // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ 142