1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 16 #define TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 17 18 #include <memory> 19 20 #include "absl/flags/declare.h" 21 #include "absl/flags/flag.h" 22 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" 23 24 // TODO(chky): Move these flags to test-only targets. 25 ABSL_DECLARE_FLAG(std::string, tfrt_default_device); 26 ABSL_DECLARE_FLAG(bool, tfrt_enable_sync_logging); 27 ABSL_DECLARE_FLAG(bool, tfrt_enable_fallback); 28 ABSL_DECLARE_FLAG(int, tfrt_num_threads); 29 ABSL_DECLARE_FLAG(int, tfrt_num_blocking_threads); 30 31 namespace tfrt { 32 class CoreRuntime; 33 class ConcurrentWorkQueue; 34 } // namespace tfrt 35 36 namespace tensorflow { 37 namespace tfrt_stub { 38 39 // This defines the runtime abstraction in tensorflow for TFRT. It is supposed 40 // to provide tensorflow specific functionalities that are implemented using 41 // TFRT. Currently, the only intended uses for this class are: 42 // 1) Creating the runtime instance with user specified dependencies (eg. 43 // thread pool). 44 // 2) Creating tensors that can be used by the runtime. 45 // 46 // It is temporary and will be replaced by the official 47 // tensorflow::experimental::cc::Runtime when it lands. 48 class Runtime { 49 public: 50 // Creates a runtime instance with default configuration. Returns null upon 51 // creation error. 52 static std::unique_ptr<Runtime> Create(); 53 54 // Creates a runtime instance with the specified work_queue. Returns null upon 55 // creation error. 56 static std::unique_ptr<Runtime> Create( 57 std::unique_ptr<WorkQueueInterface> work_queue); 58 59 ~Runtime(); 60 61 Runtime(Runtime&&) = default; 62 Runtime& operator=(Runtime&&) = default; 63 64 // TODO(tfrt-devs): Add methods for creating TFRT tensors. 65 66 // TODO(chky): Make this method private as it should be only used by 67 // tfrt::SavedModel. Simply making tfrt::SavedModel a friend class does not 68 // work because the it resides in a different namespace. But we should 69 // consider moving it to the same namespace. core_runtime()70 tfrt::CoreRuntime* core_runtime() const { return core_runtime_.get(); } work_queue()71 WorkQueueInterface* work_queue() const { return work_queue_; } 72 73 private: 74 explicit Runtime(std::unique_ptr<tfrt::CoreRuntime> core_runtime, 75 WorkQueueInterface* work_queue); 76 77 std::unique_ptr<tfrt::CoreRuntime> core_runtime_; 78 WorkQueueInterface* work_queue_ = nullptr; 79 }; 80 81 } // namespace tfrt_stub 82 } // namespace tensorflow 83 84 #endif // TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ 85