• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
16 
17 #include <cstdint>
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/core/lib/gtl/cleanup.h"
22 #include "tensorflow/core/tpu/tpu_compile_interface.h"
23 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
24 
25 namespace tensorflow {
26 namespace tpu {
27 namespace {
CreateShapePrefix(const std::vector<tensorflow::TensorShape> & dynamic_shapes)28 std::string CreateShapePrefix(
29     const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
30   std::string shapes_prefix;
31   for (const TensorShape& shape : dynamic_shapes) {
32     for (int64_t size : shape.dim_sizes()) {
33       absl::StrAppend(&shapes_prefix, size, ",");
34     }
35     absl::StrAppend(&shapes_prefix, ";");
36   }
37   return shapes_prefix;
38 }
39 
40 // Include compilation configurations of the arguments that are not captured
41 // by the called graph.
CreateConfigPrefix(const TPUCompileMetadataProto & metadata)42 std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
43   std::string config_prefix;
44   for (const auto& arg : metadata.args()) {
45     if (arg.is_same_data_across_replicas()) {
46       absl::StrAppend(&config_prefix, ":s");
47       // Same.
48     } else {
49       // Different.
50       absl::StrAppend(&config_prefix, ":");
51     }
52     if (arg.enable_xla_sharding() ==
53         tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
54       // Enabled.
55       absl::StrAppend(&config_prefix, "e");
56     }
57     if (arg.unrestricted_layout()) {
58       // Unrestricted.
59       absl::StrAppend(&config_prefix, ":u");
60     }
61     absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
62     if (arg.has_shape()) {
63       absl::StrAppend(&config_prefix, ",shape(");
64       for (const auto& dim : arg.shape().dim()) {
65         absl::StrAppend(&config_prefix, dim.size(), ",");
66       }
67       absl::StrAppend(&config_prefix, ")");
68     }
69   }
70   return config_prefix;
71 }
72 }  // namespace
73 
CreateFingerprintWithNameAndShapes(uint64 name,const std::vector<tensorflow::TensorShape> & shapes)74 uint64 CreateFingerprintWithNameAndShapes(
75     uint64 name, const std::vector<tensorflow::TensorShape>& shapes) {
76   std::string shape_prefix = CreateShapePrefix(shapes);
77   VLOG(2) << "CreateFingerprintWithNameAndShapes, name: " << name
78           << ", shape_prefix: " << shape_prefix;
79   return TpuCompileInterface::Get()->FingerprintString(
80       absl::StrCat(name, "_", shape_prefix));
81 }
82 
83 // Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
84 // data to compute the fingerprint.
GuaranteedConstFingerprint(const string & fingerprint_in_metadata,const OpInputList & guaranteed_constants)85 std::string GuaranteedConstFingerprint(
86     const string& fingerprint_in_metadata,
87     const OpInputList& guaranteed_constants) {
88   if (fingerprint_in_metadata.empty()) {
89     uint64_t fingerprint = 0;
90     for (const Tensor& constant : guaranteed_constants) {
91       fingerprint =
92           tpu::OpsApiFn()->TpuCompile_CreateGuaranteedConstFingerprintFn(
93               fingerprint, constant.tensor_data().data(),
94               constant.tensor_data().size());
95     }
96     return std::to_string(fingerprint);
97   } else {
98     return fingerprint_in_metadata;
99   }
100 }
101 
102 // The `guaranteed_constants` must be passed as reference due to the lazy
103 // evaluation of `guaranteed_const_fingerprint()` callback.
CreateCompilationCacheKey(absl::string_view function_name,uint64 function_library_fingerprint,uint64 mlir_module_fingerprint,const OpInputList & guaranteed_constants,const std::vector<TensorShape> & dynamic_shapes,const TPUCompileMetadataProto & metadata,const TpuMeshStateInterface & mesh_state)104 TpuCompilationCacheKey CreateCompilationCacheKey(
105     absl::string_view function_name, uint64 function_library_fingerprint,
106     uint64 mlir_module_fingerprint, const OpInputList& guaranteed_constants,
107     const std::vector<TensorShape>& dynamic_shapes,
108     const TPUCompileMetadataProto& metadata,
109     const TpuMeshStateInterface& mesh_state) {
110   VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
111   std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
112   VLOG(1) << "shapes_prefix = " << shapes_prefix;
113   std::string config_prefix = CreateConfigPrefix(metadata);
114   VLOG(1) << "config_prefix = " << config_prefix;
115   std::vector<int32_t> flattened_device_ids;
116   if (metadata.has_device_assignment()) {
117     for (const auto& device :
118          metadata.device_assignment().computation_devices()) {
119       flattened_device_ids.insert(flattened_device_ids.end(),
120                                   device.replica_device_ids().begin(),
121                                   device.replica_device_ids().end());
122     }
123   }
124   CompilationCacheKeyResult result =
125       tpu::OpsApiFn()->TpuCompile_CreateCompilationCacheKeyFn(
126           CompilationCacheKeyProperty{
127               config_prefix.data(),
128               shapes_prefix.data(),
129               function_name.data(),
130               mlir_module_fingerprint,
131               flattened_device_ids.data(),
132               flattened_device_ids.size(),
133               guaranteed_constants.size(),
134               function_library_fingerprint,
135               metadata.num_cores_per_replica(),
136               metadata.num_replicas(),
137               mesh_state.data(),
138           });
139   auto buffer_cleanup = gtl::MakeCleanup([result]() {
140     tpu::OpsApiFn()->TpuCompile_DestroyCompilationCacheKeyFn(result);
141   });
142   TpuCompilationCacheKey key;
143   key.prefix = result.key;
144   key.debug_string = result.debug_string;
145 
146   // Guaranteed constants can be different across sessions. Use session_handle
147   // and guaranteed_const fingerprint to guarantee no collision.
148   if (guaranteed_constants.size() > 0) {
149     key.has_guaranteed_const = true;
150     key.session_handle = metadata.session_handle();
151     // Both `metadata` and `guaranteed_constants` lifetime are captured by
152     // reference based on the assumption that these variables lifetime is
153     // managed through the `TPUCompileOpKernelImpl` that outlives the
154     // lifetime of the compilation cache lookups.
155     string fingerprint;
156     key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
157                                         fingerprint]() mutable {
158       if (fingerprint.empty()) {
159         fingerprint = GuaranteedConstFingerprint(
160             metadata.guaranteed_const_fingerprint(), guaranteed_constants);
161       }
162       return fingerprint;
163     };
164   }
165   return key;
166 }
167 }  // namespace tpu
168 }  // namespace tensorflow
169