• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_TOOLS_TFG_GRAPH_TRANSFORMS_UTILS_H_
17 #define TENSORFLOW_TOOLS_TFG_GRAPH_TRANSFORMS_UTILS_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/platform/env.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/path.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/protobuf/saved_model.pb.h"
26 
27 namespace mlir {
28 namespace tfg {
29 namespace graph_transforms {
30 
31 // Reads the model proto from `input_file`.
32 // If the format of proto cannot be identified based on the file extension,
33 // attempts to load in a binary format first and then in a text format.
34 template <class T>
ReadModelProto(const std::string & input_file,T & model_proto)35 tensorflow::Status ReadModelProto(const std::string& input_file,
36                                   T& model_proto) {
37   // Proto might be either in binary or text format.
38   tensorflow::StringPiece extension = tensorflow::io::Extension(input_file);
39   bool binary_extenstion = !extension.compare("pb");
40   bool text_extension = !extension.compare("pbtxt");
41 
42   if (!binary_extenstion && !text_extension) {
43     LOG(WARNING) << "Proto type cannot be identified based on the extension";
44     // Try load binary first.
45     auto status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
46                                               input_file, &model_proto);
47     if (status.ok()) {
48       return status;
49     }
50 
51     // Binary proto loading failed, attempt to load text proto.
52     return tensorflow::ReadTextProto(tensorflow::Env::Default(), input_file,
53                                      &model_proto);
54   }
55 
56   if (binary_extenstion) {
57     return tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input_file,
58                                        &model_proto);
59   }
60 
61   if (text_extension) {
62     return tensorflow::ReadTextProto(tensorflow::Env::Default(), input_file,
63                                      &model_proto);
64   }
65 
66   return tensorflow::errors::InvalidArgument(
67       "Expected either binary or text protobuf");
68 }
69 
70 // Best effort to identify if the protobuf file `input_file` is
71 // in a text or binary format.
72 bool IsTextProto(const std::string& input_file);
73 
74 template <class T>
SerializeProto(T model_proto,const std::string & output_file)75 tensorflow::Status SerializeProto(T model_proto,
76                                   const std::string& output_file) {
77   auto output_dir = tensorflow::io::Dirname(output_file);
78 
79   TF_RETURN_IF_ERROR(tensorflow::Env::Default()->RecursivelyCreateDir(
80       {output_dir.data(), output_dir.length()}));
81   if (IsTextProto(output_file)) {
82     TF_RETURN_WITH_CONTEXT_IF_ERROR(
83         tensorflow::WriteTextProto(tensorflow::Env::Default(), output_file,
84                                    model_proto),
85         "Error while writing the resulting model proto");
86   } else {
87     TF_RETURN_WITH_CONTEXT_IF_ERROR(
88         tensorflow::WriteBinaryProto(tensorflow::Env::Default(), output_file,
89                                      model_proto),
90         "Error while writing the resulting model proto");
91   }
92   return ::tensorflow::OkStatus();
93 }
94 
95 }  // namespace graph_transforms
96 }  // namespace tfg
97 }  // namespace mlir
98 
99 #endif  // TENSORFLOW_TOOLS_TFG_GRAPH_TRANSFORMS_UTILS_H_
100