• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Utility functions in support of the XRT API.
17 
18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
19 #define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
20 
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/compiler/xla/service/backend.h"
26 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/xla.pb.h"
31 #include "tensorflow/compiler/xrt/xrt.pb.h"
32 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
33 #include "tensorflow/compiler/xrt/xrt_refptr.h"
34 #include "tensorflow/compiler/xrt/xrt_state.h"
35 #include "tensorflow/core/framework/op_kernel.h"
36 #include "tensorflow/core/lib/core/status.h"
37 
38 namespace tensorflow {
39 
40 // Factory class which creates NCCL unique IDs based on the replicas
41 // participating to a given communication. This is only used for GPU backends.
42 struct NcclUniqueIdFactory {
~NcclUniqueIdFactoryNcclUniqueIdFactory43   virtual ~NcclUniqueIdFactory() {}
44 
45   // Generates the NCCL unique ID for the given set of replica IDs.
46   virtual std::string GetUniqueId(absl::Span<const xla::int64> replicas) = 0;
47 };
48 
49 void SetNcclUniqueIdFactory(std::shared_ptr<NcclUniqueIdFactory> factory);
50 
51 std::shared_ptr<NcclUniqueIdFactory> GetNcclUniqueIdFactory();
52 
53 struct InputCoords {
InputCoordsInputCoords54   explicit InputCoords(int64 handle) : handle(handle) {}
InputCoordsInputCoords55   InputCoords(int64 handle, xla::ShapeIndex index)
56       : handle(handle), index(std::move(index)) {}
57 
58   int64 handle = 0;
59   xla::ShapeIndex index;
60 };
61 
62 // Filters the debug options provided as argument according to the value of the
63 // TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is
64 // set to "1" or "true", the debug options will be returned as is. Otherwise
65 // only a subset of them will be set in the returned ones, and all the paths
66 // contained in it, will be limited to gs:// and bigstore:// ones.
67 xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options);
68 
69 // Populates the input_coords with a list of input coordinates from a input_name
70 // op argument.
71 xla::StatusOr<std::vector<InputCoords>> GetComputationInputs(
72     OpKernelContext* context, const char* input_name);
73 
74 bool InputShapeMatches(const xla::Shape& parameter_shape,
75                        const xla::Shape& input_shape);
76 
77 xla::StatusOr<std::vector<RefPtr<XRTTupleAllocation>>> GetInputTupleAllocations(
78     const std::vector<InputCoords>& input_coords,
79     XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend,
80     int64 num_input_shapes,
81     const std::function<xla::Shape(int64)>& shape_getter, bool release_inputs);
82 
83 Status RebuildOutputAliases(
84     const RefPtr<XRTTupleAllocation>& output_tuple,
85     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
86     const xla::HloInputOutputAliasConfig& input_output_alias);
87 
88 xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
89     const xla::HloInputOutputAliasConfig& input_output_alias,
90     absl::Span<const RefPtr<XRTTupleAllocation>> input_tuples,
91     const std::vector<bool>& input_is_dynamic, bool release_inputs);
92 
93 // Create the XRT execute output tensor given the computation result
94 // (output_tuple). The return_exploded_tuple tells whether a tuple result should
95 // be returned as vector of handles representing each tuple child.
96 Status CreateExecuteOutput(OpKernelContext* context,
97                            XRTMemoryManager* memory_manager,
98                            RefPtr<XRTTupleAllocation> output_tuple,
99                            bool return_exploded_tuple);
100 
101 // Drives the XRT chained computation execution given the supplied core execute
102 // function.
103 using ChainedExecuteFn =
104     std::function<xla::StatusOr<RefPtr<XRTTupleAllocation>>(
105         const xrt::XRTChainedExecuteOp&,
106         absl::Span<const RefPtr<XRTTupleAllocation>>)>;
107 Status ExecuteChained(OpKernelContext* context,
108                       const RefPtr<XRTMemoryManager>& memory_manager,
109                       xla::Backend* backend, int device_ordinal,
110                       const xrt::XRTChainedExecutePlan& plan,
111                       const xrt::XRTChainedExecuteConfig& config,
112                       const ChainedExecuteFn& execute_op);
113 
114 }  // namespace tensorflow
115 
116 #endif  // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
117