• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/lib/util.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 
31 namespace tensorflow {
32 
XlaScatter(const xla::XlaOp & buffer,const xla::XlaOp & updates,const xla::XlaOp & indices,bool indices_are_vectors,const std::function<xla::XlaOp (xla::XlaOp,xla::XlaOp,xla::XlaBuilder *)> & combiner,xla::XlaBuilder * builder)33 xla::StatusOr<xla::XlaOp> XlaScatter(
34     const xla::XlaOp& buffer, const xla::XlaOp& updates,
35     const xla::XlaOp& indices, bool indices_are_vectors,
36     const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
37         combiner,
38     xla::XlaBuilder* builder) {
39   TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
40   TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
41   TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
42   absl::Span<const int64> indices_dims =
43       xla::AsInt64Slice(indices_shape.dimensions());
44 
45   // If the indices are N-dimensional, the minor dimension of indices contains
46   // the indices to update. Otherwise the indices are all scalars.
47   int64 num_index_dims = 1;
48   if (indices_are_vectors) {
49     TF_RET_CHECK(!indices_dims.empty());
50     num_index_dims = indices_dims.back();
51     if (num_index_dims > buffer_shape.rank()) {
52       return errors::InvalidArgument(
53           "The size of the minor dimension of the indices (shape: ",
54           xla::ShapeUtil::HumanString(indices_shape),
55           ") must be <= the rank of the buffer (shape: ",
56           xla::ShapeUtil::HumanString(buffer_shape), ")");
57     }
58     indices_dims.remove_suffix(1);
59   }
60 
61   int64 num_indices = 1;
62   for (int64 dim : indices_dims) {
63     num_indices *= dim;
64   }
65 
66   // Degenerate case: nothing to update. Return the buffer unchanged.
67   if (num_indices == 0) {
68     return buffer;
69   }
70 
71   // If any of the indexed dimensions are zero in the buffer, the update cannot
72   // succeed since it updates a slice of size 1.
73   for (int64 i = 0; i < num_index_dims; ++i) {
74     if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) {
75       return errors::InvalidArgument("Scatter dimension ", i,
76                                      " is of size zero in tensor with shape ",
77                                      xla::ShapeUtil::HumanString(buffer_shape));
78     }
79   }
80 
81   // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
82   // shape [3,3]:
83   // NOTE: ***This case will not be generated by any of the tf.scatter ops.***
84   //
85   //   operand = s32[3,3] parameter(0)
86   //   indices = s32[2] parameter(1)
87   //   updates = s32[3,2] parameter(2)
88   //   scatter = s32[3,3] scatter(operand, indices, updates),
89   //       to_apply=update_computation,
90   //       update_window_dims={0},
91   //       inserted_window_dims={1},
92   //       scatter_dims_to_operand_dims={1},
93   //       index_vector_dim=1
94   //
95   //
96   // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
97   // shape [3,3]:
98   //
99   //   operand = s32[3,3] parameter(0)
100   //   indices = s32[2] parameter(1)
101   //   updates = s32[2,3] parameter(2)
102   //   scatter = s32[3,3] scatter(operand, indices, updates),
103   //       to_apply=update_computation,
104   //       update_window_dims={1},
105   //       inserted_window_dims={0},
106   //       scatter_dims_to_operand_dims={0},
107   //       index_vector_dim=1
108   //
109   //
110   // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
111   // shape [3,3,2]
112   //
113   //   operand = s32[3,3,2] parameter(0)
114   //   indices = s32[2,2] parameter(1)
115   //   updates = s32[2,2] parameter(2)
116   //   scatter = s32[3,3,2] scatter(operand, indices, updates),
117   //       to_apply=update_computation,
118   //       update_window_dims={1},
119   //       inserted_window_dims={0,1},
120   //       scatter_dims_to_operand_dims={0,1},
121   //       index_vector_dim=1
122   //
123   //
124   // Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
125   //
126   //   operand = s32[1,1] parameter(0)
127   //   indices = s32[1] parameter(1)
128   //   updates = s32[1] parameter(2)
129   //   scatter = s32[1,1] scatter(operand, indices, updates),
130   //       to_apply=update_computation,
131   //       update_window_dims={},
132   //       inserted_window_dims={0,1},
133   //       scatter_dims_to_operand_dims={0},
134   //       index_vector_dim=1
135   // Note that updates operand would be broadcasted into [1] in this case.
136   //
137 
138   xla::ScatterDimensionNumbers dim_numbers;
139   dim_numbers.set_index_vector_dim(indices_are_vectors
140                                        ? indices_shape.dimensions_size() - 1
141                                        : indices_shape.dimensions_size());
142 
143   int64 updates_rank = updates_shape.rank();
144   int64 buffer_rank = buffer_shape.rank();
145   int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
146 
147   // If the rank of `updates` is 0 and does not match the expected rank of
148   // updates, broadcast `updates` to the expected shape of updates.
149   auto new_updates = updates;
150   std::vector<int64> expected_updates_dims(indices_dims.begin(),
151                                            indices_dims.end());
152   for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
153     expected_updates_dims.push_back(buffer_shape.dimensions(dim));
154   }
155   int64 expected_updates_rank = expected_updates_dims.size();
156   if (updates_rank == 0 && expected_updates_rank != 0) {
157     new_updates = xla::Broadcast(updates, expected_updates_dims);
158     TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
159     updates_rank = updates_shape.rank();
160   }
161 
162   if (updates_rank > 0) {
163     for (int64 i = (updates_rank - num_window_dims_in_updates);
164          i < updates_rank; ++i) {
165       dim_numbers.add_update_window_dims(i);
166     }
167   }
168 
169   for (int64 i = 0; i < num_index_dims; ++i) {
170     dim_numbers.add_inserted_window_dims(i);
171     dim_numbers.add_scatter_dims_to_operand_dims(i);
172   }
173 
174   // Build the combiner computation.
175   xla::XlaComputation combiner_computation;
176   {
177     xla::XlaBuilder cb("scatter-combiner");
178     auto xla_scalar_shape =
179         xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
180     auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
181     auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
182     if (combiner) {
183       combiner(p0, p1, &cb);
184     }
185     combiner_computation = cb.Build().ConsumeValueOrDie();
186   }
187 
188   VLOG(3) << "Scatter op:";
189   VLOG(3) << "  Input: " << xla::ShapeUtil::HumanString(buffer_shape);
190   VLOG(3) << "  Indices: " << xla::ShapeUtil::HumanString(indices_shape);
191   VLOG(3) << "  Updates: " << xla::ShapeUtil::HumanString(updates_shape);
192   VLOG(3) << "  Scatter Dimension Numbers: ";
193   VLOG(3) << "    index_vector_dim: " << dim_numbers.index_vector_dim();
194   VLOG(3) << "    update_window_dims: ["
195           << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
196   VLOG(3) << "    inserted_window_dims: ["
197           << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
198   VLOG(3) << "    scatter_dims_to_operand_dims: ["
199           << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
200           << "]";
201 
202   return xla::Scatter(buffer, indices, new_updates, combiner_computation,
203                       dim_numbers);
204 }
205 
206 }  // namespace tensorflow
207