1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ 18 19 #include <unordered_map> 20 21 #include "tensorflow/core/distributed_runtime/worker_interface.h" 22 #include "tensorflow/core/framework/cancellation.h" 23 #include "tensorflow/core/lib/core/status.h" 24 #include "tensorflow/core/platform/macros.h" 25 #include "tensorflow/core/platform/mutex.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 30 // PartialRunMgr keeps track of pending partial run requests, and ensures that 31 // the partial run is only marked complete when the corresponding executor is 32 // run to completion. 33 // 34 // In tensorflow workers, the executor runs operations asynchronously until 35 // specified fetches (operations that return tensors) or targets (operations 36 // that don't return tensors) are reached. A PartialRun has two components: a 37 // setup which specifies all desired fetches and targets, and run calls that 38 // specify fetch values (from the setup calls) to retrieve. 39 // On the last partial run call, it is possible to satisfy the 40 // required fetches before the executor has completed running the graph to all 41 // the desired targets. 42 // PartialRunMgr is used to ensure that we don't complete and return the final 43 // partial run call to the user until both the partial run and executor have 44 // completed. 45 // 46 // PartialRunMgr is thread-safe. 47 class PartialRunMgr { 48 public: 49 // Find or create the CancellationManager associated with step_id. 50 // The PartialRunMgr owns the cancellation_manager. 51 // Returns true if a new CancellationManager was created 52 // (i.e this is a new partial run). 53 bool FindOrCreate(int step_id, CancellationManager** cancellation_manager); 54 55 // Calls the final callback if the PartialRunRequest has already completed. 56 // Otherwise stores the executor_status to be propagated when the 57 // PartialRunRequest completes (PartialRunDone has been called). 58 void ExecutorDone(int step_id, const Status& executor_status); 59 60 // Calls done if the executor has already completed (ExecutorDone has been 61 // called). Otherwise, stores the status and done callback, calling them when 62 // ExecutorDone is called. The callback will either be called by the calling 63 // thread of either PartialRunDone or ExecutorDone. 64 // If executor_status in ExecutorDone is not OK, it takes precedence over 65 // status and is passed to the done callback. 66 void PartialRunDone(int step_id, StatusCallback done, const Status& status); 67 68 private: 69 // PartialRunState stores state associated with a pending partial run request. 70 // This is protected by the mutex in PartialRunMgr. 71 struct PartialRunState { 72 std::unique_ptr<CancellationManager> cancellation_manager; 73 74 bool executor_done = false; 75 StatusCallback final_callback = nullptr; 76 Status final_status; 77 }; 78 79 mutex mu_; 80 81 std::unordered_map<int, std::unique_ptr<PartialRunState>> 82 step_id_to_partial_run_ GUARDED_BY(mu_); 83 }; 84 85 } // namespace tensorflow 86 87 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ 88