• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/utils/utils.h"
16 
17 #include "tensorflow/core/common_runtime/eager/context.h"
18 #include "tensorflow/core/framework/device.h"
19 #include "tensorflow/core/tfrt/eager/virtual_device.h"
20 #include "tensorflow/core/tpu/virtual_device.h"
21 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
22 #include "tfrt/host_context/chain.h"  // from @tf_runtime
23 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
24 #include "tfrt/host_context/function.h"  // from @tf_runtime
25 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
26 #include "tfrt/support/error_util.h"  // from @tf_runtime
27 
28 namespace tfrt {
29 
ConvertTfDeviceNameToTfrt(const char * device_name,tensorflow::EagerContext * eager_context)30 Expected<const char*> ConvertTfDeviceNameToTfrt(
31     const char* device_name, tensorflow::EagerContext* eager_context) {
32   // NOTE(fishx): We need to get tf_device first because DeviceMgr in current TF
33   // allows us get the device with simplified name like "CPU:0". However, TFRT
34   // DeviceManager only allows get device via its fullname.
35   tensorflow::Device* tf_device;
36   tensorflow::Status s =
37       eager_context->FindDeviceFromName(device_name, &tf_device);
38   if (!s.ok()) {
39     return MakeStringError(s.error_message());
40   }
41   return tf_device->name().c_str();
42 }
43 
ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype)44 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype) {
45   switch (dtype) {
46 #define DTYPE(TFRT_DTYPE, TF_DTYPE) \
47   case tensorflow::TF_DTYPE:        \
48     return DType(DType::TFRT_DTYPE);
49 #include "tensorflow/core/tfrt/utils/dtype.def"  // NOLINT
50     default:
51       return DType();
52   }
53 }
54 
RunRuntimeInitializer(const tfrt::ExecutionContext & exec_ctx,tfrt::BEFFile * bef_file,absl::string_view fallback_init_func)55 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx,
56                                          tfrt::BEFFile* bef_file,
57                                          absl::string_view fallback_init_func) {
58   auto* host = exec_ctx.host();
59 
60   auto* func = bef_file->GetFunction(
61       {fallback_init_func.data(), fallback_init_func.size()});
62   if (func == nullptr) return tensorflow::Status::OK();
63 
64   auto ready_chain = GetReadyChain(host);
65 
66   DCHECK_EQ(func->argument_types().size(), 1);
67 
68   llvm::SmallVector<RCReference<AsyncValue>, 1> results;
69   results.resize(func->result_types().size());
70   DCHECK_EQ(results.size(), 1);
71 
72   func->Execute(exec_ctx, ready_chain.GetAsyncValue(), results);
73 
74   host->Await(results);
75 
76   if (auto* error = results[0]->GetErrorIfPresent()) {
77     return tensorflow::errors::Internal(error->message);
78   }
79 
80   return tensorflow::Status::OK();
81 }
82 
CreateDummyTfDevices(const std::vector<std::string> & device_names,std::vector<std::unique_ptr<tensorflow::Device>> * dummy_tf_devices)83 void CreateDummyTfDevices(
84     const std::vector<std::string>& device_names,
85     std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices) {
86   for (const auto& name : device_names) {
87     tensorflow::DeviceAttributes device_attrs =
88         tensorflow::Device::BuildDeviceAttributes(
89             name, tensorflow::DEVICE_TPU_SYSTEM, tensorflow::Bytes(16ULL << 30),
90             tensorflow::DeviceLocality(), "device: TFRT TPU SYSTEM device");
91     dummy_tf_devices->push_back(std::make_unique<tensorflow::VirtualDevice>(
92         tensorflow::Env::Default(), device_attrs));
93   }
94 }
95 
AddDummyTfrtDevices(const std::vector<std::string> & device_names,HostContext * host_ctx)96 void AddDummyTfrtDevices(const std::vector<std::string>& device_names,
97                          HostContext* host_ctx) {
98   for (const auto& name : device_names) {
99     host_ctx->GetDeviceManager()->MaybeAddDevice(
100         TakeRef(new tfrt::VirtualDevice(name)));
101   }
102 }
103 
104 }  // namespace tfrt
105