1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ 18 19 #include <string> 20 21 #include "absl/container/flat_hash_set.h" 22 #include "llvm/ADT/MapVector.h" 23 #include "llvm/ADT/StringMap.h" 24 #include "tensorflow/core/framework/tensor_shape.pb.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/lib/core/status.h" 28 29 namespace tensorflow { 30 31 struct ArrayInfo { 32 // The node type when the input node is imported. Typically needs to be 33 // specified when passing arbitrary nodes (some node attributes are removed). 34 DataType imported_dtype; 35 36 // Node "shape" attribute value. 37 TensorShapeProto shape; 38 }; 39 40 struct GraphImportConfig { 41 // Returns string representation of config. 42 std::string str() const; 43 44 using InputArrays = 45 llvm::MapVector<std::string, ArrayInfo, llvm::StringMap<unsigned>>; 46 // The name assigned to the function which is the import result of the given 47 // graph. If empty, a default one will be used. 48 std::string graph_func_name; 49 // Maps input node names to node data types and shapes. 50 InputArrays inputs; 51 // name:index strings for the data outputs. 52 std::vector<string> outputs; 53 // name strings for the control outputs. 54 std::vector<string> control_outputs; 55 // Setting prune_unused_nodes to true, would prune unreachable nodes if 56 // output_arrays is specified. 57 bool prune_unused_nodes = false; 58 // If true, inputs of type LegacyFedInput are replaced with Placeholder ops. 59 // LegacyFedInput ops have two outputs unlike Placeholder which has only one 60 // output, so if both outputs of the LegacyFedInput ops are used then returns 61 // an error. 62 bool convert_legacy_fed_inputs = false; 63 // If true, the main graph will be treated as a function. 64 bool graph_as_function = false; 65 // If true, upgrade legacy features of the graph (for instance, functionalize 66 // control-flow). 67 bool upgrade_legacy = false; 68 // If true, functionalization is restricted to TPU nodes. This is only needed 69 // if upgrade_legacy is true and if upgrading legacy features of the graph 70 // (which includes functionalization) runs before TPU cluster extraction, as 71 // for example in the MLIR-based TPU bridge. Otherwise, this parameter should 72 // stay false. 73 bool restrict_functionalization_to_tpu_nodes = false; 74 // If true, enables shape inference on input. 75 // TODO(jpienaar): This will be removed shortly. 76 bool enable_shape_inference = true; 77 }; 78 79 struct GraphExportConfig { 80 // Whether to export shape attribute for the NodeDefs in the GraphDef. 81 bool export_shapes = true; 82 // Whether to export library field in the GraphDef. 83 bool export_library = true; 84 // Whether to export debug original node name in the GraphDef. 85 bool export_debug_info = true; 86 // Whether to export the entry function to function library instead of the 87 // graph. 88 bool export_entry_func_to_flib = false; 89 }; 90 91 // Parses the command line flag strings to the specification of nodes in 92 // the Graph. 93 Status ParseOutputArrayInfo(absl::string_view array_names, 94 std::vector<string>* outputs); 95 96 Status ParseOutputArrayInfo(const std::vector<string>& output_names, 97 std::vector<string>* outputs); 98 99 // Parses the command line flag strings to the specification of nodes in 100 // the Graph. `data_types` input string can be empty since the flag is optional. 101 Status ParseInputArrayInfo(absl::string_view array_names, 102 absl::string_view data_types, 103 absl::string_view shapes, 104 GraphImportConfig::InputArrays* inputs); 105 106 Status ParseInputArrayInfo( 107 const std::vector<string>& node_names, 108 const std::vector<string>& node_dtypes, 109 const std::vector<llvm::Optional<std::vector<int>>>& node_shapes, 110 GraphImportConfig::InputArrays* inputs); 111 112 // Parses shapes from the given string into shapes_vector which is a structured 113 // format. 114 // NOTE: If shapes_str is empty, shapes_vector will also be empty. 115 Status ParseNodeShapes( 116 absl::string_view shapes_str, 117 std::vector<llvm::Optional<std::vector<int>>>& shapes_vector); 118 119 // Parses names from the given string into the names_vector. 120 // NOTE: If names_str is empty, names_vector will also be empty. 121 Status ParseNodeNames(absl::string_view names_str, 122 std::vector<std::string>& names_vector); 123 124 // Parses data types from the given string into the data_type_vector. 125 // NOTE: If data_types_str is empty, data_type_vector will also be empty. 126 Status ParseNodeDataTypes(absl::string_view data_types_str, 127 std::vector<std::string>& data_type_vector); 128 129 } // namespace tensorflow 130 131 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ 132