1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/base/macros.h" 24 #include "absl/types/optional.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_runner.h" 30 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 31 #include "tensorflow/compiler/xla/service/platform_util.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h" 36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 40 #include "tensorflow/core/platform/test.h" 41 42 namespace xla { 43 44 // A base class for tests which build and/or run HLO code. The class includes 45 // support for running an HLO module on two platforms and compare the results. 46 // This is a lower level of abstraction than using the client interface and 47 // enables, for one, explicitly building a graph of HLO instructions to run. 48 // 49 // This can also be used to write text/file-based test cases. Note that the test 50 // target is responsible for linking the needed backends. A convenient way to do 51 // this is to make it an xla_test: it will generate test targets linking with 52 // the respective backends, which will be used as the test backend; the 53 // interpreter backend is already linked with hlo_test_base so it will be the 54 // default reference backend. For example, if you want to compare both cpu vs. 55 // interpreter, and gpu vs. interpreter, you can: 56 // 57 // xla_test ( 58 // name = "sample_text_test", 59 // srcs = ["sample_text_test.cc"], 60 // backends = [ 61 // "cpu", 62 // "gpu", 63 // ], 64 // deps = [ 65 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 66 // ... 67 // ], 68 // ) 69 // 70 // For a more detailed example, see "../tests/sample_text_test.cc". 71 class HloTestBase : public ManifestCheckingTest { 72 public: 73 // Creates a new HLO module for a test. The module created will have 74 // TestName() for its name; it will also automatically populate its debug 75 // options from command-line flags. If you want a fresh HloModule object and 76 // then add HloComputations to it, it's recommended to use this method in your 77 // tests. 78 // 79 // This returns a vanilla HloModule that doesn't run the HLO verifier on 80 // destruction. 81 ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") 82 std::unique_ptr<HloModule> CreateNewUnverifiedModule( 83 const string& name = TestName()); 84 85 // Like CreateNewUnverifiedModule, except the HloModule returned here runs the 86 // HLO verifier on destruction. 87 std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( 88 const string& name = TestName(), int64 replica_count = 1); 89 90 // Parses the given string and returns module as a VerifiedHloModule. 91 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 92 absl::string_view hlo_text, int64 replica_count = 1); 93 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 94 absl::string_view hlo_text, const HloModuleConfig& config); 95 96 // Runs the hlo_pass with the provided module and returns the result. This 97 // function also verifies that the module remains unchanged when hlo_pass 98 // returns false as the StatusOr value. 99 static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, 100 HloModule* module); 101 102 static PrecisionConfig DefaultPrecisionConfig(int operands); 103 104 // Sets most fath math options to be enabled to model the fast math flags 105 // generally used for CPU:AOT compilation. 106 static void SetAotFastMathDebugOptions(DebugOptions* options); 107 108 protected: 109 // This uses the interpreter backend as the reference backend and 110 // automatically finds another supported backend as the test backend. If the 111 // interpreter is the only supported backend, it will be both the test backend 112 // and the reference backend. 113 HloTestBase(bool verifier_layout_sensitive = false, 114 bool allow_mixed_precision_in_hlo_verifier = true, 115 std::function<bool(const HloInstruction*)> 116 instruction_can_change_layout_func = {}); 117 118 // If your test doesn't use interpreter as the reference backend, you can use 119 // this constructor. Note that your test target is responsible for linking in 120 // both needed backends. 121 HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, 122 bool verifier_layout_sensitive = false, 123 bool allow_mixed_precision_in_hlo_verifier = true, 124 std::function<bool(const HloInstruction*)> 125 instruction_can_change_layout_func = {}); 126 ~HloTestBase()127 ~HloTestBase() override {} 128 129 // Populates debug options from command-line flags and adjusts the options for 130 // testing. It is recommended to use this when you need to pass in 131 // DebugOptions, e.g. when creating a module from a string or a file. 132 // 133 // This function is virtual so tests can specify an alternative set of debug 134 // options (e.g. disabling additional passes). 135 virtual DebugOptions GetDebugOptionsForTest(); 136 137 // Gets an HloModuleConfig with options appropriate for tests. 138 HloModuleConfig GetModuleConfigForTest(int64 replica_count = 1) { 139 HloModuleConfig config; 140 config.set_debug_options(GetDebugOptionsForTest()); 141 config.set_replica_count(replica_count); 142 return config; 143 } 144 145 // Executes the given module and return the result as a Literal. 146 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 147 absl::Span<Literal* const> arguments); 148 149 // Same as above, except the module will be executed without running any HLO 150 // passes on it. 151 Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module, 152 absl::Span<Literal* const> arguments); 153 154 Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, 155 absl::Span<Literal* const> arguments); 156 157 // Executes the given module on multiple replicas. 158 // 159 // use_threads indicates whether this replicated computation will be executed 160 // with a thread-per-replica, vs using an implicitly async call such as 161 // Executable::ExecuteOnStreams. 162 StatusOr<std::vector<Literal>> ExecuteReplicated( 163 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 164 int64 num_replicas, bool use_threads, bool run_hlo_passes = false); 165 166 // Same as above, but uses specified device assignment. 167 StatusOr<std::vector<Literal>> ExecuteReplicated( 168 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 169 int64 num_replicas, DeviceAssignment* device_assignment, 170 bool run_hlo_passes, bool use_threads); 171 172 // Same as above, but allows passing different programs for replicas. 173 StatusOr<std::vector<Literal>> ExecuteReplicated( 174 std::function<Executable*(int64)> executable_provider, 175 std::function<int64(int64)> argument_count_provider, 176 std::function<const Literal*(int64, int64)> argument_provider, 177 int64 num_replicas, bool run_hlo_passes); 178 179 // Executes the given hlo module on two backends and compares results. 180 // 181 // 'arguments': the input of the hlo module. 182 // 183 // 'error': if has value, expects the results to be near (within the error 184 // bound). Otherwise, expects the results to be equal. 185 // 186 // 'reference_preprocessor': the module should be ready to run on the test 187 // backend, but it might need to be tailored so that it is able to run on the 188 // reference backend. Note that the program shape of the module must not be 189 // modified. 190 ::testing::AssertionResult RunAndCompare( 191 std::unique_ptr<HloModule> module, 192 const absl::Span<Literal* const> arguments, 193 const absl::optional<ErrorSpec>& error, 194 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 195 TF_MUST_USE_RESULT; 196 197 // Same as above, except that the module will be executed without Hlo 198 // optimization. 199 ::testing::AssertionResult RunAndCompareNoHloPasses( 200 std::unique_ptr<HloModule> module, 201 const absl::Span<Literal* const> arguments, 202 const absl::optional<ErrorSpec>& error, 203 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 204 TF_MUST_USE_RESULT; 205 206 // Executes an hlo module with fake inputs and compares the results. 207 ::testing::AssertionResult RunAndCompare( 208 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 209 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 210 TF_MUST_USE_RESULT; 211 212 // Same as above, except that the module will be executed without Hlo 213 // optimization. 214 ::testing::AssertionResult RunAndCompareNoHloPasses( 215 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 216 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 217 TF_MUST_USE_RESULT; 218 219 // Executes an hlo module with fake inputs and checks that the execution is 220 // successful. 221 ::testing::AssertionResult Run(std::unique_ptr<HloModule> module, 222 bool run_hlo_passes) TF_MUST_USE_RESULT; 223 224 // Convenient wrappers for executing and comparing an hlo module with fake 225 // input. Module can be passed in directly, or parsed from an hlo_string, 226 // or loaded from a file. 227 ::testing::AssertionResult RunAndCompare( 228 const absl::string_view hlo_string, 229 const absl::optional<ErrorSpec>& error, 230 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 231 TF_MUST_USE_RESULT; 232 ::testing::AssertionResult Run(const absl::string_view hlo_string, 233 bool run_hlo_passes = true, 234 ExecutionProfile* profile = nullptr, 235 string backend_config = "") TF_MUST_USE_RESULT; 236 237 // Executes an hlo module with fake inputs on multiple replicas. 238 ::testing::AssertionResult RunReplicated( 239 const absl::string_view hlo_string, bool run_hlo_passes = true, 240 int64 num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT; 241 242 // If assert_determinism is true, the assertion will fail unless all runs 243 // produce exactly the same output. 244 ::testing::AssertionResult RunMultipleTimes( 245 const absl::string_view hlo_string, bool run_hlo_passes, 246 std::vector<ExecutionProfile>* profiles, string backend_config = "", 247 bool assert_determinism = false) TF_MUST_USE_RESULT; 248 ::testing::AssertionResult RunAndCompareFromFile( 249 const string& filename, const absl::optional<ErrorSpec>& error, 250 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 251 TF_MUST_USE_RESULT; 252 ::testing::AssertionResult RunAndCompareNoHloPasses( 253 const absl::string_view hlo_string, 254 const absl::optional<ErrorSpec>& error, 255 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 256 TF_MUST_USE_RESULT; 257 ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 258 const string& filename, const absl::optional<ErrorSpec>& error, 259 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 260 TF_MUST_USE_RESULT; 261 262 // Convenience method to force the layout of a given parameter in a module. 263 // The layout of parameter number 'param_no' in the 'module' is set to 264 // 'layout'. ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)265 void ForceParameterLayout(HloModule* module, int64 param_no, 266 const Layout& layout) { 267 ASSERT_LT(param_no, 268 module->mutable_entry_computation_layout()->parameter_count()); 269 module->mutable_entry_computation_layout() 270 ->mutable_parameter_layout(param_no) 271 ->ResetLayout(layout); 272 } 273 274 // Convenience method to force the layout of the computation result in a 275 // module. The result layout of 'module' is set to 'layout'. ForceResultLayout(HloModule * module,const Layout & layout)276 void ForceResultLayout(HloModule* module, const Layout& layout) { 277 module->mutable_entry_computation_layout() 278 ->mutable_result_layout() 279 ->ResetLayout(layout); 280 } 281 ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)282 void ForceResultLayout(HloModule* module, const Layout& layout, 283 ShapeIndexView shape_index) { 284 module->mutable_entry_computation_layout() 285 ->mutable_result_layout() 286 ->ResetLayout(layout, shape_index); 287 } 288 289 // Convenience method to clear the layout of the computation result in 290 // 'module'. ForceClearResultLayout(HloModule * module)291 void ForceClearResultLayout(HloModule* module) { 292 module->mutable_entry_computation_layout() 293 ->mutable_result_layout() 294 ->Clear(); 295 } 296 297 // Gets the computation/instruction from the given module with the given name. 298 // 299 // This is useful for tests which create HLOs from a string and then want to 300 // inspect a particular computation or instruction. 301 HloComputation* FindComputation(HloModule* module, absl::string_view name); 302 HloInstruction* FindInstruction(HloModule* module, absl::string_view name); 303 // Gets the instruction from the given module with the given opcode. 304 HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); 305 306 // Return an HLO verifier constructed for the test backend. verifier()307 HloVerifier& verifier() const { return *hlo_verifier_; } 308 309 static string TestName(); 310 311 // Returns the backend owned by the test runner. 312 Backend& backend(); 313 314 HloRunner test_runner_; 315 HloRunner reference_runner_; 316 317 bool verifier_layout_sensitive_; 318 bool allow_mixed_precision_in_hlo_verifier_; 319 std::unique_ptr<HloVerifier> hlo_verifier_; 320 321 ErrorSpec error_spec_{0.0001}; 322 323 private: 324 // Given the test module, makes a reference module that is ready to run on the 325 // reference platform. This assumes that the given module is ready to run on 326 // the test platform. 327 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 328 const HloModule& test_module, 329 const std::function<void(HloModule*)>& reference_preprocessor); 330 331 // Runs the module on two platforms with or without running hlo passes and 332 // compares the results. Returns whether the results are near or equal. If any 333 // error happens before the results are computed, returns the error status. 334 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 335 std::unique_ptr<HloModule> module, 336 const absl::Span<Literal* const> arguments, 337 const absl::optional<ErrorSpec>& error, bool run_hlo_passes, 338 const std::function<void(HloModule*)>& reference_preprocessor); 339 }; 340 341 } // namespace xla 342 343 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 344