1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_util.h"
16
17 #include "absl/strings/str_format.h"
18 #include "absl/strings/str_split.h"
19 #include "tensorflow/core/platform/random.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21
22 namespace tensorflow {
23 namespace tpu {
24
SessionNameFromMetadata(const SessionMetadata * session_metadata)25 std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) {
26 return session_metadata ? session_metadata->name() : "";
27 }
28
ProtoKeyForComputation(const std::string & key,int core)29 std::string ProtoKeyForComputation(const std::string& key, int core) {
30 return absl::StrCat(key, ":", core);
31 }
32
ParseCompilationCacheKey(const std::string & key)33 xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
34 const std::string& key) {
35 const std::vector<std::string> splits = absl::StrSplit(key, '|');
36 if (splits.size() == 1) {
37 // No guaranteed_const.
38 return TpuCompilationCacheKey(key);
39 } else if (splits.size() != 3) {
40 return errors::InvalidArgument("Invalid TPU compilation cache key:", key);
41 }
42
43 TpuCompilationCacheKey parsed_key(splits.at(0));
44 parsed_key.has_guaranteed_const = true;
45 parsed_key.session_handle = splits.at(1);
46 const string fingerprint = splits.at(2);
47 parsed_key.guaranteed_const_fingerprint = [fingerprint] {
48 return fingerprint;
49 };
50 return parsed_key;
51 }
52
53 xla::CompileOnlyClient::AotXlaComputationInstance
BuildAotXlaComputationInstance(const XlaCompiler::CompilationResult & compilation_result)54 BuildAotXlaComputationInstance(
55 const XlaCompiler::CompilationResult& compilation_result) {
56 xla::CompileOnlyClient::AotXlaComputationInstance instance;
57 instance.computation = compilation_result.computation.get();
58 for (const xla::Shape& shape : compilation_result.xla_input_shapes) {
59 instance.argument_layouts.push_back(&shape);
60 }
61 instance.result_layout = &compilation_result.xla_output_shape;
62 return instance;
63 }
64
ShapeTensorToTensorShape(const Tensor & tensor,TensorShape * shape)65 Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) {
66 if (tensor.dtype() != DT_INT64 ||
67 !TensorShapeUtils::IsVector(tensor.shape())) {
68 return errors::InvalidArgument("Shape tensor must be an int64 vector.");
69 }
70 const int64 rank = tensor.NumElements();
71 auto tensor_dims = tensor.flat<int64>();
72 std::vector<int64> dims(rank);
73 for (int64 i = 0; i < rank; ++i) {
74 dims[i] = tensor_dims(i);
75 }
76 return TensorShapeUtils::MakeShape(dims, shape);
77 }
78
DynamicShapesToTensorShapes(const OpInputList & dynamic_shapes,std::vector<TensorShape> * shapes)79 Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
80 std::vector<TensorShape>* shapes) {
81 shapes->resize(dynamic_shapes.size());
82 for (int i = 0; i < dynamic_shapes.size(); ++i) {
83 TF_RETURN_IF_ERROR(
84 ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
85 }
86 return Status::OK();
87 }
88
DynamicShapesToTensorShapes(const InputList & dynamic_shapes,std::vector<TensorShape> * shapes)89 Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
90 std::vector<TensorShape>* shapes) {
91 shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin());
92 size_t i = 0;
93 for (auto& dynamic_shape : dynamic_shapes) {
94 TF_RETURN_IF_ERROR(
95 ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i]));
96 ++i;
97 }
98 return Status::OK();
99 }
100
CreateServerBuilder(int serving_port)101 xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> CreateServerBuilder(
102 int serving_port) {
103 auto server_builder = absl::make_unique<::grpc::ServerBuilder>();
104 server_builder->AddListeningPort(
105 absl::StrFormat("[::]:%d", serving_port),
106 ::grpc::InsecureServerCredentials()); // NOLINT
107 return std::move(server_builder);
108 }
109 } // namespace tpu
110 } // namespace tensorflow
111