1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
16
17 #include <cstdint>
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/tpu/tpu_compile_interface.h"
23 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
24
25 namespace tensorflow {
26 namespace tpu {
27 namespace {
CreateShapePrefix(const std::vector<tensorflow::TensorShape> & dynamic_shapes)28 std::string CreateShapePrefix(
29 const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
30 std::string shapes_prefix;
31 for (const TensorShape& shape : dynamic_shapes) {
32 for (int64_t size : shape.dim_sizes()) {
33 absl::StrAppend(&shapes_prefix, size, ",");
34 }
35 absl::StrAppend(&shapes_prefix, ";");
36 }
37 return shapes_prefix;
38 }
39
40 // Include compilation configurations of the arguments that are not captured
41 // by the called graph.
CreateConfigPrefix(const TPUCompileMetadataProto & metadata)42 std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
43 std::string config_prefix;
44 for (const auto& arg : metadata.args()) {
45 if (arg.is_same_data_across_replicas()) {
46 absl::StrAppend(&config_prefix, ":s");
47 // Same.
48 } else {
49 // Different.
50 absl::StrAppend(&config_prefix, ":");
51 }
52 if (arg.enable_xla_sharding() ==
53 tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
54 // Enabled.
55 absl::StrAppend(&config_prefix, "e");
56 }
57 if (arg.unrestricted_layout()) {
58 // Unrestricted.
59 absl::StrAppend(&config_prefix, ":u");
60 }
61 absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
62 if (arg.has_shape()) {
63 absl::StrAppend(&config_prefix, ",shape(");
64 for (const auto& dim : arg.shape().dim()) {
65 absl::StrAppend(&config_prefix, dim.size(), ",");
66 }
67 absl::StrAppend(&config_prefix, ")");
68 }
69 }
70 return config_prefix;
71 }
72 } // namespace
73
CreateFingerprintWithNameAndShapes(uint64 name,const std::vector<tensorflow::TensorShape> & shapes)74 uint64 CreateFingerprintWithNameAndShapes(
75 uint64 name, const std::vector<tensorflow::TensorShape>& shapes) {
76 std::string shape_prefix = CreateShapePrefix(shapes);
77 VLOG(2) << "CreateFingerprintWithNameAndShapes, name: " << name
78 << ", shape_prefix: " << shape_prefix;
79 return TpuCompileInterface::Get()->FingerprintString(
80 absl::StrCat(name, "_", shape_prefix));
81 }
82
83 // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
84 // data to compute the fingerprint.
GuaranteedConstFingerprint(const string & fingerprint_in_metadata,const OpInputList & guaranteed_constants)85 std::string GuaranteedConstFingerprint(
86 const string& fingerprint_in_metadata,
87 const OpInputList& guaranteed_constants) {
88 if (fingerprint_in_metadata.empty()) {
89 uint64_t fingerprint = 0;
90 for (const Tensor& constant : guaranteed_constants) {
91 fingerprint =
92 tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
93 fingerprint, constant.tensor_data().data(),
94 constant.tensor_data().size());
95 }
96 return std::to_string(fingerprint);
97 } else {
98 return fingerprint_in_metadata;
99 }
100 }
101
102 // The `guaranteed_constants` must be passed as reference due to the lazy
103 // evaluation of `guaranteed_const_fingerprint()` callback.
CreateCompilationCacheKey(absl::string_view function_name,uint64 function_library_fingerprint,uint64 mlir_module_fingerprint,const OpInputList & guaranteed_constants,const std::vector<TensorShape> & dynamic_shapes,const TPUCompileMetadataProto & metadata,const TpuMeshStateInterface & mesh_state)104 TpuCompilationCacheKey CreateCompilationCacheKey(
105 absl::string_view function_name, uint64 function_library_fingerprint,
106 uint64 mlir_module_fingerprint, const OpInputList& guaranteed_constants,
107 const std::vector<TensorShape>& dynamic_shapes,
108 const TPUCompileMetadataProto& metadata,
109 const TpuMeshStateInterface& mesh_state) {
110 VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
111 std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
112 VLOG(1) << "shapes_prefix = " << shapes_prefix;
113 std::string config_prefix = CreateConfigPrefix(metadata);
114 VLOG(1) << "config_prefix = " << config_prefix;
115 std::vector<int32_t> flattened_device_ids;
116 if (metadata.has_device_assignment()) {
117 for (const auto& device :
118 metadata.device_assignment().computation_devices()) {
119 flattened_device_ids.insert(flattened_device_ids.end(),
120 device.replica_device_ids().begin(),
121 device.replica_device_ids().end());
122 }
123 }
124 CompilationCacheKeyResult result =
125 tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
126 CompilationCacheKeyProperty{
127 config_prefix.data(),
128 shapes_prefix.data(),
129 function_name.data(),
130 mlir_module_fingerprint,
131 flattened_device_ids.data(),
132 flattened_device_ids.size(),
133 guaranteed_constants.size(),
134 function_library_fingerprint,
135 metadata.num_cores_per_replica(),
136 metadata.num_replicas(),
137 mesh_state.data(),
138 });
139 auto buffer_cleanup = gtl::MakeCleanup([result]() {
140 tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
141 });
142 TpuCompilationCacheKey key;
143 key.prefix = result.key;
144 key.debug_string = result.debug_string;
145
146 // Guaranteed constants can be different across sessions. Use session_handle
147 // and guaranteed_const fingerprint to guarantee no collision.
148 if (guaranteed_constants.size() > 0) {
149 key.has_guaranteed_const = true;
150 key.session_handle = metadata.session_handle();
151 // Both `metadata` and `guaranteed_constants` lifetime are captured by
152 // reference based on the assumption that these variables lifetime is
153 // managed through the `TPUCompileOpKernelImpl` that outlives the
154 // lifetime of the compilation cache lookups.
155 string fingerprint;
156 key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
157 fingerprint]() mutable {
158 if (fingerprint.empty()) {
159 fingerprint = GuaranteedConstFingerprint(
160 metadata.guaranteed_const_fingerprint(), guaranteed_constants);
161 }
162 return fingerprint;
163 };
164 }
165 return key;
166 }
167 } // namespace tpu
168 } // namespace tensorflow
169