1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_UTILS_H_ 16 #define TENSORFLOW_CORE_TFRT_UTILS_H_ 17 18 #include <string> 19 20 #include "tensorflow/core/framework/types.pb.h" 21 #include "tensorflow/core/lib/gtl/array_slice.h" 22 #include "tensorflow/core/platform/status.h" 23 #include "tfrt/dtype/dtype.h" // from @tf_runtime 24 #include "tfrt/support/forward_decls.h" // from @tf_runtime 25 26 namespace tensorflow { 27 class Device; 28 class EagerContext; 29 } 30 31 namespace tfrt { 32 33 class BEFFile; 34 class ExecutionContext; 35 class HostContext; 36 37 typedef tensorflow::gtl::InlinedVector<tfrt::DType, 4> TfrtDataTypeVector; 38 typedef tensorflow::gtl::ArraySlice<tfrt::DType> TfrtDataTypeSlice; 39 40 // TODO(b/161370736): Have a formal method to convert between TF's and TFRT's 41 // device name. Currently TFRT adopts the suffix of TF's device name, 42 // e.g. CPU:0. 43 Expected<const char*> ConvertTfDeviceNameToTfrt( 44 const char* device_name, tensorflow::EagerContext* eager_context); 45 46 DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype); 47 48 // Run the runtime initialization function. A runtime initialization function is 49 // added by runtime/compiler workflow and is not present in the original 50 // savedmodel. 51 // 52 // TODO(b/178714905): We should avoid special handling on initialization by 53 // letting compiler to handle it. 54 tensorflow::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, 55 tfrt::BEFFile* bef_file, 56 absl::string_view fallback_init_func); 57 58 // Create dummy TF devices from the input device names. Currently this method 59 // is used to create the TPU_SYSTEM device for worker server. 60 void CreateDummyTfDevices( 61 const std::vector<std::string>& device_names, 62 std::vector<std::unique_ptr<tensorflow::Device>>* dummy_tf_devices); 63 64 // Create and add dummy TFRT devices from the input device names. Currently 65 // this method is used to create the TPU_SYSTEM device for worker server. 66 void AddDummyTfrtDevices(const std::vector<std::string>& device_names, 67 HostContext* host_ctx); 68 69 } // namespace tfrt 70 71 #endif // TENSORFLOW_CORE_TFRT_UTILS_H_ 72