1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utility functions in support of the XRT API. 17 18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ 19 #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ 20 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/service/backend.h" 26 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 27 #include "tensorflow/compiler/xla/shape.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/xla.pb.h" 31 #include "tensorflow/compiler/xrt/xrt.pb.h" 32 #include "tensorflow/compiler/xrt/xrt_memory_manager.h" 33 #include "tensorflow/compiler/xrt/xrt_refptr.h" 34 #include "tensorflow/compiler/xrt/xrt_state.h" 35 #include "tensorflow/core/framework/op_kernel.h" 36 #include "tensorflow/core/lib/core/status.h" 37 38 namespace tensorflow { 39 40 // Factory class which creates NCCL unique IDs based on the replicas 41 // participating to a given communication. This is only used for GPU backends. 42 struct NcclUniqueIdFactory { ~NcclUniqueIdFactoryNcclUniqueIdFactory43 virtual ~NcclUniqueIdFactory() {} 44 45 // Generates the NCCL unique ID for the given set of replica IDs. 46 virtual std::string GetUniqueId(absl::Span<const xla::int64> replicas) = 0; 47 }; 48 49 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory); 50 51 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory(); 52 53 struct InputCoords { InputCoordsInputCoords54 explicit InputCoords(int64 handle) : handle(handle) {} InputCoordsInputCoords55 InputCoords(int64 handle, xla::ShapeIndex index) 56 : handle(handle), index(std::move(index)) {} 57 58 int64 handle = 0; 59 xla::ShapeIndex index; 60 }; 61 62 // Filters the debug options provided as argument according to the value of the 63 // TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is 64 // set to "1" or "true", the debug options will be returned as is. Otherwise 65 // only a subset of them will be set in the returned ones, and all the paths 66 // contained in it, will be limited to gs:// and bigstore:// ones. 67 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); 68 69 // Populates the input_coords with a list of input coordinates from a input_name 70 // op argument. 71 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs( 72 OpKernelContext* context, const char* input_name); 73 74 bool InputShapeMatches(const xla::Shape& parameter_shape, 75 const xla::Shape& input_shape); 76 77 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations( 78 const std::vector<InputCoords>& input_coords, 79 XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, 80 int64 num_input_shapes, 81 const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs); 82 83 Status RebuildOutputAliases( 84 const RefPtr<XRTTupleAllocation>& output_tuple, 85 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples, 86 const xla::HloInputOutputAliasConfig& input_output_alias); 87 88 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers( 89 const xla::HloInputOutputAliasConfig& input_output_alias, 90 absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples, 91 const std::vector<bool>& input_is_dynamic, bool release_inputs); 92 93 // Create the XRT execute output tensor given the computation result 94 // (output_tuple). The return_exploded_tuple tells whether a tuple result should 95 // be returned as vector of handles representing each tuple child. 96 Status CreateExecuteOutput(OpKernelContext* context, 97 XRTMemoryManager* memory_manager, 98 RefPtr<XRTTupleAllocation> output_tuple, 99 bool return_exploded_tuple); 100 101 // Drives the XRT chained computation execution given the supplied core execute 102 // function. 103 using ChainedExecuteFn = 104 std::function<xla::StatusOr<RefPtr<XRTTupleAllocation>>( 105 const xrt::XRTChainedExecuteOp&, 106 absl::Span<const RefPtr<XRTTupleAllocation>>)>; 107 Status ExecuteChained(OpKernelContext* context, 108 const RefPtr<XRTMemoryManager>& memory_manager, 109 xla::Backend* backend, int device_ordinal, 110 const xrt::XRTChainedExecutePlan& plan, 111 const xrt::XRTChainedExecuteConfig& config, 112 const ChainedExecuteFn& execute_op); 113 114 } // namespace tensorflow 115 116 #endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_ 117