1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 17 #define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/base/macros.h" 24 #include "absl/types/optional.h" 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/service/backend.h" 27 #include "tensorflow/compiler/xla/service/computation_layout.h" 28 #include "tensorflow/compiler/xla/service/hlo_module.h" 29 #include "tensorflow/compiler/xla/service/hlo_runner.h" 30 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 31 #include "tensorflow/compiler/xla/service/platform_util.h" 32 #include "tensorflow/compiler/xla/shape_layout.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 38 #include "tensorflow/core/platform/test.h" 39 40 namespace xla { 41 42 // An HLO module derived class which verifies itself on destruction. This class 43 // is intended to be used in unit tests. Any verification errors are raised via 44 // ADD_FAILURE. 45 class VerifiedHloModule : public HloModule { 46 public: VerifiedHloModule(const string & name,const HloModuleConfig & config,bool verifier_layout_sensitive,bool allow_mixed_precision_in_hlo_verifier,std::function<int64 (const Shape &)> shape_size_function)47 VerifiedHloModule(const string& name, const HloModuleConfig& config, 48 bool verifier_layout_sensitive, 49 bool allow_mixed_precision_in_hlo_verifier, 50 std::function<int64(const Shape&)> shape_size_function) 51 : HloModule(name, config), 52 verifier_( 53 verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier, 54 /*instruction_can_change_layout_func=*/{}, shape_size_function) {} 55 ~VerifiedHloModule()56 ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } 57 58 // Verifies the module using HloVerifier and returns the status. 59 Status Verify(); 60 61 // Verifies the module and flags any error with ADD_FAILURE. 'message' is 62 // included in the failure message. 63 void VerifyOrAddFailure(const string& message); 64 65 private: 66 HloVerifier verifier_; 67 }; 68 69 // A base class for tests which build and/or run HLO code. The class includes 70 // support for running an HLO module on two platforms and compare the results. 71 // This is a lower level of abstraction than using the client interface and 72 // enables, for one, explicitly building a graph of HLO instructions to run. 73 // 74 // This can also be used to write text/file-based test cases. Note that the test 75 // target is responsible for linking the needed backends. A convenient way to do 76 // this is to make it an xla_test: it will generate test targets linking with 77 // the respective backends, which will be used as the test backend; the 78 // interpreter backend is already linked with hlo_test_base so it will be the 79 // default reference backend. For example, if you want to compare both cpu vs. 80 // interpreter, and gpu vs. interpreter, you can: 81 // 82 // xla_test ( 83 // name = "sample_text_test", 84 // srcs = ["sample_text_test.cc"], 85 // backends = [ 86 // "cpu", 87 // "gpu", 88 // ], 89 // deps = [ 90 // "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", 91 // ... 92 // ], 93 // ) 94 // 95 // For a more detailed example, see "../tests/sample_text_test.cc". 96 class HloTestBase : public ::testing::Test { 97 public: 98 // Creates a new HLO module for a test. The module created will have 99 // TestName() for its name; it will also automatically populate its debug 100 // options from command-line flags. If you want a fresh HloModule object and 101 // then add HloComputations to it, it's recommended to use this method in your 102 // tests. 103 // 104 // This returns a vanilla HloModule that doesn't run the HLO verifier on 105 // destruction. 106 ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") 107 std::unique_ptr<HloModule> CreateNewUnverifiedModule( 108 const string& name = TestName()); 109 110 // Like CreateNewUnverifiedModule, except the HloModule returned here runs the 111 // HLO verifier on destruction. 112 std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( 113 const string& name = TestName()); 114 115 // Parses the given string and returns module as a VerifiedHloModule. 116 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( 117 absl::string_view hlo_text, 118 const HloModuleConfig& config = HloModuleConfig()); 119 120 // Runs the hlo_pass with the provided module and returns the result. This 121 // function also verifies that the module remains unchanged when hlo_pass 122 // returns false as the StatusOr value. 123 static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass, 124 HloModule* module); 125 126 static PrecisionConfig DefaultPrecisionConfig(int operands); 127 128 protected: 129 // This uses the interpreter backend as the reference backend and 130 // automatically finds another supported backend as the test backend. If the 131 // interpreter is the only supported backend, it will be both the test backend 132 // and the reference backend. 133 HloTestBase(bool verifier_layout_sensitive = false, 134 bool allow_mixed_precision_in_hlo_verifier = true, 135 std::function<bool(const HloInstruction*)> 136 instruction_can_change_layout_func = {}); 137 138 // If your test doesn't use interpreter as the reference backend, you can use 139 // this constructor. Note that your test target is responsible for linking in 140 // both needed backends. 141 HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, 142 bool verifier_layout_sensitive = false, 143 bool allow_mixed_precision_in_hlo_verifier = true, 144 std::function<bool(const HloInstruction*)> 145 instruction_can_change_layout_func = {}); 146 ~HloTestBase()147 ~HloTestBase() override {} 148 149 // Populates debug options from command-line flags and adjusts the options for 150 // testing. It is recommended to use this when you need to pass in 151 // DebugOptions, e.g. when creating a module from a string or a file. 152 // 153 // This function is virtual so tests can specify an alternative set of debug 154 // options (e.g. disabling additional passes). 155 virtual DebugOptions GetDebugOptionsForTest(); 156 157 // Gets an HloModuleConfig with options appropriate for tests. GetModuleConfigForTest()158 HloModuleConfig GetModuleConfigForTest() { 159 HloModuleConfig config; 160 config.set_debug_options(GetDebugOptionsForTest()); 161 return config; 162 } 163 164 // Executes the given module and return the result as a Literal. 165 StatusOr<Literal> Execute(std::unique_ptr<HloModule> module, 166 absl::Span<Literal* const> arguments); 167 168 // Same as above, except the module will be executed without running any HLO 169 // passes on it. 170 Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module, 171 absl::Span<Literal* const> arguments); 172 173 Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module, 174 absl::Span<Literal* const> arguments); 175 176 // Executes the given module on multiple replicas. 177 // 178 // use_threads indicates whether this replicated computation will be executed 179 // with a thread-per-replica, vs using an implicitly async call such as 180 // Executable::ExecuteOnStreams. 181 StatusOr<std::vector<Literal>> ExecuteReplicated( 182 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 183 int64 num_replicas, bool use_threads); 184 185 // Same as above, but uses specified device assignment. 186 StatusOr<std::vector<Literal>> ExecuteReplicated( 187 std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments, 188 int64 num_replicas, DeviceAssignment* device_assignment, 189 bool run_hlo_passes, bool use_threads); 190 191 // Executes the given hlo module on two backends and compares results. 192 // 193 // 'arguments': the input of the hlo module. 194 // 195 // 'error': if has value, expects the results to be near (within the error 196 // bound). Otherwise, expects the results to be equal. 197 // 198 // 'reference_preprocessor': the module should be ready to run on the test 199 // backend, but it might need to be tailored so that it is able to run on the 200 // reference backend. Note that the program shape of the module must not be 201 // modified. 202 ::testing::AssertionResult RunAndCompare( 203 std::unique_ptr<HloModule> module, 204 const absl::Span<Literal* const> arguments, 205 const absl::optional<ErrorSpec>& error, 206 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 207 TF_MUST_USE_RESULT; 208 209 // Same as above, except that the module will be executed without Hlo 210 // optimization. 211 ::testing::AssertionResult RunAndCompareNoHloPasses( 212 std::unique_ptr<HloModule> module, 213 const absl::Span<Literal* const> arguments, 214 const absl::optional<ErrorSpec>& error, 215 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 216 TF_MUST_USE_RESULT; 217 218 // Executes an hlo module with fake inputs and compares the results. 219 ::testing::AssertionResult RunAndCompare( 220 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 221 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 222 TF_MUST_USE_RESULT; 223 224 // Same as above, except that the module will be executed without Hlo 225 // optimization. 226 ::testing::AssertionResult RunAndCompareNoHloPasses( 227 std::unique_ptr<HloModule> module, const absl::optional<ErrorSpec>& error, 228 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 229 TF_MUST_USE_RESULT; 230 231 // Convenient wrappers for executing and comparing an hlo module with fake 232 // input. Module can be passed in directly, or parsed from an hlo_string, 233 // or loaded from a file. 234 ::testing::AssertionResult RunAndCompare( 235 const absl::string_view hlo_string, 236 const absl::optional<ErrorSpec>& error, 237 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 238 TF_MUST_USE_RESULT; 239 ::testing::AssertionResult Run(const absl::string_view hlo_string, 240 bool run_hlo_passes = true, 241 ExecutionProfile* profile = nullptr, 242 string backend_config = "") TF_MUST_USE_RESULT; 243 ::testing::AssertionResult RunMultipleTimes( 244 const absl::string_view hlo_string, bool run_hlo_passes, 245 std::vector<ExecutionProfile>* profiles, 246 string backend_config = "") TF_MUST_USE_RESULT; 247 ::testing::AssertionResult RunAndCompareFromFile( 248 const string& filename, const absl::optional<ErrorSpec>& error, 249 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 250 TF_MUST_USE_RESULT; 251 ::testing::AssertionResult RunAndCompareNoHloPasses( 252 const absl::string_view hlo_string, 253 const absl::optional<ErrorSpec>& error, 254 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 255 TF_MUST_USE_RESULT; 256 ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( 257 const string& filename, const absl::optional<ErrorSpec>& error, 258 const std::function<void(HloModule*)>& reference_preprocessor = nullptr) 259 TF_MUST_USE_RESULT; 260 261 // Convenience method to force the layout of a given parameter in a module. 262 // The layout of parameter number 'param_no' in the 'module' is set to 263 // 'layout'. ForceParameterLayout(HloModule * module,int64 param_no,const Layout & layout)264 void ForceParameterLayout(HloModule* module, int64 param_no, 265 const Layout& layout) { 266 ASSERT_LT(param_no, 267 module->mutable_entry_computation_layout()->parameter_count()); 268 module->mutable_entry_computation_layout() 269 ->mutable_parameter_layout(param_no) 270 ->ResetLayout(layout); 271 } 272 273 // Convenience method to force the layout of the computation result in a 274 // module. The result layout of 'module' is set to 'layout'. ForceResultLayout(HloModule * module,const Layout & layout)275 void ForceResultLayout(HloModule* module, const Layout& layout) { 276 module->mutable_entry_computation_layout() 277 ->mutable_result_layout() 278 ->ResetLayout(layout); 279 } 280 ForceResultLayout(HloModule * module,const Layout & layout,ShapeIndexView shape_index)281 void ForceResultLayout(HloModule* module, const Layout& layout, 282 ShapeIndexView shape_index) { 283 module->mutable_entry_computation_layout() 284 ->mutable_result_layout() 285 ->ResetLayout(layout, shape_index); 286 } 287 288 // Convenience method to clear the layout of the computation result in 289 // 'module'. ForceClearResultLayout(HloModule * module)290 void ForceClearResultLayout(HloModule* module) { 291 module->mutable_entry_computation_layout() 292 ->mutable_result_layout() 293 ->Clear(); 294 } 295 296 // Gets the computation/instruction from the given module with the given name. 297 // 298 // This is useful for tests which create HLOs from a string and then want to 299 // inspect a particular computation or instruction. 300 HloComputation* FindComputation(HloModule* module, absl::string_view name); 301 HloInstruction* FindInstruction(HloModule* module, absl::string_view name); 302 303 // Return an HLO verifier constructed for the test backend. verifier()304 HloVerifier& verifier() const { return *hlo_verifier_; } 305 306 static string TestName(); 307 308 // Returns the backend owned by the test runner. 309 Backend& backend(); 310 311 HloRunner test_runner_; 312 HloRunner reference_runner_; 313 314 bool verifier_layout_sensitive_; 315 bool allow_mixed_precision_in_hlo_verifier_; 316 std::unique_ptr<HloVerifier> hlo_verifier_; 317 318 ErrorSpec error_spec_{0.0001}; 319 320 private: 321 // Given the test module, makes a reference module that is ready to run on the 322 // reference platform. This assumes that the given module is ready to run on 323 // the test platform. 324 StatusOr<std::unique_ptr<HloModule>> MakeReferenceModule( 325 const HloModule& test_module, 326 const std::function<void(HloModule*)>& reference_preprocessor); 327 328 // Runs the module on two platforms with or without running hlo passes and 329 // compares the results. Returns whether the results are near or equal. If any 330 // error happens before the results are computed, returns the error status. 331 StatusOr<::testing::AssertionResult> RunAndCompareInternal( 332 std::unique_ptr<HloModule> module, 333 const absl::Span<Literal* const> arguments, 334 const absl::optional<ErrorSpec>& error, bool run_hlo_passes, 335 const std::function<void(HloModule*)>& reference_preprocessor); 336 }; 337 338 } // namespace xla 339 340 #endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_ 341