1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/utils.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/framework/device.h"
25 #include "tensorflow/core/tfrt/eager/virtual_device.h"
26 #include "tensorflow/core/tfrt/utils/error_util.h"
27 #include "tensorflow/core/tpu/virtual_device.h"
28 #include "tfrt/bef/bef_encoding.h" // from @tf_runtime
29 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
30 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime
31 #include "tfrt/host_context/chain.h" // from @tf_runtime
32 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
33 #include "tfrt/host_context/function.h" // from @tf_runtime
34 #include "tfrt/host_context/host_context.h" // from @tf_runtime
35 #include "tfrt/support/error_util.h" // from @tf_runtime
36 #include "tfrt/support/string_util.h" // from @tf_runtime
37
38 namespace tfrt {
39
40 using ::tensorflow::StatusOr;
41
ConvertTfDeviceNameToTfrt(const char * device_name,tensorflow::EagerContext * eager_context)42 Expected<const char*> ConvertTfDeviceNameToTfrt(
43 const char* device_name, tensorflow::EagerContext* eager_context) {
44 // NOTE(fishx): We need to get tf_device first because DeviceMgr in current TF
45 // allows us get the device with simplified name like "CPU:0". However, TFRT
46 // DeviceManager only allows get device via its fullname.
47 tensorflow::Device* tf_device;
48 tensorflow::Status s =
49 eager_context->FindDeviceFromName(device_name, &tf_device);
50 if (!s.ok()) {
51 return MakeStringError(s.error_message());
52 }
53 return tf_device->name().c_str();
54 }
55
ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype)56 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype) {
57 switch (dtype) {
58 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
59 case tensorflow::TF_DTYPE: \
60 return DType(DType::TFRT_DTYPE);
61 #include "tensorflow/core/tfrt/utils/dtype.def" // NOLINT
62 default:
63 return DType();
64 }
65 }
66
RunRuntimeInitializer(const tfrt::ExecutionContext & exec_ctx,tfrt::BEFFile * bef_file,absl::string_view fallback_init_func)67 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx,
68 tfrt::BEFFile* bef_file,
69 absl::string_view fallback_init_func) {
70 auto* host = exec_ctx.host();
71
72 auto* func = bef_file->GetFunction(
73 {fallback_init_func.data(), fallback_init_func.size()});
74 if (func == nullptr) return ::tensorflow::OkStatus();
75
76 if (func->function_kind() == FunctionKind::kBEFFunction) {
77 auto ready_chain = GetReadyChain();
78
79 DCHECK_EQ(func->argument_types().size(), 1);
80
81 llvm::SmallVector<RCReference<AsyncValue>, 1> results;
82 results.resize(func->result_types().size());
83 DCHECK_EQ(results.size(), 1);
84
85 func->Execute(exec_ctx, ready_chain.GetAsyncValue(), results);
86
87 host->Await(results);
88
89 if (auto* error = results[0]->GetErrorIfPresent()) {
90 return CreateTfErrorStatus(*error);
91 }
92 } else {
93 DCHECK_EQ(func->result_types().size(), 0);
94 if (auto err = ExecuteSyncBEFFunction(*func, exec_ctx, {}, {})) {
95 return tensorflow::errors::Internal(
96 tfrt::StrCat("Failed to run function: ", func->name(), err));
97 }
98 }
99
100 return ::tensorflow::OkStatus();
101 }
102
CreateDummyTfDevices(const std::vector<std::string> & device_names,std::vector<std::unique_ptr<tensorflow::Device>> * dummy_tf_devices)103 void CreateDummyTfDevices(
104 const std::vector<std::string>& device_names,
105 std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices) {
106 for (const auto& name : device_names) {
107 tensorflow::DeviceAttributes device_attrs =
108 tensorflow::Device::BuildDeviceAttributes(
109 name, tensorflow::DEVICE_TPU_SYSTEM, tensorflow::Bytes(16ULL << 30),
110 tensorflow::DeviceLocality(), "device: TFRT TPU SYSTEM device");
111 dummy_tf_devices->push_back(std::make_unique<tensorflow::VirtualDevice>(
112 tensorflow::Env::Default(), device_attrs));
113 }
114 }
115
AddDummyTfrtDevices(const std::vector<std::string> & device_names,HostContext * host_ctx)116 void AddDummyTfrtDevices(const std::vector<std::string>& device_names,
117 HostContext* host_ctx) {
118 for (const auto& name : device_names) {
119 host_ctx->GetDeviceManager()->MaybeAddDevice(
120 TakeRef(new tfrt::VirtualDevice(name)));
121 }
122 }
123
CreateBefFileFromBefBuffer(const tensorflow::tfrt_stub::Runtime & runtime,const tfrt::BefBuffer & bef)124 StatusOr<RCReference<tfrt::BEFFile>> CreateBefFileFromBefBuffer(
125 const tensorflow::tfrt_stub::Runtime& runtime, const tfrt::BefBuffer& bef) {
126 auto* core_runtime = runtime.core_runtime();
127 DCHECK(core_runtime);
128 auto* host_context = core_runtime->GetHostContext();
129 DCHECK(host_context);
130 auto bef_file =
131 BEFFile::Open(bef, host_context->GetKernelRegistry(),
132 host_context->diag_handler(), host_context->allocator());
133 TF_RET_CHECK(bef_file) << "failed to open BEF";
134 return bef_file;
135 }
136
GetUniqueInt()137 int64_t GetUniqueInt() {
138 static std::atomic<int64_t> id(0);
139 return id.fetch_add(1, std::memory_order_relaxed);
140 }
141
142 } // namespace tfrt
143