• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
17 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
18 
19 #include "absl/base/call_once.h"
20 #include "tensorflow/compiler/jit/shape_inference.h"
21 #include "tensorflow/core/common_runtime/device_mgr.h"
22 #include "tensorflow/core/common_runtime/optimization_registry.h"
23 #include "tensorflow/core/framework/function.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h"
28 #include "tensorflow/core/tpu/tpu_api.h"
29 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
30 #include "tensorflow/core/util/reffed_status_callback.h"
31 #include "absl/container/flat_hash_map.h"
32 
33 namespace tensorflow {
34 // Holds node's shape information for Concat/Split.
35 using EdgeShapes = absl::flat_hash_map<const Edge*, std::vector<int>>;
36 using GroupedEdges =
37     absl::flat_hash_map<std::string, std::vector<const Edge*>>;
38 
39 // Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op that
40 // we find as part of searching for inputs to models that are replicated.
41 using XlaShardingInfoMap = absl::flat_hash_map<
42     std::string, std::tuple<tensorflow::DataType, std::string, std::string>>;
43 
44 // Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep
45 // for each TpuReplicatedInput op that we find as part of searching for inputs
46 // to models that are replicated.
47 using TpuReplicatedInputInfoMap =
48     absl::flat_hash_map<std::string,
49                            std::tuple<tensorflow::DataType, Node*>>;
50 
51 namespace tpu_functional_internal {
52 
53 // Helper functions for graph rewrites.
54 GroupedEdges GroupTensorsForInputPacking(
55     const EdgeShapes& tpu_input_shapes,
56     const absl::flat_hash_map<const Edge*, DataType>& tpu_input_dtypes,
57     bool input_shape_opt, bool group_tensors_for_packing);
58 GroupedEdges GroupTensorsForOutputPacking(Graph* graph,
59                                           EdgeShapes& tpu_output_shapes,
60                                           GraphShapeInfo* shape_info);
61 
62 Status CreateConcatAndSplitNodesForInputTensor(
63     Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes,
64     const absl::flat_hash_map<std::string, std::vector<const Edge*>>&
65         grouped_input_edges,
66     int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded,
67     const XlaShardingInfoMap& xla_sharding_info,
68     const TpuReplicatedInputInfoMap& tpu_replicated_input_info);
69 Status CreateConcatAndSplitNodesForOutputTensor(
70     Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes,
71     GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output,
72     int32_t minimum_output_tensors_packing);
73 
74 Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name,
75                               EdgeShapes* tpu_input_shapes,
76                               int num_cores_per_replica);
77 
78 }  // namespace tpu_functional_internal
79 
80 typedef FunctionLibraryRuntime::Handle FHandle;
81 
82 // A `TPUPartitionedCallOp` asynchronously executes a function on exactly one
83 // TPU core and potentially across multiple other devices, but within a single
84 // process. The kernel places and partitions the function's underlying graph,
85 // executing each of the partitioned subgraphs as a function.
86 //
87 // The core on which the TPU computation is executed must be specified via the
88 // `device_ordinal` input. Different invocations of this op may specify
89 // different device ordinals, making it possible to map TPU computations to
90 // different cores at runtime. Currently, macro-substitution of device ordinals
91 // is only supported for the following whitelisted ops:
92 //   * TPUExecute
93 //   * InfeedEnqueue
94 //   * InfeedEnqueueTuple
95 //
96 // Attempting to compute a TPUPartitionedCallOp whose function body has a
97 // non-whitelisted node bearing an attribute named "device_ordinal" will result
98 // in an error.
99 //
100 // TODO(akshayka): This class duplicates most of the logic of
101 // `PartitionedCallOp`; once that class and this one have evolved to stable
102 // states, and if at that time they remain sufficiently similar, either unify
103 // them in one op or set up an inheritance structure that allows for code reuse.
104 class TPUPartitionedCallOp : public AsyncOpKernel {
105  public:
TPUPartitionedCallOp(OpKernelConstruction * ctx)106   explicit TPUPartitionedCallOp(OpKernelConstruction* ctx)
107       : AsyncOpKernel(ctx),
108         pool_(ctx->env(), "InitializeVarOnTPUPool", 1),
109         library_runtime_(nullptr) {
110     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
111     auto status = ctx->GetAttr("autotuner_thresh", &autotuner_thresh_);
112     if (!status.ok()) {
113       autotuner_thresh_ = 0;
114     }
115     tensorflow::tpu::OpsApiFn()->TfTpu_GetTpuPartitionedCallParamsFn(
116         &runtime_params_);
117   }
118 
~TPUPartitionedCallOp()119   ~TPUPartitionedCallOp() override {}
120 
121   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
122 
123  private:
124   struct DeviceAndFHandle {
125     std::string device;
126     FHandle handle;
127 
128     // The FLD passed to `library_runtime_` as an overlay function library for
129     // instantiation of function `handle`. This is a snapshot of the currrent
130     // `flib_def_`. Since `flib_def_` can be changed concurrently by another
131     // graph rewrite when executing `handle`, we need to make sure each
132     // `handle` uses a different FLD to avoid races. See b/181149591.
133     std::unique_ptr<FunctionLibraryDefinition> flib_def;
134   };
135 
136   Status GetTpuCoreOrdinal(OpKernelContext* ctx, uint64 input_hash,
137                            int64_t* ordinal_selector_req_id,
138                            int32_t* core_ordinal)
139       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
140 
141   // Helper to create and initialize a TPU variable given a CPU variable
142   // var: the CPU variable created by the user
143   // ndef: the node def of the corresponding TPU var handle that we created
144   // device_ordinal: TPU device ordinal on which to initialize this variable
145   Status InitializeVarOnTPU(OpKernelContext* ctx,
146                             const core::RefCountPtr<Var>& var, NodeDef* ndef,
147                             int device_ordinal, bool fast_mem)
148       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
149 
150   // Helper to create and initialize partitioned TPU variables given a CPU
151   // variable with XLA sharding annotation.
152   // var: the CPU variable created by the user.
153   // ndefs: the node def of the corresponding TPU var handle on all the logical
154   //   cores.
155   // split_dim: the partition dimension of the variable. If -1, the variable is
156   //   replicated.
157   // device_ordinal: The index of the TPU core that is scheduled to run
158   //   the computation. In the case of XLA SPMD, it is the "primary" core, which
159   //   is the smallest index of all the cores.
160   Status InitializeShardedVarOnTPU(OpKernelContext* ctx,
161                                    const core::RefCountPtr<Var>& var,
162                                    std::vector<NodeDef>& ndefs, int split_dim,
163                                    int device_ordinal)
164       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
165 
166   // Check if any of the immediate successors of node has attribute
167   // "_tpu_replicate".
168   bool IsInputToTPUReplicate(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
169 
170   // Replace an _Arg node of type DT_RESOURCE by a VarHandleOp on TPU
171   Status ReplaceResourceArgsWithVarHandleOps(Graph* graph, OpKernelContext* ctx,
172                                              int device_ordinal,
173                                              int num_cores_per_replica,
174                                              bool enable_spmd_xla_partitioning)
175       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
176 
177   // Replace a _Arg node indicates a variable on CPU host by sharded/replicated
178   // variables on all logical TPU devices.
179   Status ReplaceAndPartitionXLAShardingVariable(
180       Graph* graph, OpKernelContext* ctx, int device_ordinal,
181       ResourceHandle& handle, Node* variable, int num_cores_per_replica)
182       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
183 
184   Status ShardInputsWithXlaSharding(Graph* graph, int num_cores_per_replica,
185                                     OpKernelContext* ctx)
186       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
187 
188   // Rewrite the graph for input and output optimiazations.
189   // TODO(ylc): Move this function to Graph optimization pass.
190   Status OptimizeTpuInputOutputTensors(
191       Graph* graph, bool enable_spmd_xla_partitioning,
192       int num_cores_per_replica,
193       std::map<std::string, std::vector<int>>& named_input_shapes,
194       OpKernelContext* ctx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195 
196   Status InferShapesWithResourceVar(Graph* graph, OpKernelContext* ctx,
197                                     std::map<int, InferredShape>& arg_shapes,
198                                     GraphShapeInfo* tpu_inferred_info);
199 
200   // Copies the graph backing `func_` into `graph`.
201   Status GetGraphFromFunction(Graph* graph, int device_ordinal,
202                               int* num_core_per_replica,
203                               bool* use_spmd_for_xla_partitioning)
204       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
205 
206   // Places the graph carried by `optimization_options` and runs graph
207   // optimization passes (pre-placement, post-placement, and post-rewrite).
208   Status PlacementHelper(
209       const DeviceSet& device_set,
210       const GraphOptimizationPassOptions& optimization_options,
211       const string& function_name);
212   // Partitions `graph`, populates `subgraphs` with the partitions, and runs
213   // the post-partitioning graph optimization passes.
214   Status PartitionHelper(
215       const DeviceSet& device_set,
216       const GraphOptimizationPassOptions& optimization_options, Graph* graph,
217       std::unordered_map<std::string, std::unique_ptr<Graph>>* subgraphs);
218 
219   // Adds and instantiates a function backed by `graph` with name
220   // `function_name` on device `target_device`, storing the handle in `handle`.
221   // If `out_flib_def` is not null, it will be set to a copy of `flib_def_` and
222   // used for instantiation.
223   Status InstantiatePartition(
224       const Graph& graph, const string& function_name,
225       const string& target_device, FHandle* handle,
226       std::unique_ptr<FunctionLibraryDefinition>* out_flib_def)
227       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
228   // Adds and instantiates functions for each subgraph in `subgraphs` after
229   // rewriting nodes' `device_ordinal` attributes to match `replica_id` when
230   // num_cores_per_replica == 1.
231   Status InstantiateFunctionsFromSubgraphs(
232       const DeviceSet& device_set, int replica_id, uint64 cache_hash,
233       int num_cores_per_replica,
234       std::unordered_map<std::string, std::unique_ptr<Graph>> subgraphs)
235       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
236 
237   // Rewrites `graph` such that the device ordinal attributes of all whitelisted
238   // nodes (see `IsSupportedTPUOp`) are set to `device_ordinal`;
239   // `*modified` is set to true if the graph is modified in the process (i.e.,
240   // if it contains a whitelisted node), otherwise is unmodified.
241   //
242   // Returns an error if
243   //   (1) the graph contains a non-whitelisted node that carries an attribute
244   //       with name "device_ordinal", or
245   //   (2) the set of device ordinals found among the graph's nodes has
246   //       cardinality greater than 1.
247   Status SetDeviceOrdinal(const DeviceSet& device_set, int device_ordinal,
248                           Graph* graph, bool* modified)
249       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
250 
251   void ExecuteRemoteFunction(const FunctionLibraryRuntime::Options& opts,
252                              FHandle handle, OpKernelContext* ctx,
253                              ReffedStatusCallback* done)
254       ABSL_LOCKS_EXCLUDED(mu_);
255   void ExecuteLocalFunction(const FunctionLibraryRuntime::Options& opts,
256                             const OpInputList& arguments, FHandle handle,
257                             OpKernelContext* ctx, ReffedStatusCallback* done)
258       ABSL_LOCKS_EXCLUDED(mu_);
259   void ExecuteFunctions(const std::vector<DeviceAndFHandle>& functions,
260                         OpKernelContext* ctx, int device_ordinal,
261                         int64_t ordinal_selector_req_id, DoneCallback done)
262       ABSL_LOCKS_EXCLUDED(mu_);
263 
ShouldUseRemoteExecutionForFn(const std::string & target_device,bool * remote_execution)264   Status ShouldUseRemoteExecutionForFn(const std::string& target_device,
265                                        bool* remote_execution) {
266     DeviceNameUtils::ParsedName target_device_parsed;
267     DeviceNameUtils::ParsedName local_device_parsed;
268 
269     if (!DeviceNameUtils::ParseFullOrLocalName(target_device,
270                                                &target_device_parsed)) {
271       return errors::InvalidArgument("Cannot parse target device ",
272                                      target_device);
273     }
274     if (!DeviceNameUtils::ParseFullOrLocalName(local_device_name_,
275                                                &local_device_parsed)) {
276       return errors::InvalidArgument("Cannot parse local device ",
277                                      local_device_name_);
278     }
279 
280     if (DeviceNameUtils::AreCompatibleDevNames(target_device_parsed,
281                                                local_device_parsed)) {
282       *remote_execution = false;
283     } else {
284       *remote_execution = true;
285     }
286     return Status::OK();
287   }
288 
289   // Init once flagas.
290   absl::once_flag once_;
291   absl::once_flag ordinal_selector_once_;
292 
293   // Device manager and device set.
294   const DeviceMgr* device_mgr_;
295   DeviceSet device_set_;
296 
297   // Threadpool.
298   thread::ThreadPool pool_;
299 
300   // `func_` is the original function supplied to this OpKernel.
301   NameAttrList func_;
302   string local_device_name_;
303   // Maps from cache key to their corresponding functions, which are
304   // represented as (device, handle) pairs.
305   gtl::FlatMap<uint64, std::vector<DeviceAndFHandle>> partition_cache_
306       ABSL_GUARDED_BY(mu_);
307 
308   // A set contains seen ordinals. Used by variable initialization on TPU.
309   absl::flat_hash_set<int> seen_ordinals_;
310 
311   // Record the indices of the _Arg with type DT_RESOURCE that goes
312   // into a TPU Op.
313   std::vector<bool> replaced_input_indices_;
314 
315   absl::Mutex mu_;
316   // Function shards are added to the `flib_def_`, and later on it'll create
317   // a copy of `flib_def_` to pass to `library_runtime_` as an overlay function
318   // library for instantiation.
319   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
320   FunctionLibraryRuntime* library_runtime_;
321 
322   // Used to uniquify function names in `flib_def_`.
323   uint32 suffix_ = 0;
324 
325   // Minimum number of run steps (batches) necessary to trigger xla autotuner.
326   int autotuner_thresh_ = 0;
327 
328   // TPU core selection.
329   std::shared_ptr<tpu::TPUOrdinalSelector> ordinal_selector_;
330 
331   // Maps input hash to TF fingerprint.
332   absl::flat_hash_map<uint64, uint64> inputs_to_fingerprint_;
333 
334   // List of TPU devices
335   std::vector<Device*> tpu_devices_;
336 
337   TpuPartitionedCall_Params runtime_params_;
338 };
339 
340 }  // namespace tensorflow
341 
342 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_
343