1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
17 #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
18 
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
24 #include "tensorflow/c/eager/tfe_context_internal.h"
25 #include "tensorflow/core/common_runtime/composite_device.h"
26 #include "tensorflow/core/common_runtime/eager/context.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/node_def_builder.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/fingerprint.h"
34 #include "tensorflow/dtensor/cc/constants.h"
35 #include "tensorflow/dtensor/cc/dstatus.h"
36 #include "tensorflow/dtensor/cc/tensor_layout.h"
37 
38 namespace tensorflow {
39 namespace dtensor {
40 
41 #define RETURN_STATUS(status, code, message)   \
42   {                                            \
43     TF_SetStatus((status), (code), (message)); \
44     return;                                    \
45   }
46 
47 #define RETURN_C_STATUS_IF_NOT_OK(cpp_status, c_status)                   \
48   {                                                                       \
49     auto return_if_not_ok_status = (cpp_status);                          \
50     if (!return_if_not_ok_status.ok()) {                                  \
51       RETURN_STATUS((c_status),                                           \
52                     static_cast<TF_Code>(return_if_not_ok_status.code()), \
53                     return_if_not_ok_status.error_message().c_str());     \
54     }                                                                     \
55   }
56 
57 // Using a counter to uniquify instead of a new block allows `var` to declare a
58 // new variable.
59 #define ASSIGN_OR_RETURN_C_STATUS(var, cpp_status, c_status)               \
60   ASSIGN_OR_RETURN_C_STATUS_IMPL(                                          \
61       TF_STATUS_MACROS_CONCAT_NAME(_dtensor_status_or_value, __COUNTER__), \
62       var, cpp_status, c_status)
63 
64 #define ASSIGN_OR_RETURN_C_STATUS_IMPL(statusor, var, cpp_status, c_status) \
65   auto statusor = (cpp_status);                                             \
66   RETURN_C_STATUS_IF_NOT_OK(statusor.status(), (c_status));                 \
67   var = std::move(statusor.ValueOrDie());
68 
69 struct TranslatedFunction {
70   // Mesh for which specified function will run.
71   Mesh function_mesh;
72 
73   // StatefulPartitionedCall op to run the mesh function.
74   const Node* node_to_execute = nullptr;
75 
76   // Maps i-th local input index to input index in global graph.
77   std::vector<int> input_index_map;
78 
79   // Maps i-th local output to output index of global graph.
80   std::vector<int> output_index_map;
81 
82   std::string translated_function_name;
83   // For resource ops, layouts of resource handles are inferred lazily
84   // during SPMD expansion of resource assign ops. In that case,
85   // inferred layouts of resource handles are attached to arg nodes
86   // of the returned graph.
87   std::map<int, Layout> resource_input_layouts;
88   // Record some metadata for output of a shape op. This would help recover
89   // local shape on future operations over the Tensor.
90   std::map<int, Layout> shape_output_metadata;
91   std::vector<Layout> output_layouts;
92   // Local shapes inferred for function outputs; these may be partially known.
93   std::vector<PartialTensorShape> local_output_shapes;
94   // Output data types.
95   std::vector<TF_DataType> output_dtypes;
96 };
97 
98 struct ExecutionFunctions {
99   // Stores information about all functions to execute for provided computation.
100   std::vector<TranslatedFunction> function_list;
101   // Number of device ids args added to translated functions.
102   // During translation, we insert one device id arg node per mesh.
103   // For a single mesh function, it equals 1.
104   // For a multi-mesh function (e.g. pipelining), it equals the number of
105   // meshes.
106   int num_device_ids;
107 
108   // Mesh fingerprint of function_list. Set only when ExecutionFunctions refers
109   // to a function for performance reason, since an eager op doesn't use it.
110   uint64 function_mesh_fingerprint = 0;
111 };
112 
113 // TODO(yujingzhang): move FingerprintCat128 to tensorflow/platform.
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)114 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
115                                                const tensorflow::Fprint128& b) {
116   return {tensorflow::FingerprintCat64(a.low64, b.low64),
117           tensorflow::FingerprintCat64(a.high64, b.high64)};
118 }
119 
FingerprintCat128(const tensorflow::Fprint128 & a,const int64 b)120 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
121                                                const int64 b) {
122   auto x = tensorflow::FingerprintCat64(a.low64, b);
123   return {x, tensorflow::FingerprintCat64(a.high64, x)};
124 }
125 
126 struct DTensorOperation {
127   // For both fields: not owned. lifetime covers the whole usage.
128   const char* name;
129   const FunctionDef* function_def;
130 
is_funcDTensorOperation131   inline bool is_func() const { return function_def != nullptr; }
132 };
133 
134 struct EmbeddingResourceAttrs {
135   int64_t table_id;
136   absl::optional<int64_t> slot_id;  // NOLINT
137   bool is_dirty = false;
138 };
139 
140 // Contains a mesh bundled with a parallel device over all of the devices in
141 // that mesh.
142 class MeshWithParallelDevice {
143  public:
144   MeshWithParallelDevice(
145       const Mesh& mesh_config,
146       std::unique_ptr<parallel_device::ParallelDevice> parallel_device,
147       const std::string& composite_device_name = "")
mesh_config_(mesh_config)148       : mesh_config_(mesh_config),
149         parallel_device_(std::move(parallel_device)),
150         composite_device_name_(composite_device_name),
151         // Device IDs are constructed lazily because we don't have a context
152         // until we start executing ops.
153         device_ids_tensor_(nullptr) {}
154 
155   // A parallel tensor containing scalar integer device IDs for underlying
156   // devices, each placed on its corresponding device.
157   //
158   // TODO(allenl): It would be nice if DeviceID worked as an op inside the
159   // function's graph. Then we wouldn't need to feed it as an argument.
160   parallel_device::ParallelTensor* DeviceIDs(TFE_Context* context,
161                                              TF_Status* status) const;
parallel_device()162   const parallel_device::ParallelDevice& parallel_device() const {
163     return *parallel_device_;
164   }
165 
mesh_config()166   const dtensor::Mesh& mesh_config() const { return mesh_config_; }
167 
168   // Creates a CompositeDevice in eager context if it not exists.
169   // Called when parallel_device_ contains a subset of global devices, e.g.
170   // pipelining is enabled.
FindOrCreateCompositeDevice(TFE_Context * context)171   StatusOr<CompositeDevice*> FindOrCreateCompositeDevice(TFE_Context* context) {
172     if (composite_device_ == nullptr && !composite_device_name_.empty()) {
173       if (mesh_config_.global_devices().empty()) {
174         return errors::InvalidArgument(
175             "Expect non-empty global devices when creating a CompositeDevice.");
176       }
177       TF_RETURN_IF_ERROR(ContextFromInterface(tensorflow::unwrap(context))
178                              ->FindOrCreateCompositeDevice(
179                                  mesh_config_.global_devices(),
180                                  composite_device_name_, &composite_device_));
181     }
182     return composite_device_;
183   }
184 
composite_device()185   CompositeDevice* composite_device() const { return composite_device_; }
186 
187  private:
188   dtensor::Mesh mesh_config_;
189   std::unique_ptr<parallel_device::ParallelDevice> parallel_device_;
190 
191   // Set when parallel_device_ contains a subset of global devices, e.g.
192   // pipelining is enabled.
193   const std::string composite_device_name_;
194   // A tensorflow::Device that represents underlying devices of
195   // parallel_device_. Set when composite_device_name_ is not empty.
196   CompositeDevice* composite_device_ = nullptr;  // owned by eager context
197 
198   // Constructed lazily; contains a parallel tensor with scalar integer device
199   // IDs for each device.
200   mutable std::unique_ptr<parallel_device::ParallelTensor> device_ids_tensor_;
201 };
202 
203 enum TensorType {
204   kDense = 0,
205   kResource = 1,
206   kSparse = 2,
207 };
208 
209 class TensorWithLayout {
210  public:
211   // Broadcast a single non-parallel tensor onto `mesh` with a fully replicated
212   // sharding spec. Does not take ownership of `tensor`.
213   static std::unique_ptr<TensorWithLayout> Broadcast(
214       TFE_Context* context, TFE_TensorHandle* tensor,
215       const MeshWithParallelDevice& mesh,
216       const std::string& dtensor_device_name, TF_Status* status);
217 
218   // Given an already-parallel tensor, wraps it with a mesh and a layout.
219   static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap(
220       std::unique_ptr<parallel_device::ParallelTensor> tensor,
221       const MeshWithParallelDevice& mesh, const Layout& layout);
222 
223   // A dummy TensorWithLayout without holding a ParallelTensor.
224   static std::unique_ptr<TensorWithLayout> Dummy(
225       const std::vector<int64_t>& local_shape, const TF_DataType dtype,
226       const MeshWithParallelDevice& mesh, const Layout& layout);
227 
~TensorWithLayout()228   virtual ~TensorWithLayout() {}
229 
layout()230   virtual const Layout& layout() const { return layout_; }
231 
tensor_type()232   virtual TensorType tensor_type() const { return TensorType::kDense; }
233 
dtype()234   virtual TF_DataType dtype() const {
235     if (dtype_.has_value()) {
236       return dtype_.value();
237     } else {
238       return tensor_->dtype();
239     }
240   }
241 
242   // Small constant value optimization for non-resource-handle tensors.
set_const_value(NodeDef & const_node)243   virtual void set_const_value(NodeDef& const_node) {
244     // If we extracted a constant value from the tensor, check if this
245     // value was the output from `tf.shape`. In this case, we need to
246     // forward the kShapeOpInputLayout attribute to the new node def. This
247     // is needed for layout propagation when running in op-by-op mode.
248     //
249     // TODO(b/162747667): Improve the presentation for Shape input Op
250     //                    layout.
251     if (shape_metadata_layout().has_value()) {
252       AddNodeAttr(kShapeOpInputLayout, {shape_metadata_layout()->ToString()},
253                   &(const_node));
254     }
255     const_value_.emplace(const_node);
256   }
257 
258   // Clears the cached const value if present.
reset_const_value()259   void reset_const_value() { const_value_.reset(); }
260 
261   // Encodes the NodeDef via provided builder, if applicable.
EncodeAttributes(tensorflow::NodeDefBuilder & builder)262   virtual void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const {}
263 
264   virtual tensorflow::Fprint128 CacheKey() const;
265 
266   // Updates layout for this Tensor.
UpdateLayout(const Layout & new_layout,TF_Status * status)267   virtual void UpdateLayout(const Layout& new_layout, TF_Status* status) {
268     TF_SetStatus(status, TF_INTERNAL,
269                  "Attempt to update layout on non-resource-handle");
270   }
271 
272   // Update shape and dtype.
UpdateShapeAndDType(const TensorShapeProto & shape,const DataType & dtype,TF_Status * status)273   virtual void UpdateShapeAndDType(const TensorShapeProto& shape,
274                                    const DataType& dtype, TF_Status* status) {
275     TF_SetStatus(status, TF_INTERNAL,
276                  "Attempt to update shape and layout on non-resource-handle");
277   }
278 
279   // Update Attrs for this Tensor.
UpdateAttrs(const EmbeddingResourceAttrs & attrs,TF_Status * status)280   virtual void UpdateAttrs(const EmbeddingResourceAttrs& attrs,
281                            TF_Status* status) {
282     TF_SetStatus(status, TF_INTERNAL,
283                  "Attempt to update layout on non-resource-handle");
284   }
285 
get_tensor(size_t index)286   virtual TFE_TensorHandle* get_tensor(size_t index) const {
287     return tensor()->tensor(index);
288   }
289 
num_tensors()290   virtual size_t num_tensors() const { return tensor()->num_tensors(); }
291 
tensor()292   virtual parallel_device::ParallelTensor* tensor() const {
293     return tensor_.get();
294   }
295 
296   // Returns a string which includes just the value and layout of the tensor.
297   virtual std::string SummarizeValue() const;
298   // Returns a string which includes `SummarizeValue` along with shape and type
299   // information.
300   virtual std::string DebugString() const;
301 
set_input_layout_for_shape_op_result(const Layout & layout)302   void set_input_layout_for_shape_op_result(const Layout& layout) {
303     input_layout_for_shape_op_result_.emplace(layout);
304   }
305 
shape_metadata_layout()306   const absl::optional<Layout> shape_metadata_layout() const {
307     return input_layout_for_shape_op_result_;
308   }
309 
mesh()310   const MeshWithParallelDevice& mesh() const { return mesh_; }
311 
312   // Compute global shape from layout & local tensor shape.
313   //
314   // For replicated layout tensors, global shape is simply the shape of local
315   // tensors on each device. For sharded tensor, this is the global shape
316   // encodes layout & local shape on each device.
global_shape()317   const std::vector<int64_t> global_shape() const {
318     return layout().GlobalShapeFromLocalShape(local_shape());
319   }
320 
local_shape()321   const std::vector<int64_t> local_shape() const { return local_shape_; }
322 
const_value()323   const absl::optional<NodeDef> const_value() const { return const_value_; }
324 
attrs()325   const absl::optional<EmbeddingResourceAttrs>& attrs() const { return attrs_; }
326 
327  protected:
328   TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor,
329                    const MeshWithParallelDevice& mesh, const Layout& layout,
330                    std::vector<int64_t> local_shape,
331                    absl::optional<TF_DataType> dtype = absl::nullopt,
332                    absl::optional<NodeDef> const_value = absl::nullopt)
tensor_(std::move (tensor))333       : tensor_(std::move(tensor)),
334         layout_(layout),
335         mesh_(mesh),
336         const_value_(std::move(const_value)),
337         local_shape_(local_shape),
338         dtype_(dtype) {}
339 
340   std::unique_ptr<parallel_device::ParallelTensor> tensor_;
341 
342   Layout layout_;
343 
344   const MeshWithParallelDevice& mesh_;
345 
346   // Optionally holds the value of a small, non-resource tensor. Small constants
347   // are directly folded into the SPMD graph instead of being passed as inputs.
348   // This provides extra information to the layout propagation and SPMD passes
349   // during op-by-op execution. (For example, the reduction indices for Sum,
350   // target shapes for Rng/Reshape, etc).
351   absl::optional<NodeDef> const_value_;
352 
353   // Optionally holds the original input layout for a shape Op returned Tensor.
354   // This is used to preserve information for a shape op output so that future
355   // uses could recover local shape.
356   // TODO(hthu,allenl,xiejw): Move this into a separate class for clarity.
357   absl::optional<Layout> input_layout_for_shape_op_result_ = absl::nullopt;
358 
359   // The local shape of tensors placed on each of `tensor_`'s component devices.
360   std::vector<int64_t> local_shape_;
361 
362   absl::optional<TF_DataType> dtype_;
363 
364   // Resource input attributes for embedding inputs.
365   absl::optional<EmbeddingResourceAttrs> attrs_;  // NOLINT
366 };
367 
368 // Extension of TensorWithLayout which holds resource handle with layout.
369 //
370 // The major differences are
371 // 1. The layout, shape, dtype are lazily set as they are unavailable upon
372 //    creation.
373 // 2. Small const optimization should be disabled.
374 class ResourceHandleWithLayout : public TensorWithLayout {
375  public:
376   // The layout of uninitialized resource tensors, or the layout of the tensor
377   // contained in an initialized resource.
layout()378   const Layout& layout() const override {
379     return dereferenced_layout_.has_value() ? dereferenced_layout_.value()
380                                             : layout_;
381   }
382 
tensor_type()383   TensorType tensor_type() const override { return TensorType::kResource; }
384 
set_const_value(NodeDef & const_node)385   void set_const_value(NodeDef& const_node) override {
386     // Just a no-op for resource handle. Maybe we should error out.
387   }
388 
389   void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override;
390 
391   tensorflow::Fprint128 CacheKey() const override;
392 
393   void UpdateLayout(const Layout& new_layout, TF_Status* status) override;
394 
UpdateShapeAndDType(const TensorShapeProto & shape,const DataType & dtype,TF_Status * status)395   void UpdateShapeAndDType(const TensorShapeProto& shape, const DataType& dtype,
396                            TF_Status* status) override {
397     set_dereferenced_shape(shape);
398     set_dereferenced_dtype(dtype);
399   }
400 
401   void UpdateAttrs(const EmbeddingResourceAttrs& attrs,
402                    TF_Status* status) override;
403 
UpdateDirtyness(bool is_dirty,TF_Status * status)404   void UpdateDirtyness(bool is_dirty, TF_Status* status) {
405     if (!attrs_.has_value()) {
406       TF_SetStatus(status, TF_INTERNAL,
407                    "Attempt to update dirtyness on non embedding resource");
408     }
409     attrs_.value().is_dirty = is_dirty;
410   }
411 
set_dereferenced_shape(const TensorShapeProto & shape)412   void set_dereferenced_shape(const TensorShapeProto& shape) {
413     dereferenced_shape_.emplace(shape);
414   }
set_dereferenced_dtype(const DataType & dtype)415   void set_dereferenced_dtype(const DataType& dtype) {
416     dereferenced_dtype_.emplace(dtype);
417   }
418 
dereferenced_shape()419   const absl::optional<TensorShapeProto>& dereferenced_shape() const {
420     return dereferenced_shape_;
421   }
dereferenced_dtype()422   const absl::optional<DataType>& dereferenced_dtype() const {
423     return dereferenced_dtype_;
424   }
425 
426  public:
ResourceHandleWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor,const MeshWithParallelDevice & mesh,const Layout & layout,std::vector<int64_t> local_shape)427   ResourceHandleWithLayout(
428       std::unique_ptr<parallel_device::ParallelTensor> tensor,
429       const MeshWithParallelDevice& mesh, const Layout& layout,
430       std::vector<int64_t> local_shape)
431       : TensorWithLayout(std::move(tensor), mesh, layout, local_shape,
432                          TF_RESOURCE) {}
433 
434  private:
435   // The layout of the tensor pointed to by this handle, if any.
436   absl::optional<Layout> dereferenced_layout_;
437   // The shape and dtype of the tensor pointed to by this resource tensor.
438   absl::optional<TensorShapeProto> dereferenced_shape_;
439   absl::optional<DataType> dereferenced_dtype_;
440 };
441 
442 // TensorWithLayout for SparseTensors.
443 //
444 // The main difference between this and TensorWithLayout is this
445 // contains 3 lists of tensors as opposed to one (values, indices, shapes).
446 // The shapes of the SparseTensors will always be the dense view of the shapes,
447 // and thus will have no difference with the TensorWithLayout in terms of
448 // shapes.
449 class SparseTensorWithLayout : public TensorWithLayout {
450  public:
451   static StatusOr<std::unique_ptr<TensorWithLayout>> Wrap(
452       std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
453       std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
454       std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
455       const MeshWithParallelDevice& mesh, const Layout& layout,
456       std::vector<int64_t> local_shape);
457 
458   // A dummy TensorWithLayout without holding a ParallelTensor.
Dummy(const std::vector<int64_t> & local_shape,const MeshWithParallelDevice & mesh,const Layout & layout)459   static std::unique_ptr<TensorWithLayout> Dummy(
460       const std::vector<int64_t>& local_shape,
461       const MeshWithParallelDevice& mesh, const Layout& layout) {
462     return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout(
463         /*indices=*/nullptr, /*values=*/nullptr, /*dense_shapes=*/nullptr, mesh,
464         layout, local_shape));
465   }
466 
set_const_value(NodeDef & const_node)467   void set_const_value(NodeDef& const_node) override {
468     // No-op for SparseTensors, consider erroring out.
469   }
470 
471   // Add attribute '_sparse' to the NodeDefBuilder so that the mlir::Value
472   // that originate from SparseTensorWithLayout are marked as '_sparse'.
EncodeAttributes(tensorflow::NodeDefBuilder & builder)473   void EncodeAttributes(tensorflow::NodeDefBuilder& builder) const override {
474     builder.Attr("_sparse", true);
475   }
476 
tensor_type()477   TensorType tensor_type() const override { return TensorType::kSparse; }
478 
num_tensors()479   size_t num_tensors() const override { return 3 * indices()->num_tensors(); }
480 
481   TFE_TensorHandle* get_tensor(size_t index) const override;
482 
483   std::string SummarizeValue() const override;
484 
485   std::string DebugString() const override;
486 
487   TF_DataType dtype() const override;
488 
indices()489   parallel_device::ParallelTensor* indices() const { return indices_.get(); }
490 
values()491   parallel_device::ParallelTensor* values() const { return values_.get(); }
492 
dense_shapes()493   parallel_device::ParallelTensor* dense_shapes() const {
494     return dense_shapes_.get();
495   }
496 
497  protected:
498   SparseTensorWithLayout(
499       std::unique_ptr<parallel_device::ParallelTensor> indices,
500       std::unique_ptr<parallel_device::ParallelTensor> values,
501       std::unique_ptr<parallel_device::ParallelTensor> dense_shapes,
502       const MeshWithParallelDevice& mesh, const Layout& layout,
503       std::vector<int64_t> local_shape,
504       absl::optional<TF_DataType> dtype = absl::nullopt,
505       absl::optional<NodeDef> const_value = absl::nullopt)
TensorWithLayout(nullptr,mesh,layout,local_shape)506       : TensorWithLayout(nullptr, mesh, layout, local_shape),
507         indices_(std::move(indices)),
508         values_(std::move(values)),
509         dense_shapes_(std::move(dense_shapes)) {}
510   std::unique_ptr<parallel_device::ParallelTensor> indices_;
511   std::unique_ptr<parallel_device::ParallelTensor> values_;
512   std::unique_ptr<parallel_device::ParallelTensor> dense_shapes_;
513 };
514 
515 template <typename T>
ShapeToDebugString(const std::vector<T> shape_vector)516 std::string ShapeToDebugString(const std::vector<T> shape_vector) {
517   std::vector<tensorflow::int64> cast_shape(shape_vector.begin(),
518                                             shape_vector.end());
519   tensorflow::PartialTensorShape shape;
520   if (!tensorflow::PartialTensorShape::MakePartialShape(
521            cast_shape.data(), cast_shape.size(), &shape)
522            .ok()) {
523     return "<error displaying shape>";
524   } else {
525     return shape.DebugString();
526   }
527 }
528 // Class that holds information about DTensor Functions ran, including cached
529 // lowered functions and constant folding input information per function.
530 //
531 //
532 // The caching policy for constant folded inputs is the following:
533 //   In the first call to a function, we assume that all the inputs that
534 //   are constant foldable are constant folded and save these values. In the
535 //   next call to the same function call, we compare the values of constant
536 //   folded inputs to the previous constant folded inputs. We disable constant
537 //   folding for the changed values, and save these new inputs.
538 // TODO(b/169348205) Support cache eviction if the cache gets bloated.
539 class FunctionManager {
540  public:
541   FunctionManager() = default;
542 
543   // Caches the graph with the lowered 'function'.
544   const ExecutionFunctions* AddCachedFunction(const DTensorOperation& op,
545                                               tensorflow::Fprint128 cache_key,
546                                               ExecutionFunctions function);
547 
548   // Returns the cache key and the cached lowered graph for the function.
549   // Returns a nullptr for the lowered graph if there is a cache miss.
550   // Upon a cache miss, this will save some metadata about the function
551   // and the small inputs to keep track of information for constant folding.
552   std::pair<tensorflow::Fprint128, const ExecutionFunctions*> GetCachedFunction(
553       const DTensorOperation& doperation, const NameAttrList& attributes,
554       const std::vector<TensorWithLayout*>& inputs,
555       const std::vector<const Layout*>& output_layouts);
556 
557   // Returns whether the input at `input_index` is known to be constant
558   // foldable for function `doperation`. An input is not constant foldable if we
559   // have ran this function at least twice and the small input value changed
560   // across separate runs.
561   bool IsConstantFoldable(const DTensorOperation& doperation,
562                           const int input_index) const;
563 
564  private:
565   // Cache key for dtensor operation name, which includes the op name
566   // and the input shapes. This is needed as a higher level cache for constant
567   // folding.
568   const tensorflow::Fprint128 CacheKeyForDTensorOperation(
569       const DTensorOperation& doperation) const;
570 
571   // Generates a cache key for the graph, including its attributes,
572   // inputs, and outputs.
573   tensorflow::Fprint128 CacheKeyForGraph(
574       const DTensorOperation& doperation, const NameAttrList& attributes,
575       const std::vector<TensorWithLayout*>& inputs,
576       const std::vector<const Layout*>& output_layouts);
577 
578   // Maps the hash of a graph with the lowered graph.
579   absl::flat_hash_map<tensorflow::Fprint128, ExecutionFunctions,
580                       tensorflow::Fprint128Hasher>
581       function_cache_;
582 
583   // Maps the hash of dtensor_operation and its input shapes to a map
584   // representing the small constant indices and values to the function. The
585   // small constant indices are saved to make faster comparisons for constant
586   // folding validation.
587   absl::flat_hash_map<tensorflow::Fprint128, absl::flat_hash_map<int, NodeDef>,
588                       tensorflow::Fprint128Hasher>
589       dtensor_op_and_small_inputs_;
590 };
591 
592 // Returns the shape of a given tensor.
593 std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor,
594                                          TF_Status* status);
595 
596 // Creates a Graph with _Arg and _Retval nodes surrounding an
597 // `operation_name`-type node.
598 Status PrepareGraphForMlir(
599     const FunctionManager& function_manager,
600     const std::vector<TensorWithLayout*>& inputs,
601     const DTensorOperation& doperation,
602     const tensorflow::FunctionLibraryDefinition& flib_def,
603     const NameAttrList& attributes,
604     const absl::optional<Layout>& default_layout, tensorflow::Graph* graph,
605     std::vector<PartialTensorShape>* global_output_shapes,
606     std::vector<const Layout*>* output_layouts);
607 
608 // Returns set of functions to run to execute DTensor computation.
609 StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute(
610     const tensorflow::Graph& graph,
611     const std::vector<PartialTensorShape>& global_output_shapes);
612 
613 // For functions with control outputs, add identity nodes between
614 // StatefulPartitionedCall and _Retvals, in order to preserve control output
615 // dependencies after StatefulPartitionedCall is inlined at runtime.
616 // Consider calling this in PrepareGraphForMlir, once the identity nodes won't
617 // be dropped during MLIR lowering.
618 // TODO(b/171265131): fix the underlying issue to avoid inserting identity
619 // nodes.
620 Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph);
621 
622 // Add DTensor specific function attributes to be compatible with eager runtime.
623 void AddDTensorFunctionAttr(FunctionDef& function_def);
624 
625 // Prepare inputs of embeddings for checkpoint functions.
626 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs(
627     const std::vector<TensorWithLayout*>& inputs);
628 
629 Status InsertFunctionForTPUEmbeddingCheckpoint(
630     TF_Status* status, Graph* graph,
631     const std::vector<TensorWithLayout*>& inputs,
632     const std::string& checkpoint_fn_name);
633 
634 }  // namespace dtensor
635 }  // namespace tensorflow
636 
637 #endif  // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_UTIL_H_
638