• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 
28 namespace tensorflow {
29 
KindToString(XlaResource::Kind kind)30 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
31   switch (kind) {
32     case XlaResource::kInvalid:
33       return "invalid";
34     case XlaResource::kVariable:
35       return "variable";
36     case XlaResource::kStack:
37       return "stack";
38     case XlaResource::kTensorArray:
39       return "tensorarray";
40   }
41 }
42 
CreateStack(string name,DataType type,int64 max_size)43 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
44     string name, DataType type, int64 max_size) {
45   return absl::make_unique<XlaResource>(
46       XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
47       /*initial_value=*/xla::XlaOp(),
48       /*max_array_size=*/max_size,
49       /*tensor_array_gradients=*/std::set<string>{},
50       /*tensor_array_multiple_writes_aggregate=*/false);
51 }
52 
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64 max_array_size)53 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
54     string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
55     int64 max_array_size) {
56   return absl::make_unique<XlaResource>(
57       XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
58       initial_value, max_array_size,
59       /*tensor_array_gradients=*/std::set<string>{},
60       /*tensor_array_multiple_writes_aggregate=*/false);
61 }
62 
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,const xla::XlaOp & initial_value,int64 max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate)63 XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
64                          TensorShape shape, const xla::XlaOp& initial_value,
65                          int64 max_array_size,
66                          const std::set<string>& tensor_array_gradients,
67                          bool tensor_array_multiple_writes_aggregate)
68     : kind_(kind),
69       arg_num_(arg_num),
70       name_(std::move(name)),
71       type_(type),
72       shape_(std::move(shape)),
73       value_(initial_value),
74       initial_value_(initial_value),
75       max_array_size_(max_array_size),
76       tensor_array_multiple_writes_aggregate_(
77           tensor_array_multiple_writes_aggregate) {
78   CHECK(kind_ != kInvalid);
79 
80   for (const string& gradient : tensor_array_gradients) {
81     tensor_array_gradients_[gradient].reset(new XlaResource(
82         /*kind=*/kTensorArray, /*arg_num=*/-1,
83         /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
84         xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
85         /*tensor_array_multiple_writes_aggregate=*/true));
86   }
87 }
88 
SetTypeAndShape(DataType type,const TensorShape & shape)89 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
90   if (type == DT_INVALID) {
91     return errors::InvalidArgument("Attempted to set type of resource '", name_,
92                                    "'' to an invalid type");
93   }
94   if (initialized() && type_ != type) {
95     return errors::InvalidArgument("Type of resource ", name_,
96                                    " cannot be changed after initialization: "
97                                    "old type was ",
98                                    DataTypeString(type_), ", new type is ",
99                                    DataTypeString(type));
100   }
101   if (initialized() && shape_ != shape) {
102     return errors::InvalidArgument("Shape of resource ", name_,
103                                    " cannot be changed after initialization: "
104                                    "old shape was ",
105                                    shape_.DebugString(), ", new shape is ",
106                                    shape.DebugString());
107   }
108   type_ = type;
109   shape_ = shape;
110   return Status::OK();
111 }
112 
SetValue(const xla::XlaOp & value)113 Status XlaResource::SetValue(const xla::XlaOp& value) {
114   if (type_ == DT_INVALID) {
115     return errors::InvalidArgument(
116         "Resource '", name_,
117         "' must be initialized with a valid type before use.");
118   }
119   value_ = value;
120   return Status::OK();
121 }
122 
SetZeroValue(xla::XlaBuilder * builder)123 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
124   if (type_ == DT_INVALID) {
125     return errors::InvalidArgument(
126         "Resource '", name_,
127         "' must be initialized with a valid type before use.");
128   }
129   switch (kind_) {
130     case kVariable: {
131       value_ =
132           xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
133       break;
134     }
135     case kTensorArray: {
136       TensorShape ta_shape;
137       ta_shape.AddDim(max_array_size_);
138       ta_shape.AppendShape(shape_);
139       value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
140                               ta_shape.dim_sizes());
141       break;
142     }
143     case kStack: {
144       TensorShape ta_shape;
145       ta_shape.AddDim(max_array_size_);
146       ta_shape.AppendShape(shape_);
147       value_ =
148           xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
149                                               ta_shape.dim_sizes()),
150                                xla::ConstantR0<int32>(builder, 0)});
151       break;
152     }
153 
154     case kInvalid:
155     default:
156       LOG(FATAL) << "Invalid resource type";
157   }
158   return Status::OK();
159 }
160 
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)161 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
162                                                    xla::XlaBuilder* builder,
163                                                    XlaResource** gradient_out) {
164   VLOG(2) << "Gradient lookup for resource: " << name_
165           << " gradient: " << source;
166   TF_RET_CHECK(kind_ == kTensorArray);
167   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
168   if (!gradient) {
169     TensorShape ta_shape;
170     ta_shape.AddDim(max_array_size_);
171     ta_shape.AppendShape(shape_);
172     xla::XlaOp gradient_value =
173         xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
174     gradient.reset(
175         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
176                         /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
177                         type_, shape_, gradient_value, max_array_size_,
178                         /*tensor_array_gradients=*/{},
179                         /*tensor_array_multiple_writes_aggregate=*/true));
180   }
181   *gradient_out = gradient.get();
182   return Status::OK();
183 }
184 
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const185 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
186   if (tensor_array_gradients_.empty()) {
187     *pack = value_;
188   } else {
189     TF_RET_CHECK(kind_ == kTensorArray);
190     std::vector<xla::XlaOp> elems;
191     elems.push_back(value_);
192     for (const auto& gradient : tensor_array_gradients_) {
193       elems.push_back(gradient.second->value_);
194     }
195     *pack = xla::Tuple(builder, elems);
196   }
197   return Status::OK();
198 }
199 
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)200 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
201                                 const xla::XlaOp& pack,
202                                 xla::XlaBuilder* builder) {
203   if (gradient_sources.empty()) {
204     if (!initialized()) {
205       initial_value_ = pack;
206     }
207     value_ = pack;
208   } else {
209     TF_RET_CHECK(kind_ == kTensorArray);
210     int pos = 0;
211     auto v = xla::GetTupleElement(pack, pos++);
212     if (!initialized()) {
213       initial_value_ = v;
214     }
215     value_ = v;
216 
217     for (const auto& source : gradient_sources) {
218       XlaResource* gradient;
219       TF_RETURN_IF_ERROR(
220           GetOrCreateTensorArrayGradient(source, builder, &gradient));
221       auto v = xla::GetTupleElement(pack, pos++);
222       if (!gradient->initialized()) {
223         gradient->initial_value_ = v;
224       }
225       gradient->value_ = v;
226     }
227   }
228   return Status::OK();
229 }
230 
231 }  // namespace tensorflow
232