• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_resource.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/sharding_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 
27 namespace tensorflow {
28 
KindToString(XlaResource::Kind kind)29 /*static*/ absl::string_view XlaResource::KindToString(XlaResource::Kind kind) {
30   switch (kind) {
31     case XlaResource::kInvalid:
32       return "invalid";
33     case XlaResource::kVariable:
34       return "variable";
35     case XlaResource::kStack:
36       return "stack";
37     case XlaResource::kTensorArray:
38       return "tensorarray";
39   }
40 }
41 
CreateStack(string name,DataType type,int64_t max_size)42 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
43     string name, DataType type, int64_t max_size) {
44   return absl::make_unique<XlaResource>(
45       XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
46       /*initial_value=*/xla::XlaOp(),
47       /*max_array_size=*/max_size,
48       /*tensor_array_gradients=*/std::set<string>{},
49       /*tensor_array_multiple_writes_aggregate=*/false);
50 }
51 
CreateTensorArray(string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size)52 /*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
53     string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
54     int64_t max_array_size) {
55   return absl::make_unique<XlaResource>(
56       XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
57       initial_value, max_array_size,
58       /*tensor_array_gradients=*/std::set<string>{},
59       /*tensor_array_multiple_writes_aggregate=*/false);
60 }
61 
XlaResource(Kind kind,int arg_num,string name,DataType type,TensorShape shape,xla::XlaOp initial_value,int64_t max_array_size,const std::set<string> & tensor_array_gradients,bool tensor_array_multiple_writes_aggregate,const absl::optional<ManagedStackTrace> & definition_stack_trace)62 XlaResource::XlaResource(
63     Kind kind, int arg_num, string name, DataType type, TensorShape shape,
64     xla::XlaOp initial_value, int64_t max_array_size,
65     const std::set<string>& tensor_array_gradients,
66     bool tensor_array_multiple_writes_aggregate,
67     const absl::optional<ManagedStackTrace>& definition_stack_trace)
68     : kind_(kind),
69       arg_num_(arg_num),
70       name_(std::move(name)),
71       type_(type),
72       shape_(std::move(shape)),
73       value_(initial_value),
74       initial_value_(initial_value),
75       max_array_size_(max_array_size),
76       tensor_array_multiple_writes_aggregate_(
77           tensor_array_multiple_writes_aggregate),
78       definition_stack_trace_(definition_stack_trace) {
79   CHECK(kind_ != kInvalid);
80 
81   for (const string& gradient : tensor_array_gradients) {
82     tensor_array_gradients_[gradient].reset(new XlaResource(
83         /*kind=*/kTensorArray, /*arg_num=*/-1,
84         /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
85         xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
86         /*tensor_array_multiple_writes_aggregate=*/true));
87   }
88 }
89 
SetTypeAndShape(DataType type,const TensorShape & shape)90 Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
91   if (type == DT_INVALID) {
92     return errors::InvalidArgument(
93         "Attempted to set type of resource '", name_, "'' to an invalid type",
94         DefinitionLocationMsg(definition_stack_trace_));
95   }
96   if (initialized() && type_ != type) {
97     return errors::InvalidArgument(
98         "Type of resource ", name_,
99         " cannot be changed after initialization: "
100         "old type was ",
101         DataTypeString(type_), ", new type is ", DataTypeString(type),
102         DefinitionLocationMsg(definition_stack_trace_));
103   }
104   if (initialized() && shape_ != shape) {
105     return errors::InvalidArgument(
106         "Shape of resource ", name_,
107         " cannot be changed after initialization: "
108         "old shape was ",
109         shape_.DebugString(), ", new shape is ", shape.DebugString(),
110         DefinitionLocationMsg(definition_stack_trace_));
111   }
112   type_ = type;
113   shape_ = shape;
114   return Status::OK();
115 }
116 
SetValue(const xla::XlaOp & value)117 Status XlaResource::SetValue(const xla::XlaOp& value) {
118   if (type_ == DT_INVALID) {
119     return errors::InvalidArgument(
120         "Resource '", name_,
121         "' must be initialized with a valid type before use.");
122   }
123   value_ = value;
124   is_overwritten_ = true;
125   return Status::OK();
126 }
127 
SetZeroValue(xla::XlaBuilder * builder)128 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
129   is_overwritten_ = true;
130   if (type_ == DT_INVALID) {
131     return errors::InvalidArgument(
132         "Resource '", name_,
133         "' must be initialized with a valid type before use.");
134   }
135   switch (kind_) {
136     case kVariable: {
137       value_ =
138           xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
139       break;
140     }
141     case kTensorArray: {
142       TensorShape ta_shape;
143       ta_shape.AddDim(max_array_size_);
144       ta_shape.AppendShape(shape_);
145       value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
146                               ta_shape.dim_sizes());
147       break;
148     }
149     case kStack: {
150       TensorShape ta_shape;
151       ta_shape.AddDim(max_array_size_);
152       ta_shape.AppendShape(shape_);
153       value_ =
154           xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
155                                               ta_shape.dim_sizes()),
156                                xla::ConstantR0<int32>(builder, 0)});
157       break;
158     }
159 
160     case kInvalid:
161     default:
162       LOG(FATAL) << "Invalid resource type";
163   }
164   return Status::OK();
165 }
166 
GetOrCreateTensorArrayGradient(const string & source,xla::XlaBuilder * builder,XlaResource ** gradient_out)167 Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
168                                                    xla::XlaBuilder* builder,
169                                                    XlaResource** gradient_out) {
170   VLOG(2) << "Gradient lookup for resource: " << name_
171           << " gradient: " << source;
172   TF_RET_CHECK(kind_ == kTensorArray);
173   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
174   if (!gradient) {
175     TensorShape ta_shape;
176     ta_shape.AddDim(max_array_size_);
177     ta_shape.AppendShape(shape_);
178     xla::XlaOp gradient_value =
179         xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
180     gradient.reset(
181         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
182                         /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
183                         type_, shape_, gradient_value, max_array_size_,
184                         /*tensor_array_gradients=*/{},
185                         /*tensor_array_multiple_writes_aggregate=*/true));
186   }
187   *gradient_out = gradient.get();
188   return Status::OK();
189 }
190 
Pack(xla::XlaOp * pack,xla::XlaBuilder * builder) const191 Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
192   if (tensor_array_gradients_.empty()) {
193     *pack = value_;
194   } else {
195     TF_RET_CHECK(kind_ == kTensorArray);
196     std::vector<xla::XlaOp> elems;
197     elems.push_back(value_);
198     for (const auto& gradient : tensor_array_gradients_) {
199       elems.push_back(gradient.second->value_);
200     }
201     *pack = xla::Tuple(builder, elems);
202   }
203   return Status::OK();
204 }
205 
SetFromPack(const std::set<string> & gradient_sources,const xla::XlaOp & pack,xla::XlaBuilder * builder)206 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
207                                 const xla::XlaOp& pack,
208                                 xla::XlaBuilder* builder) {
209   if (gradient_sources.empty()) {
210     if (!initialized()) {
211       initial_value_ = pack;
212     }
213     value_ = pack;
214   } else {
215     TF_RET_CHECK(kind_ == kTensorArray);
216     int pos = 0;
217     auto v = xla::GetTupleElement(pack, pos++);
218     if (!initialized()) {
219       initial_value_ = v;
220     }
221     value_ = v;
222 
223     for (const auto& source : gradient_sources) {
224       XlaResource* gradient;
225       TF_RETURN_IF_ERROR(
226           GetOrCreateTensorArrayGradient(source, builder, &gradient));
227       auto v = xla::GetTupleElement(pack, pos++);
228       if (!gradient->initialized()) {
229         gradient->initial_value_ = v;
230       }
231       gradient->value_ = v;
232     }
233   }
234   return Status::OK();
235 }
236 
237 }  // namespace tensorflow
238