• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 
27 namespace tensorflow {
28 
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30   switch (kind) {
31     case XlaResource::kInvalid:
32       return "invalid";
33     case XlaResource::kVariable:
34       return "variable";
35     case XlaResource::kStack:
36       return "stack";
37     case XlaResource::kTensorArray:
38       return "tensorarray";
39   }
40 }
41 
CreateStack(string name,DataType type,int64 max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43     string name, DataType type, int64 max_size) {
44   return absl::make_unique<XlaResource>(
45       XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46       /*initial_value=*/xla::XlaOp(),
47       /*max_array_size=*/max_size,
48       /*tensor_array_gradients=*/std::set<string>{},
49       /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51 
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64 max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53     string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54     int64 max_array_size) {
55   return absl::make_unique<XlaResource>(
56       XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57       initial_value, max_array_size,
58       /*tensor_array_gradients=*/std::set<string>{},
59       /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61 
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,const xla::XlaOp & initial_value,int64 max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate)62 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
63                          TensorShape shape, const xla::XlaOp& initial_value,
64                          int64 max_array_size,
65                          const std::set<string>& tensor_array_gradients,
66                          bool tensor_array_multiple_writes_aggregate)
67     : kind_(kind),
68       arg_num_(arg_num),
69       name_(std::move(name)),
70       type_(type),
71       shape_(std::move(shape)),
72       value_(initial_value),
73       initial_value_(initial_value),
74       max_array_size_(max_array_size),
75       tensor_array_multiple_writes_aggregate_(
76           tensor_array_multiple_writes_aggregate) {
77   CHECK(kind_ != kInvalid);
78 
79   for (const string& gradient : tensor_array_gradients) {
80     tensor_array_gradients_[gradient].reset(new XlaResource(
81         /*kind=*/kTensorArray, /*arg_num=*/-1,
82         /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
83         xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
84         /*tensor_array_multiple_writes_aggregate=*/true));
85   }
86 }
87 
SetTypeAndShape(DataType type,const TensorShape & shape)88 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
89   if (type == DT_INVALID) {
90     return errors::InvalidArgument("Attempted to set type of resource '", name_,
91                                    "'' to an invalid type");
92   }
93   if (initialized() && type_ != type) {
94     return errors::Unimplemented("Type of resource ", name_,
95                                  " cannot be changed after initialization: "
96                                  "old type was ",
97                                  DataTypeString(type_), ", new type is ",
98                                  DataTypeString(type));
99   }
100   if (initialized() && shape_ != shape) {
101     return errors::Unimplemented("Shape of resource ", name_,
102                                  " cannot be changed after initialization: "
103                                  "old shape was ",
104                                  shape_.DebugString(), ", new shape is ",
105                                  shape.DebugString());
106   }
107   type_ = type;
108   shape_ = shape;
109   return Status::OK();
110 }
111 
SetValue(const xla::XlaOp & value)112 Status XlaResource::SetValue(const xla::XlaOp& value) {
113   if (type_ == DT_INVALID) {
114     return errors::InvalidArgument(
115         "Resource '", name_,
116         "' must be initialized with a valid type before use.");
117   }
118   value_ = value;
119   is_overwritten_ = true;
120   return Status::OK();
121 }
122 
SetZeroValue(xla::XlaBuilder * builder)123 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
124   is_overwritten_ = true;
125   if (type_ == DT_INVALID) {
126     return errors::InvalidArgument(
127         "Resource '", name_,
128         "' must be initialized with a valid type before use.");
129   }
130   switch (kind_) {
131     case kVariable: {
132       value_ =
133           xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
134       break;
135     }
136     case kTensorArray: {
137       TensorShape ta_shape;
138       ta_shape.AddDim(max_array_size_);
139       ta_shape.AppendShape(shape_);
140       value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
141                               ta_shape.dim_sizes());
142       break;
143     }
144     case kStack: {
145       TensorShape ta_shape;
146       ta_shape.AddDim(max_array_size_);
147       ta_shape.AppendShape(shape_);
148       value_ =
149           xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
150                                               ta_shape.dim_sizes()),
151                                xla::ConstantR0<int32>(builder, 0)});
152       break;
153     }
154 
155     case kInvalid:
156     default:
157       LOG(FATAL) << "Invalid resource type";
158   }
159   return Status::OK();
160 }
161 
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)162 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
163                                                    xla::XlaBuilder* builder,
164                                                    XlaResource** gradient_out) {
165   VLOG(2) << "Gradient lookup for resource: " << name_
166           << " gradient: " << source;
167   TF_RET_CHECK(kind_ == kTensorArray);
168   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
169   if (!gradient) {
170     TensorShape ta_shape;
171     ta_shape.AddDim(max_array_size_);
172     ta_shape.AppendShape(shape_);
173     xla::XlaOp gradient_value =
174         xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
175     gradient.reset(
176         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
177                         /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
178                         type_, shape_, gradient_value, max_array_size_,
179                         /*tensor_array_gradients=*/{},
180                         /*tensor_array_multiple_writes_aggregate=*/true));
181   }
182   *gradient_out = gradient.get();
183   return Status::OK();
184 }
185 
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const186 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
187   if (tensor_array_gradients_.empty()) {
188     *pack = value_;
189   } else {
190     TF_RET_CHECK(kind_ == kTensorArray);
191     std::vector<xla::XlaOp> elems;
192     elems.push_back(value_);
193     for (const auto& gradient : tensor_array_gradients_) {
194       elems.push_back(gradient.second->value_);
195     }
196     *pack = xla::Tuple(builder, elems);
197   }
198   return Status::OK();
199 }
200 
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)201 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
202                                 const xla::XlaOp& pack,
203                                 xla::XlaBuilder* builder) {
204   if (gradient_sources.empty()) {
205     if (!initialized()) {
206       initial_value_ = pack;
207     }
208     value_ = pack;
209   } else {
210     TF_RET_CHECK(kind_ == kTensorArray);
211     int pos = 0;
212     auto v = xla::GetTupleElement(pack, pos++);
213     if (!initialized()) {
214       initial_value_ = v;
215     }
216     value_ = v;
217 
218     for (const auto& source : gradient_sources) {
219       XlaResource* gradient;
220       TF_RETURN_IF_ERROR(
221           GetOrCreateTensorArrayGradient(source, builder, &gradient));
222       auto v = xla::GetTupleElement(pack, pos++);
223       if (!gradient->initialized()) {
224         gradient->initial_value_ = v;
225       }
226       gradient->value_ = v;
227     }
228   }
229   return Status::OK();
230 }
231 
232 }  // namespace tensorflow
233