• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 #include <vector>
19 
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/tensor_slice.h"
26 #include "tensorflow/core/util/saved_tensor_slice_util.h"
27 
28 namespace tensorflow {
29 namespace dtensor {
30 
31 using shape_inference::InferenceContext;
32 using shape_inference::ShapeHandle;
33 using shape_inference::UnchangedShape;
34 
35 // Change layout of input to target layout inside the same mesh cluster.
36 REGISTER_OP("Relayout")
37     .Input("input: T")
38     .Output("output: T")
39     .Attr("layout: string")
40     .Attr("T: type")
41     .SetShapeFn(UnchangedShape);
42 
43 // Copy `input` to the given mesh and layout.
44 REGISTER_OP("CopyToMesh")
45     .Input("input: T")
46     .Output("output: T")
47     .Attr("layout: string")
48     .Attr("source_layout: string = ''")
49     .Attr("T: type")
50     .SetShapeFn(UnchangedShape);
51 
52 // Queries the generated sharded prefix that is used to in SaveV2 op in a
53 // multi-client setup. Should take exact same inputs as the original SaveV2 is
54 // invoked or the value returned won't match the ones generated.
55 REGISTER_OP("DTensorShardedPrefix")
56     .Input("prefix: string")
57     .Input("tensor_names: string")
58     .Input("shape_and_slices: string")
59     .Input("mesh: string")
60     .Input("layouts: string")
61     .Input("tensors: dtypes")
62     .Attr("dtypes: list(type)")
63     .Output("output: string")
__anon1714fc320102(InferenceContext* c) 64     .SetShapeFn([](InferenceContext* c) {
65       // Always output a one d vector of strings.
66       // We could calculate the exact numbers of output here as well but that's
67       // the whole logic of the op itself.
68       c->set_output(0, c->Vector(c->UnknownDim()));
69       return OkStatus();
70     });
71 
72 // DTensorRestoreV2 that is pretty much RestoreV2 but with extra global shapes
73 // and layouts.
74 REGISTER_OP("DTensorRestoreV2")
75     .Input("prefix: string")
76     .Input("tensor_names: string")
77     .Input("shape_and_slices: string")
78     .Output("tensors: dtypes")
79     .Attr("input_shapes: list(shape)")
80     .Attr("input_layouts: list(string)")
81     .Attr("dtypes: list(type)")
82     .SetIsStateful()
__anon1714fc320202(InferenceContext* c) 83     .SetShapeFn([](InferenceContext* c) {
84       ShapeHandle shape0, shape1, shape2;
85       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
86       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
87       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
88       TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
89 
90       std::vector<PartialTensorShape> input_shapes;
91       TF_RETURN_IF_ERROR(c->GetAttr("input_shapes", &input_shapes));
92       std::vector<std::string> input_layouts;
93       TF_RETURN_IF_ERROR(c->GetAttr("input_layouts", &input_layouts));
94 
95       if (input_shapes.size() != input_layouts.size()) {
96         return errors::InvalidArgument(
97             "Size of input_shapes and input_layouts is expected to match, but "
98             "got ",
99             input_shapes.size(), " for input_shapes and ", input_layouts.size(),
100             " for input_layouts");
101       }
102 
103       // TODO(hthu): We should be able to infer from layout and global_shape
104       // field.
105       return UnknownShape(c);
106     });
107 
108 }  // namespace dtensor
109 }  // namespace tensorflow
110