• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
16 #define TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
17 
18 #include <memory>
19 
20 #include "absl/flags/declare.h"
21 #include "absl/flags/flag.h"
22 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
23 
24 // TODO(chky): Move these flags to test-only targets.
25 ABSL_DECLARE_FLAG(std::string, tfrt_default_device);
26 ABSL_DECLARE_FLAG(bool, tfrt_enable_sync_logging);
27 ABSL_DECLARE_FLAG(bool, tfrt_enable_fallback);
28 ABSL_DECLARE_FLAG(int, tfrt_num_threads);
29 ABSL_DECLARE_FLAG(int, tfrt_num_blocking_threads);
30 
31 namespace tfrt {
32 class CoreRuntime;
33 class ConcurrentWorkQueue;
34 }  // namespace tfrt
35 
36 namespace tensorflow {
37 namespace tfrt_stub {
38 
39 // This defines the runtime abstraction in tensorflow for TFRT. It is supposed
40 // to provide tensorflow specific functionalities that are implemented using
41 // TFRT. Currently, the only intended uses for this class are:
42 //  1) Creating the runtime instance with user specified dependencies (eg.
43 //  thread pool).
44 //  2) Creating tensors that can be used by the runtime.
45 //
46 // It is temporary and will be replaced by the official
47 // tensorflow::experimental::cc::Runtime when it lands.
48 class Runtime {
49  public:
50   // Creates a runtime instance with default configuration. Returns null upon
51   // creation error.
52   static std::unique_ptr<Runtime> Create();
53 
54   // Creates a runtime instance with the specified work_queue. Returns null upon
55   // creation error.
56   static std::unique_ptr<Runtime> Create(
57       std::unique_ptr<WorkQueueInterface> work_queue);
58 
59   ~Runtime();
60 
61   Runtime(Runtime&&) = default;
62   Runtime& operator=(Runtime&&) = default;
63 
64   // TODO(tfrt-devs): Add methods for creating TFRT tensors.
65 
66   // TODO(chky): Make this method private as it should be only used by
67   // tfrt::SavedModel. Simply making tfrt::SavedModel a friend class does not
68   // work because the it resides in a different namespace. But we should
69   // consider moving it to the same namespace.
core_runtime()70   tfrt::CoreRuntime* core_runtime() const { return core_runtime_.get(); }
work_queue()71   WorkQueueInterface* work_queue() const { return work_queue_; }
72 
73  private:
74   explicit Runtime(std::unique_ptr<tfrt::CoreRuntime> core_runtime,
75                    WorkQueueInterface* work_queue);
76 
77   std::unique_ptr<tfrt::CoreRuntime> core_runtime_;
78   WorkQueueInterface* work_queue_ = nullptr;
79 };
80 
81 }  // namespace tfrt_stub
82 }  // namespace tensorflow
83 
84 #endif  // TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_
85