1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ 17 #define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ 18 19 #include <functional> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/xla/service/hlo.pb.h" 23 #include "tensorflow/compiler/xla/status_macros.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/core/platform/env.h" 26 27 namespace xla { 28 29 // Wrapper class for HloModuleMetadataProto to avoid allowing callers to mutate 30 // arbitrary fields. Specifically, callers cannot set timestamps or ids or 31 // set the fields of any pass not currently running. 32 class HloModuleMetadata { 33 public: HloModuleMetadata(tensorflow::Env * env)34 explicit HloModuleMetadata(tensorflow::Env* env) : env_(env) {} 35 proto()36 const HloModuleMetadataProto& proto() const { return module_metadata_; } 37 38 // Creates a new HloPassMetadata. All calls to RecordPassStart should be 39 // matched by a later call to RecordPassEnd. 40 void RecordPassStart(); 41 42 // Marks the currently running pass as finished. Returns NotFound if metadata 43 // for the currently running pass cannot be found. 44 Status RecordPassEnd(); 45 prepartitioning_metadata()46 const absl::optional<HloModuleMetadataProto>& prepartitioning_metadata() 47 const { 48 return prepartitioning_metadata_; 49 } 50 void set_prepartitioning_metadata( 51 const HloModuleMetadata& prepartitioning_metadata); 52 53 // Setters for HloModuleMetadataProto. set_module_group_name(const std::string & name)54 void set_module_group_name(const std::string& name) { 55 module_metadata_.set_module_group_name(name); 56 } set_canonical_module_id(int64 id)57 void set_canonical_module_id(int64 id) { 58 module_metadata_.set_canonical_module_id(id); 59 } add_partitioned_module_id(int64 id)60 void add_partitioned_module_id(int64 id) { 61 module_metadata_.add_partitioned_module_ids(id); 62 } 63 current_pass_id()64 StatusOr<int64> current_pass_id() { 65 TF_ASSIGN_OR_RETURN(HloPassMetadata * pass_metadata, 66 GetCurrentHloPassMetadata()); 67 return pass_metadata->pass_id(); 68 } 69 70 // Setters for the current HloPassMetadata. set_current_pass_name(const std::string & pass_name)71 Status set_current_pass_name(const std::string& pass_name) { 72 return MutateCurrentHloPassMetadata( 73 [&pass_name](HloPassMetadata* pass_metadata) { 74 pass_metadata->set_pass_name(pass_name); 75 }); 76 } set_current_pass_pipeline_name(const std::string & pipeline_name)77 Status set_current_pass_pipeline_name(const std::string& pipeline_name) { 78 return MutateCurrentHloPassMetadata( 79 [&pipeline_name](HloPassMetadata* pass_metadata) { 80 pass_metadata->set_pipeline_name(pipeline_name); 81 }); 82 } add_current_pass_dump_filename(const std::string & dump_filename)83 Status add_current_pass_dump_filename(const std::string& dump_filename) { 84 return MutateCurrentHloPassMetadata( 85 [&dump_filename](HloPassMetadata* pass_metadata) { 86 pass_metadata->add_dump_filenames(dump_filename); 87 }); 88 } set_current_pass_module_changed(bool module_changed)89 Status set_current_pass_module_changed(bool module_changed) { 90 return MutateCurrentHloPassMetadata( 91 [&module_changed](HloPassMetadata* pass_metadata) { 92 pass_metadata->set_module_changed(module_changed); 93 }); 94 } set_current_pass_module_id(int64 module_id)95 Status set_current_pass_module_id(int64 module_id) { 96 return MutateCurrentHloPassMetadata( 97 [&module_id](HloPassMetadata* pass_metadata) { 98 pass_metadata->set_module_id(module_id); 99 }); 100 } add_current_pass_module_group_module_id(int64 module_id)101 Status add_current_pass_module_group_module_id(int64 module_id) { 102 return MutateCurrentHloPassMetadata( 103 [&module_id](HloPassMetadata* pass_metadata) { 104 pass_metadata->add_module_group_module_ids(module_id); 105 }); 106 } 107 108 private: 109 // Gets mutable metadata for the currently running pass. If passes are nested, 110 // finds the deepest one still running. Returns NotFound if metadata for the 111 // currently running pass cannot be found. 112 StatusOr<HloPassMetadata*> GetCurrentHloPassMetadata(); 113 114 Status MutateCurrentHloPassMetadata( 115 const std::function<void(HloPassMetadata*)>& mutator); 116 117 HloModuleMetadataProto module_metadata_; 118 tensorflow::Env* env_; 119 int64 next_pass_id_ = 1; 120 121 // Stack of metadata for passes that are currently running. Size > 1 iff 122 // passes are nested. 123 std::vector<HloPassMetadata*> running_passes_; 124 125 // Metadata from before the module was partitioned, if applicable. 126 absl::optional<HloModuleMetadataProto> prepartitioning_metadata_ = 127 absl::nullopt; 128 }; 129 130 } // namespace xla 131 132 #endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_METADATA_H_ 133