1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/base/macros.h" 24 #include "absl/types/optional.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_runner.h" 30 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 31 #include "tensorflow/compiler/xla/service/platform_util.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" 36 #include "tensorflow/compiler/xla/types.h" 37 #include "tensorflow/compiler/xla/xla_data.pb.h" 38 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 39 #include "tensorflow/core/platform/test.h" 40 41 namespace xla { 42 43 // A base class for tests which build and/or run HLO code. The class includes 44 // support for running an HLO module on two platforms and compare the results. 45 // This is a lower level of abstraction than using the client interface and 46 // enables, for one, explicitly building a graph of HLO instructions to run. 47 // 48 // This can also be used to write text/file-based test cases. Note that the test 49 // target is responsible for linking the needed backends. A convenient way to do 50 // this is to make it an xla_test: it will generate test targets linking with 51 // the respective backends, which will be used as the test backend; the 52 // interpreter backend is already linked with hlo_test_base so it will be the 53 // default reference backend. For example, if you want to compare both cpu vs. 54 // interpreter, and gpu vs. interpreter, you can: 55 // 56 // xla_test ( 57 // name = "sample_text_test", 58 // srcs = ["sample_text_test.cc"], 59 // backends = [ 60 // "cpu", 61 // "gpu", 62 // ], 63 // deps = [ 64 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 65 // ... 66 // ], 67 // ) 68 // 69 // For a more detailed example, see "../tests/sample_text_test.cc". 70 class HloTestBase : public ::testing::Test { 71 public: 72 // Creates a new HLO module for a test. The module created will have 73 // TestName() for its name; it will also automatically populate its debug 74 // options from command-line flags. If you want a fresh HloModule object and 75 // then add HloComputations to it, it's recommended to use this method in your 76 // tests. 77 // 78 // This returns a vanilla HloModule that doesn't run the HLO verifier on 79 // destruction. 80 ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") 81 std::unique_ptr<HloModule> CreateNewUnverifiedModule( 82 const string& name = TestName()); 83 84 // Like CreateNewUnverifiedModule, except the HloModule returned here runs the 85 // HLO verifier on destruction. 86 std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( 87 const string& name = TestName()); 88 89 // Parses the given string and returns module as a VerifiedHloModule. 90 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 91 absl::string_view hlo_text); 92 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 93 absl::string_view hlo_text, const HloModuleConfig& config); 94 95 // Runs the hlo_pass with the provided module and returns the result. This 96 // function also verifies that the module remains unchanged when hlo_pass 97 // returns false as the StatusOr value. 98 static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, 99 HloModule* module); 100 101 static PrecisionConfig DefaultPrecisionConfig(int operands); 102 103 protected: 104 // This uses the interpreter backend as the reference backend and 105 // automatically finds another supported backend as the test backend. If the 106 // interpreter is the only supported backend, it will be both the test backend 107 // and the reference backend. 108 HloTestBase(bool verifier_layout_sensitive = false, 109 bool allow_mixed_precision_in_hlo_verifier = true, 110 std::function<bool(const HloInstruction*)> 111 instruction_can_change_layout_func = {}); 112 113 // If your test doesn't use interpreter as the reference backend, you can use 114 // this constructor. Note that your test target is responsible for linking in 115 // both needed backends. 116 HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, 117 bool verifier_layout_sensitive = false, 118 bool allow_mixed_precision_in_hlo_verifier = true, 119 std::function<bool(const HloInstruction*)> 120 instruction_can_change_layout_func = {}); 121 ~HloTestBase()122 ~HloTestBase() override {} 123 124 // Populates debug options from command-line flags and adjusts the options for 125 // testing. It is recommended to use this when you need to pass in 126 // DebugOptions, e.g. when creating a module from a string or a file. 127 // 128 // This function is virtual so tests can specify an alternative set of debug 129 // options (e.g. disabling additional passes). 130 virtual DebugOptions GetDebugOptionsForTest(); 131 132 // Gets an HloModuleConfig with options appropriate for tests. GetModuleConfigForTest()133 HloModuleConfig GetModuleConfigForTest() { 134 HloModuleConfig config; 135 config.set_debug_options(GetDebugOptionsForTest()); 136 return config; 137 } 138 139 // Executes the given module and return the result as a Literal. 140 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 141 absl::Span<Literal* const> arguments); 142 143 // Same as above, except the module will be executed without running any HLO 144 // passes on it. 145 Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module, 146 absl::Span<Literal* const> arguments); 147 148 Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, 149 absl::Span<Literal* const> arguments); 150 151 // Executes the given module on multiple replicas. 152 // 153 // use_threads indicates whether this replicated computation will be executed 154 // with a thread-per-replica, vs using an implicitly async call such as 155 // Executable::ExecuteOnStreams. 156 StatusOr<std::vector<Literal>> ExecuteReplicated( 157 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 158 int64 num_replicas, bool use_threads, bool run_hlo_passes = false); 159 160 // Same as above, but uses specified device assignment. 161 StatusOr<std::vector<Literal>> ExecuteReplicated( 162 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 163 int64 num_replicas, DeviceAssignment* device_assignment, 164 bool run_hlo_passes, bool use_threads); 165 166 // Executes the given hlo module on two backends and compares results. 167 // 168 // 'arguments': the input of the hlo module. 169 // 170 // 'error': if has value, expects the results to be near (within the error 171 // bound). Otherwise, expects the results to be equal. 172 // 173 // 'reference_preprocessor': the module should be ready to run on the test 174 // backend, but it might need to be tailored so that it is able to run on the 175 // reference backend. Note that the program shape of the module must not be 176 // modified. 177 ::testing::AssertionResult RunAndCompare( 178 std::unique_ptr<HloModule> module, 179 const absl::Span<Literal* const> arguments, 180 const absl::optional<ErrorSpec>& error, 181 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 182 TF_MUST_USE_RESULT; 183 184 // Same as above, except that the module will be executed without Hlo 185 // optimization. 186 ::testing::AssertionResult RunAndCompareNoHloPasses( 187 std::unique_ptr<HloModule> module, 188 const absl::Span<Literal* const> arguments, 189 const absl::optional<ErrorSpec>& error, 190 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 191 TF_MUST_USE_RESULT; 192 193 // Executes an hlo module with fake inputs and compares the results. 194 ::testing::AssertionResult RunAndCompare( 195 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 196 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 197 TF_MUST_USE_RESULT; 198 199 // Same as above, except that the module will be executed without Hlo 200 // optimization. 201 ::testing::AssertionResult RunAndCompareNoHloPasses( 202 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 203 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 204 TF_MUST_USE_RESULT; 205 206 // Convenient wrappers for executing and comparing an hlo module with fake 207 // input. Module can be passed in directly, or parsed from an hlo_string, 208 // or loaded from a file. 209 ::testing::AssertionResult RunAndCompare( 210 const absl::string_view hlo_string, 211 const absl::optional<ErrorSpec>& error, 212 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 213 TF_MUST_USE_RESULT; 214 ::testing::AssertionResult Run(const absl::string_view hlo_string, 215 bool run_hlo_passes = true, 216 ExecutionProfile* profile = nullptr, 217 string backend_config = "") TF_MUST_USE_RESULT; 218 219 // If assert_determinism is true, the assertion will fail unless all runs 220 // produce exactly the same output. 221 ::testing::AssertionResult RunMultipleTimes( 222 const absl::string_view hlo_string, bool run_hlo_passes, 223 std::vector<ExecutionProfile>* profiles, string backend_config = "", 224 bool assert_determinism = false) TF_MUST_USE_RESULT; 225 ::testing::AssertionResult RunAndCompareFromFile( 226 const string& filename, const absl::optional<ErrorSpec>& error, 227 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 228 TF_MUST_USE_RESULT; 229 ::testing::AssertionResult RunAndCompareNoHloPasses( 230 const absl::string_view hlo_string, 231 const absl::optional<ErrorSpec>& error, 232 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 233 TF_MUST_USE_RESULT; 234 ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 235 const string& filename, const absl::optional<ErrorSpec>& error, 236 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 237 TF_MUST_USE_RESULT; 238 239 // Convenience method to force the layout of a given parameter in a module. 240 // The layout of parameter number 'param_no' in the 'module' is set to 241 // 'layout'. ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)242 void ForceParameterLayout(HloModule* module, int64 param_no, 243 const Layout& layout) { 244 ASSERT_LT(param_no, 245 module->mutable_entry_computation_layout()->parameter_count()); 246 module->mutable_entry_computation_layout() 247 ->mutable_parameter_layout(param_no) 248 ->ResetLayout(layout); 249 } 250 251 // Convenience method to force the layout of the computation result in a 252 // module. The result layout of 'module' is set to 'layout'. ForceResultLayout(HloModule * module,const Layout & layout)253 void ForceResultLayout(HloModule* module, const Layout& layout) { 254 module->mutable_entry_computation_layout() 255 ->mutable_result_layout() 256 ->ResetLayout(layout); 257 } 258 ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)259 void ForceResultLayout(HloModule* module, const Layout& layout, 260 ShapeIndexView shape_index) { 261 module->mutable_entry_computation_layout() 262 ->mutable_result_layout() 263 ->ResetLayout(layout, shape_index); 264 } 265 266 // Convenience method to clear the layout of the computation result in 267 // 'module'. ForceClearResultLayout(HloModule * module)268 void ForceClearResultLayout(HloModule* module) { 269 module->mutable_entry_computation_layout() 270 ->mutable_result_layout() 271 ->Clear(); 272 } 273 274 // Gets the computation/instruction from the given module with the given name. 275 // 276 // This is useful for tests which create HLOs from a string and then want to 277 // inspect a particular computation or instruction. 278 HloComputation* FindComputation(HloModule* module, absl::string_view name); 279 HloInstruction* FindInstruction(HloModule* module, absl::string_view name); 280 // Gets the instruction from the given module with the given opcode. 281 HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); 282 283 // Return an HLO verifier constructed for the test backend. verifier()284 HloVerifier& verifier() const { return *hlo_verifier_; } 285 286 static string TestName(); 287 288 // Returns the backend owned by the test runner. 289 Backend& backend(); 290 291 HloRunner test_runner_; 292 HloRunner reference_runner_; 293 294 bool verifier_layout_sensitive_; 295 bool allow_mixed_precision_in_hlo_verifier_; 296 std::unique_ptr<HloVerifier> hlo_verifier_; 297 298 ErrorSpec error_spec_{0.0001}; 299 300 private: 301 // Given the test module, makes a reference module that is ready to run on the 302 // reference platform. This assumes that the given module is ready to run on 303 // the test platform. 304 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 305 const HloModule& test_module, 306 const std::function<void(HloModule*)>& reference_preprocessor); 307 308 // Runs the module on two platforms with or without running hlo passes and 309 // compares the results. Returns whether the results are near or equal. If any 310 // error happens before the results are computed, returns the error status. 311 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 312 std::unique_ptr<HloModule> module, 313 const absl::Span<Literal* const> arguments, 314 const absl::optional<ErrorSpec>& error, bool run_hlo_passes, 315 const std::function<void(HloModule*)>& reference_preprocessor); 316 }; 317 318 } // namespace xla 319 320 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 321