• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_H_
16 #define TENSORFLOW_CORE_TFRT_UTILS_H_
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/lib/gtl/array_slice.h"
22 #include "tensorflow/core/platform/status.h"
23 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
24 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
25 
26 namespace tensorflow {
27 class Device;
28 class EagerContext;
29 }
30 
31 namespace tfrt {
32 
33 class BEFFile;
34 class ExecutionContext;
35 class HostContext;
36 
37 typedef tensorflow::gtl::InlinedVector<tfrt::DType, 4> TfrtDataTypeVector;
38 typedef tensorflow::gtl::ArraySlice<tfrt::DType> TfrtDataTypeSlice;
39 
40 // TODO(b/161370736): Have a formal method to convert between TF's and TFRT's
41 // device name. Currently TFRT adopts the suffix of TF's device name,
42 // e.g. CPU:0.
43 Expected<const char*> ConvertTfDeviceNameToTfrt(
44     const char* device_name, tensorflow::EagerContext* eager_context);
45 
46 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype);
47 
48 // Run the runtime initialization function. A runtime initialization function is
49 // added by runtime/compiler workflow and is not present in the original
50 // savedmodel.
51 //
52 // TODO(b/178714905): We should avoid special handling on initialization by
53 // letting compiler to handle it.
54 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx,
55                                          tfrt::BEFFile* bef_file,
56                                          absl::string_view fallback_init_func);
57 
58 // Create dummy TF devices from the input device names. Currently this method
59 // is used to create the TPU_SYSTEM device for worker server.
60 void CreateDummyTfDevices(
61     const std::vector<std::string>& device_names,
62     std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices);
63 
64 // Create and add dummy TFRT devices from the input device names. Currently
65 // this method is used to create the TPU_SYSTEM device for worker server.
66 void AddDummyTfrtDevices(const std::vector<std::string>& device_names,
67                          HostContext* host_ctx);
68 
69 }  // namespace tfrt
70 
71 #endif  // TENSORFLOW_CORE_TFRT_UTILS_H_
72