• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/utils.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/status_macros.h"
23 #include "tensorflow/core/common_runtime/eager/context.h"
24 #include "tensorflow/core/framework/device.h"
25 #include "tensorflow/core/tfrt/eager/virtual_device.h"
26 #include "tensorflow/core/tfrt/utils/error_util.h"
27 #include "tensorflow/core/tpu/virtual_device.h"
28 #include "tfrt/bef/bef_encoding.h"  // from @tf_runtime
29 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
30 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
31 #include "tfrt/host_context/chain.h"  // from @tf_runtime
32 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
33 #include "tfrt/host_context/function.h"  // from @tf_runtime
34 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
35 #include "tfrt/support/error_util.h"  // from @tf_runtime
36 #include "tfrt/support/string_util.h"  // from @tf_runtime
37 
38 namespace tfrt {
39 
40 using ::tensorflow::StatusOr;
41 
ConvertTfDeviceNameToTfrt(const char * device_name,tensorflow::EagerContext * eager_context)42 Expected<const char*> ConvertTfDeviceNameToTfrt(
43     const char* device_name, tensorflow::EagerContext* eager_context) {
44   // NOTE(fishx): We need to get tf_device first because DeviceMgr in current TF
45   // allows us get the device with simplified name like "CPU:0". However, TFRT
46   // DeviceManager only allows get device via its fullname.
47   tensorflow::Device* tf_device;
48   tensorflow::Status s =
49       eager_context->FindDeviceFromName(device_name, &tf_device);
50   if (!s.ok()) {
51     return MakeStringError(s.error_message());
52   }
53   return tf_device->name().c_str();
54 }
55 
ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype)56 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype) {
57   switch (dtype) {
58 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
59   case tensorflow::TF_DTYPE:        \
60     return DType(DType::TFRT_DTYPE);
61 #include "tensorflow/core/tfrt/utils/dtype.def"  // NOLINT
62     default:
63       return DType();
64   }
65 }
66 
RunRuntimeInitializer(const tfrt::ExecutionContext & exec_ctx,tfrt::BEFFile * bef_file,absl::string_view fallback_init_func)67 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx,
68                                          tfrt::BEFFile* bef_file,
69                                          absl::string_view fallback_init_func) {
70   auto* host = exec_ctx.host();
71 
72   auto* func = bef_file->GetFunction(
73       {fallback_init_func.data(), fallback_init_func.size()});
74   if (func == nullptr) return ::tensorflow::OkStatus();
75 
76   if (func->function_kind() == FunctionKind::kBEFFunction) {
77     auto ready_chain = GetReadyChain();
78 
79     DCHECK_EQ(func->argument_types().size(), 1);
80 
81     llvm::SmallVector<RCReference<AsyncValue>, 1> results;
82     results.resize(func->result_types().size());
83     DCHECK_EQ(results.size(), 1);
84 
85     func->Execute(exec_ctx, ready_chain.GetAsyncValue(), results);
86 
87     host->Await(results);
88 
89     if (auto* error = results[0]->GetErrorIfPresent()) {
90       return CreateTfErrorStatus(*error);
91     }
92   } else {
93     DCHECK_EQ(func->result_types().size(), 0);
94     if (auto err = ExecuteSyncBEFFunction(*func, exec_ctx, {}, {})) {
95       return tensorflow::errors::Internal(
96           tfrt::StrCat("Failed to run function: ", func->name(), err));
97     }
98   }
99 
100   return ::tensorflow::OkStatus();
101 }
102 
CreateDummyTfDevices(const std::vector<std::string> & device_names,std::vector<std::unique_ptr<tensorflow::Device>> * dummy_tf_devices)103 void CreateDummyTfDevices(
104     const std::vector<std::string>& device_names,
105     std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices) {
106   for (const auto& name : device_names) {
107     tensorflow::DeviceAttributes device_attrs =
108         tensorflow::Device::BuildDeviceAttributes(
109             name, tensorflow::DEVICE_TPU_SYSTEM, tensorflow::Bytes(16ULL << 30),
110             tensorflow::DeviceLocality(), "device: TFRT TPU SYSTEM device");
111     dummy_tf_devices->push_back(std::make_unique<tensorflow::VirtualDevice>(
112         tensorflow::Env::Default(), device_attrs));
113   }
114 }
115 
AddDummyTfrtDevices(const std::vector<std::string> & device_names,HostContext * host_ctx)116 void AddDummyTfrtDevices(const std::vector<std::string>& device_names,
117                          HostContext* host_ctx) {
118   for (const auto& name : device_names) {
119     host_ctx->GetDeviceManager()->MaybeAddDevice(
120         TakeRef(new tfrt::VirtualDevice(name)));
121   }
122 }
123 
CreateBefFileFromBefBuffer(const tensorflow::tfrt_stub::Runtime & runtime,const tfrt::BefBuffer & bef)124 StatusOr<RCReference<tfrt::BEFFile>> CreateBefFileFromBefBuffer(
125     const tensorflow::tfrt_stub::Runtime& runtime, const tfrt::BefBuffer& bef) {
126   auto* core_runtime = runtime.core_runtime();
127   DCHECK(core_runtime);
128   auto* host_context = core_runtime->GetHostContext();
129   DCHECK(host_context);
130   auto bef_file =
131       BEFFile::Open(bef, host_context->GetKernelRegistry(),
132                     host_context->diag_handler(), host_context->allocator());
133   TF_RET_CHECK(bef_file) << "failed to open BEF";
134   return bef_file;
135 }
136 
GetUniqueInt()137 int64_t GetUniqueInt() {
138   static std::atomic<int64_t> id(0);
139   return id.fetch_add(1, std::memory_order_relaxed);
140 }
141 
142 }  // namespace tfrt
143