1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/base/macros.h" 24 #include "absl/types/optional.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_runner.h" 30 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 31 #include "tensorflow/compiler/xla/service/platform_util.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h" 36 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" 37 #include "tensorflow/compiler/xla/types.h" 38 #include "tensorflow/compiler/xla/xla_data.pb.h" 39 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 40 #include "tensorflow/core/platform/test.h" 41 42 namespace xla { 43 44 // A base class for tests which build and/or run HLO code. The class includes 45 // support for running an HLO module on two platforms and compare the results. 46 // This is a lower level of abstraction than using the client interface and 47 // enables, for one, explicitly building a graph of HLO instructions to run. 48 // 49 // This can also be used to write text/file-based test cases. Note that the test 50 // target is responsible for linking the needed backends. A convenient way to do 51 // this is to make it an xla_test: it will generate test targets linking with 52 // the respective backends, which will be used as the test backend; the 53 // interpreter backend is already linked with hlo_test_base so it will be the 54 // default reference backend. For example, if you want to compare both cpu vs. 55 // interpreter, and gpu vs. interpreter, you can: 56 // 57 // xla_test ( 58 // name = "sample_text_test", 59 // srcs = ["sample_text_test.cc"], 60 // backends = [ 61 // "cpu", 62 // "gpu", 63 // ], 64 // deps = [ 65 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 66 // ... 67 // ], 68 // ) 69 // 70 // For a more detailed example, see "../tests/sample_text_test.cc". 71 class HloTestBase : public ManifestCheckingTest { 72 public: 73 // Creates a new HLO module for a test. The module created will have 74 // TestName() for its name; it will also automatically populate its debug 75 // options from command-line flags. If you want a fresh HloModule object and 76 // then add HloComputations to it, it's recommended to use this method in your 77 // tests. 78 // 79 // This returns a vanilla HloModule that doesn't run the HLO verifier on 80 // destruction. 81 ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") 82 std::unique_ptr<HloModule> CreateNewUnverifiedModule( 83 const string& name = TestName()); 84 85 // Like CreateNewUnverifiedModule, except the HloModule returned here runs the 86 // HLO verifier on destruction. 87 std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( 88 const string& name = TestName(), int64_t replica_count = 1); 89 90 // Parses the given string and returns module as a VerifiedHloModule. 91 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 92 absl::string_view hlo_text, int64_t replica_count = 1, 93 int64_t num_partitions = 1); 94 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 95 absl::string_view hlo_text, const HloModuleConfig& config); 96 97 // Runs the hlo_pass with the provided module and returns the result. This 98 // function also verifies that the module remains unchanged when hlo_pass 99 // returns false as the StatusOr value. 100 static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, 101 HloModule* module); 102 103 static PrecisionConfig DefaultPrecisionConfig(int operands); 104 105 // Sets most fath math options to be enabled to model the fast math flags 106 // generally used for CPU:AOT compilation. 107 static void SetAotFastMathDebugOptions(DebugOptions* options); 108 109 protected: 110 // This uses the interpreter backend as the reference backend and 111 // automatically finds another supported backend as the test backend. If the 112 // interpreter is the only supported backend, it will be both the test backend 113 // and the reference backend. 114 HloTestBase(bool verifier_layout_sensitive = false, 115 bool allow_mixed_precision_in_hlo_verifier = true, 116 std::function<bool(const HloInstruction*)> 117 instruction_can_change_layout_func = {}); 118 119 // If your test doesn't use interpreter as the reference backend, you can use 120 // this constructor. Note that your test target is responsible for linking in 121 // both needed backends. 122 HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, 123 bool verifier_layout_sensitive = false, 124 bool allow_mixed_precision_in_hlo_verifier = true, 125 std::function<bool(const HloInstruction*)> 126 instruction_can_change_layout_func = {}); 127 ~HloTestBase()128 ~HloTestBase() override {} 129 130 // Populates debug options from command-line flags and adjusts the options for 131 // testing. It is recommended to use this when you need to pass in 132 // DebugOptions, e.g. when creating a module from a string or a file. 133 // 134 // This function is virtual so tests can specify an alternative set of debug 135 // options (e.g. disabling additional passes). 136 virtual DebugOptions GetDebugOptionsForTest(); 137 138 // Gets an HloModuleConfig with options appropriate for tests. 139 HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, 140 int64_t num_partitions = 1) { 141 HloModuleConfig config; 142 config.set_debug_options(GetDebugOptionsForTest()); 143 config.set_replica_count(replica_count); 144 config.set_num_partitions(num_partitions); 145 return config; 146 } 147 148 // Executes the given module and return the result as a Literal. 149 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 150 absl::Span<Literal* const> arguments); 151 152 // Same as above, except the module will be executed without running any HLO 153 // passes on it. 154 Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module, 155 absl::Span<Literal* const> arguments); 156 157 Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, 158 absl::Span<Literal* const> arguments); 159 160 // Executes the given module on multiple replicas. 161 // 162 // use_threads indicates whether this replicated computation will be executed 163 // with a thread-per-replica, vs using an implicitly async call such as 164 // Executable::ExecuteOnStreams. 165 StatusOr<std::vector<Literal>> ExecuteReplicated( 166 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 167 int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); 168 169 // Same as above, but uses specified device assignment. 170 StatusOr<std::vector<Literal>> ExecuteReplicated( 171 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 172 int64_t num_replicas, DeviceAssignment* device_assignment, 173 bool run_hlo_passes, bool use_threads); 174 175 // Same as above, but allows passing different programs for replicas. 176 StatusOr<std::vector<Literal>> ExecuteReplicated( 177 std::function<Executable*(int64_t)> executable_provider, 178 std::function<int64(int64_t)> argument_count_provider, 179 std::function<const Literal*(int64_t, int64_t)> argument_provider, 180 int64_t num_replicas, bool run_hlo_passes); 181 182 // Executes the given hlo module on two backends and compares results. 183 // 184 // 'arguments': the input of the hlo module. 185 // 186 // 'error': if has value, expects the results to be near (within the error 187 // bound). Otherwise, expects the results to be equal. 188 // 189 // 'reference_preprocessor': the module should be ready to run on the test 190 // backend, but it might need to be tailored so that it is able to run on the 191 // reference backend. Note that the program shape of the module must not be 192 // modified. 193 ::testing::AssertionResult RunAndCompare( 194 std::unique_ptr<HloModule> module, 195 const absl::Span<Literal* const> arguments, 196 const absl::optional<ErrorSpec>& error, 197 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 198 TF_MUST_USE_RESULT; 199 200 // Same as above, except that the module will be executed without Hlo 201 // optimization. 202 ::testing::AssertionResult RunAndCompareNoHloPasses( 203 std::unique_ptr<HloModule> module, 204 const absl::Span<Literal* const> arguments, 205 const absl::optional<ErrorSpec>& error, 206 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 207 TF_MUST_USE_RESULT; 208 209 // Executes an hlo module with fake inputs and compares the results. 210 ::testing::AssertionResult RunAndCompare( 211 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 212 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 213 TF_MUST_USE_RESULT; 214 215 // Same as above, except that the module will be executed without Hlo 216 // optimization. 217 ::testing::AssertionResult RunAndCompareNoHloPasses( 218 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 219 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 220 TF_MUST_USE_RESULT; 221 222 // Executes an hlo module with fake inputs and checks that the execution is 223 // successful. 224 ::testing::AssertionResult Run(std::unique_ptr<HloModule> module, 225 bool run_hlo_passes) TF_MUST_USE_RESULT; 226 227 // Convenient wrappers for executing and comparing an hlo module with fake 228 // input. Module can be passed in directly, or parsed from an hlo_string, 229 // or loaded from a file. 230 ::testing::AssertionResult RunAndCompare( 231 const absl::string_view hlo_string, 232 const absl::optional<ErrorSpec>& error, 233 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 234 TF_MUST_USE_RESULT; 235 ::testing::AssertionResult Run(const absl::string_view hlo_string, 236 bool run_hlo_passes = true, 237 ExecutionProfile* profile = nullptr, 238 string backend_config = "") TF_MUST_USE_RESULT; 239 240 // Executes an hlo module with fake inputs on multiple replicas. 241 ::testing::AssertionResult RunReplicated( 242 const absl::string_view hlo_string, bool run_hlo_passes = true, 243 int64_t num_replicas = 1, string backend_config = "") TF_MUST_USE_RESULT; 244 245 // If assert_determinism is true, the assertion will fail unless all runs 246 // produce exactly the same output. 247 ::testing::AssertionResult RunMultipleTimes( 248 const absl::string_view hlo_string, bool run_hlo_passes, 249 std::vector<ExecutionProfile>* profiles, string backend_config = "", 250 bool assert_determinism = false) TF_MUST_USE_RESULT; 251 ::testing::AssertionResult RunAndCompareFromFile( 252 const string& filename, const absl::optional<ErrorSpec>& error, 253 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 254 TF_MUST_USE_RESULT; 255 ::testing::AssertionResult RunAndCompareNoHloPasses( 256 const absl::string_view hlo_string, 257 const absl::optional<ErrorSpec>& error, 258 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 259 TF_MUST_USE_RESULT; 260 ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 261 const string& filename, const absl::optional<ErrorSpec>& error, 262 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 263 TF_MUST_USE_RESULT; 264 265 // Convenience method to force the layout of a given parameter in a module. 266 // The layout of parameter number 'param_no' in the 'module' is set to 267 // 'layout'. ForceParameterLayout(HloModule * module,int64_t param_no,const Layout & layout)268 void ForceParameterLayout(HloModule* module, int64_t param_no, 269 const Layout& layout) { 270 ASSERT_LT(param_no, 271 module->mutable_entry_computation_layout()->parameter_count()); 272 module->mutable_entry_computation_layout() 273 ->mutable_parameter_layout(param_no) 274 ->ResetLayout(layout); 275 } 276 277 // Convenience method to force the layout of the computation result in a 278 // module. The result layout of 'module' is set to 'layout'. ForceResultLayout(HloModule * module,const Layout & layout)279 void ForceResultLayout(HloModule* module, const Layout& layout) { 280 module->mutable_entry_computation_layout() 281 ->mutable_result_layout() 282 ->ResetLayout(layout); 283 } 284 ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)285 void ForceResultLayout(HloModule* module, const Layout& layout, 286 ShapeIndexView shape_index) { 287 module->mutable_entry_computation_layout() 288 ->mutable_result_layout() 289 ->ResetLayout(layout, shape_index); 290 } 291 292 // Convenience method to clear the layout of the computation result in 293 // 'module'. ForceClearResultLayout(HloModule * module)294 void ForceClearResultLayout(HloModule* module) { 295 module->mutable_entry_computation_layout() 296 ->mutable_result_layout() 297 ->Clear(); 298 } 299 300 // Gets the computation/instruction from the given module with the given name. 301 // 302 // This is useful for tests which create HLOs from a string and then want to 303 // inspect a particular computation or instruction. 304 HloComputation* FindComputation(HloModule* module, absl::string_view name); 305 HloInstruction* FindInstruction(HloModule* module, absl::string_view name); 306 // Gets the instruction from the given module with the given opcode. 307 HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); 308 309 // Return an HLO verifier constructed for the test backend. verifier()310 HloVerifier& verifier() const { return *hlo_verifier_; } 311 312 static string TestName(); 313 314 // Returns the backend owned by the test runner. 315 Backend& backend(); 316 317 HloRunner test_runner_; 318 HloRunner reference_runner_; 319 320 bool verifier_layout_sensitive_; 321 bool allow_mixed_precision_in_hlo_verifier_; 322 std::unique_ptr<HloVerifier> hlo_verifier_; 323 324 ErrorSpec error_spec_{0.0001}; 325 326 protected: 327 // Helper functions to get test and reference platforms. 328 static se::Platform* GetReferencePlatform(); 329 static se::Platform* GetTestPlatform(); 330 331 private: 332 // Given the test module, makes a reference module that is ready to run on the 333 // reference platform. This assumes that the given module is ready to run on 334 // the test platform. 335 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 336 const HloModule& test_module, 337 const std::function<void(HloModule*)>& reference_preprocessor); 338 339 // Runs the module on two platforms with or without running hlo passes and 340 // compares the results. Returns whether the results are near or equal. If any 341 // error happens before the results are computed, returns the error status. 342 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 343 std::unique_ptr<HloModule> module, 344 const absl::Span<Literal* const> arguments, 345 const absl::optional<ErrorSpec>& error, bool run_hlo_passes, 346 const std::function<void(HloModule*)>& reference_preprocessor); 347 }; 348 349 } // namespace xla 350 351 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 352