1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
17 #define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
18
19 #include "google/protobuf/duration.pb.h"
20 #include "absl/strings/string_view.h"
21 #include "absl/time/time.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/protobuf.h"
25
26 namespace tensorflow {
27 namespace proto_utils {
28
29 using tensorflow::protobuf::FieldDescriptor;
30
31 // Returns true if the proto field type can be converted to the tensor dtype.
32 bool IsCompatibleType(FieldDescriptor::Type field_type, DataType dtype);
33
34 // Parses a text-formatted protobuf from a string into the given Message* output
35 // and returns status OK if valid, or INVALID_ARGUMENT with an accompanying
36 // parser error message if the text format is invalid.
37 Status ParseTextFormatFromString(absl::string_view input,
38 protobuf::Message* output);
39
40 class StringErrorCollector : public protobuf::io::ErrorCollector {
41 public:
42 // String error_text is unowned and must remain valid during the use of
43 // StringErrorCollector.
44 explicit StringErrorCollector(string* error_text);
45 // If one_indexing is set to true, all line and column numbers will be
46 // increased by one for cases when provided indices are 0-indexed and
47 // 1-indexed error messages are desired
48 StringErrorCollector(string* error_text, bool one_indexing);
49 StringErrorCollector(const StringErrorCollector&) = delete;
50 StringErrorCollector& operator=(const StringErrorCollector&) = delete;
51
52 // Implementation of protobuf::io::ErrorCollector::AddError.
53 void AddError(int line, int column, const string& message) override;
54
55 // Implementation of protobuf::io::ErrorCollector::AddWarning.
56 void AddWarning(int line, int column, const string& message) override;
57
58 private:
59 string* const error_text_;
60 const int index_offset_;
61 };
62
63 // Converts an absl::Duration to a google::protobuf::Duration.
ToDurationProto(absl::Duration duration)64 inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
65 google::protobuf::Duration proto;
66 proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
67 proto.set_nanos(
68 absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
69 return proto;
70 }
71
72 // Converts a google::protobuf::Duration to an absl::Duration.
FromDurationProto(google::protobuf::Duration proto)73 inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
74 return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
75 }
76
77 } // namespace proto_utils
78 } // namespace tensorflow
79
80 #endif // TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_
81