1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/proto/descriptors.h"
17
18 #include "absl/strings/match.h"
19 #include "absl/strings/strip.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/reader_op_kernel.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/protobuf.h"
24 #include "tensorflow/core/util/proto/descriptor_pool_registry.h"
25
26 namespace tensorflow {
27 namespace {
28
CreatePoolFromSet(const protobuf::FileDescriptorSet & set,std::unique_ptr<protobuf::DescriptorPool> * out_pool)29 Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set,
30 std::unique_ptr<protobuf::DescriptorPool>* out_pool) {
31 *out_pool = absl::make_unique<protobuf::DescriptorPool>();
32 for (const auto& file : set.file()) {
33 if ((*out_pool)->BuildFile(file) == nullptr) {
34 return errors::InvalidArgument("Failed to load FileDescriptorProto: ",
35 file.DebugString());
36 }
37 }
38 return OkStatus();
39 }
40
41 // Build a `DescriptorPool` from the named file or URI. The file or URI
42 // must be available to the current TensorFlow environment.
43 //
44 // The file must contain a serialized `FileDescriptorSet`. See
45 // `GetDescriptorPool()` for more information.
GetDescriptorPoolFromFile(tensorflow::Env * env,const string & filename,std::unique_ptr<protobuf::DescriptorPool> * owned_desc_pool)46 Status GetDescriptorPoolFromFile(
47 tensorflow::Env* env, const string& filename,
48 std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
49 Status st = env->FileExists(filename);
50 if (!st.ok()) {
51 return st;
52 }
53 // Read and parse the FileDescriptorSet.
54 protobuf::FileDescriptorSet descs;
55 std::unique_ptr<ReadOnlyMemoryRegion> buf;
56 st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
57 if (!st.ok()) {
58 return st;
59 }
60 if (!descs.ParseFromArray(buf->data(), buf->length())) {
61 return errors::InvalidArgument(
62 "descriptor_source contains invalid FileDescriptorSet: ", filename);
63 }
64 return CreatePoolFromSet(descs, owned_desc_pool);
65 }
66
GetDescriptorPoolFromBinary(const string & source,std::unique_ptr<protobuf::DescriptorPool> * owned_desc_pool)67 Status GetDescriptorPoolFromBinary(
68 const string& source,
69 std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
70 if (!absl::StartsWith(source, "bytes://")) {
71 return errors::InvalidArgument(absl::StrCat(
72 "Source does not represent serialized file descriptor set proto. ",
73 "This may be due to a missing dependency on the file containing ",
74 "REGISTER_DESCRIPTOR_POOL(\"", source, "\", ...);"));
75 }
76 // Parse the FileDescriptorSet.
77 protobuf::FileDescriptorSet proto;
78 if (!proto.ParseFromString(string(absl::StripPrefix(source, "bytes://")))) {
79 return errors::InvalidArgument(absl::StrCat(
80 "Source does not represent serialized file descriptor set proto. ",
81 "This may be due to a missing dependency on the file containing ",
82 "REGISTER_DESCRIPTOR_POOL(\"", source, "\", ...);"));
83 }
84 return CreatePoolFromSet(proto, owned_desc_pool);
85 }
86
87 } // namespace
88
GetDescriptorPool(Env * env,string const & descriptor_source,protobuf::DescriptorPool const ** desc_pool,std::unique_ptr<protobuf::DescriptorPool> * owned_desc_pool)89 Status GetDescriptorPool(
90 Env* env, string const& descriptor_source,
91 protobuf::DescriptorPool const** desc_pool,
92 std::unique_ptr<protobuf::DescriptorPool>* owned_desc_pool) {
93 // Attempt to lookup the pool in the registry.
94 auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
95 if (pool_fn != nullptr) {
96 return (*pool_fn)(desc_pool, owned_desc_pool);
97 }
98
99 // If there is no pool function registered for the given source, let the
100 // runtime find the file or URL.
101 Status status =
102 GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
103 if (status.ok()) {
104 *desc_pool = owned_desc_pool->get();
105 return OkStatus();
106 }
107
108 status = GetDescriptorPoolFromBinary(descriptor_source, owned_desc_pool);
109 *desc_pool = owned_desc_pool->get();
110 return status;
111 }
112
113 } // namespace tensorflow
114