1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ 18 19 #include <map> 20 #include <memory> 21 #include <utility> 22 23 #include "tensorflow/compiler/xla/executable_run_options.h" 24 #include "tensorflow/compiler/xla/service/backend.h" 25 #include "tensorflow/compiler/xla/service/stream_pool.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/util.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/macros.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 38 // Represents an asynchronously launched execution. Owns the stream (from the 39 // passed run_options->stream()) on which the execution is launched and releases 40 // the stream when destructed. 41 class AsyncExecution { 42 public: 43 AsyncExecution(Backend* backend, std::vector<StreamPool::Ptr> streams, 44 const ExecutionProfile& profile, GlobalDataHandle result); 45 46 Status BlockUntilDone() const; 47 result()48 const GlobalDataHandle& result() const { return result_; } 49 profile()50 const ExecutionProfile& profile() const { return profile_; } 51 52 private: 53 // Backend to execute the computation on. 54 Backend* backend_; 55 56 // Stream on which the execution is launched. 57 std::vector<StreamPool::Ptr> streams_; 58 59 // Profile object of the execution to be returned to the user. 60 ExecutionProfile profile_; 61 62 // Data handle to the result of the execution. Data represented by this handle 63 // is valid only after BlockUntilDone() is called. 64 GlobalDataHandle result_; 65 }; 66 67 // Tracks asynchronously launched executions for the XLA service. 68 class ExecutionTracker { 69 public: 70 ExecutionTracker(); 71 72 // Registers an execution with its backend, streams, and data handle to the 73 // execution result. Returns a handle for the registered execution. 74 ExecutionHandle Register(Backend* backend, 75 std::vector<StreamPool::Ptr> stream, 76 const ExecutionProfile& profile, 77 GlobalDataHandle data); 78 79 // Unregisters the execution for the given handle. 80 Status Unregister(const ExecutionHandle& handle); 81 82 // Resolves the given ExecutionHandle to an AsyncExecution. Returns an 83 // error status if the given handle is not found, which means that the 84 // execution is not yet registered or already unregistered. 85 StatusOr<const AsyncExecution*> Resolve(const ExecutionHandle& handle); 86 87 private: 88 // The next handle to assign to an execution. 89 int64 next_handle_ TF_GUARDED_BY(execution_mutex_); 90 91 // Mapping from ExecutionHandle handle to the corresponding registered 92 // AsyncExecution object. 93 std::map<int64, std::unique_ptr<AsyncExecution>> handle_to_execution_ 94 TF_GUARDED_BY(execution_mutex_); 95 96 tensorflow::mutex execution_mutex_; // Guards the execution mapping. 97 98 TF_DISALLOW_COPY_AND_ASSIGN(ExecutionTracker); 99 }; 100 101 } // namespace xla 102 103 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_ 104