• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
16 #define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/kernels/tensor_map.h"
20 #include "tensorflow/core/util/batch_util.h"
21 #include "tensorflow/core/util/tensor_ops_util.h"
22 
23 namespace tensorflow {
24 
GetInputMap(OpKernelContext * ctx,int index,const TensorMap ** ret_map)25 Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
26   if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
27     return errors::InvalidArgument("Input map must be a scalar. Saw: ",
28                                    ctx->input(index).shape().DebugString());
29   }
30   const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
31   if (map == nullptr) {
32     return errors::InvalidArgument(
33         "Input handle is not a map. Saw: '",
34         ctx->input(index).scalar<Variant>()().DebugString(), "'");
35   }
36   *ret_map = map;
37   return Status::OK();
38 }
39 
40 // TODO(kattian): change into templated function
ForwardInputOrCreateNewMap(OpKernelContext * ctx,int32_t input_index,int32_t output_index,const TensorMap & input_map,TensorMap ** output_map)41 Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32_t input_index,
42                                   int32_t output_index,
43                                   const TensorMap& input_map,
44                                   TensorMap** output_map) {
45   // Attempt to forward the input tensor to the output if possible.
46   std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
47       input_index, output_index, DT_VARIANT, TensorShape{},
48       ctx->input_memory_type(input_index), AllocatorAttributes());
49   Tensor* output_tensor;
50   if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
51       maybe_output->NumElements() == 1) {
52     output_tensor = maybe_output.get();
53     TensorMap* tmp_out = output_tensor->scalar<Variant>()().get<TensorMap>();
54     if (tmp_out == nullptr) {
55       return errors::InvalidArgument(
56           "Expected input ", input_index, " to be a TensorMap but saw ",
57           output_tensor->scalar<Variant>()().TypeName());
58     }
59     if (tmp_out->RefCountIsOne()) {
60       // Woohoo, forwarding succeeded!
61       ctx->set_output(output_index, *output_tensor);
62       *output_map = tmp_out;
63       return Status::OK();
64     }
65   }
66 
67   // If forwarding is not possible allocate a new output tensor and copy
68   // the `input_map` to it.
69   AllocatorAttributes attr;
70   attr.set_on_host(true);
71   TF_RETURN_IF_ERROR(
72       ctx->allocate_output(output_index, {}, &output_tensor, attr));
73   output_tensor->scalar<Variant>()() = input_map.Copy();
74 
75   *output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
76   return Status::OK();
77 }
78 
79 class EmptyTensorMap : public OpKernel {
80  public:
EmptyTensorMap(OpKernelConstruction * ctx)81   explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
82 
Compute(OpKernelContext * ctx)83   void Compute(OpKernelContext* ctx) override {
84     Tensor* result;
85     AllocatorAttributes attr;
86     attr.set_on_host(true);
87     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
88     TensorMap empty;
89     result->scalar<Variant>()() = std::move(empty);
90   }
91 };
92 
93 class TensorMapSize : public OpKernel {
94  public:
TensorMapSize(OpKernelConstruction * ctx)95   explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapSize()96   ~TensorMapSize() override {}
97 
Compute(OpKernelContext * ctx)98   void Compute(OpKernelContext* ctx) override {
99     const TensorMap* map = nullptr;
100     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
101     Tensor* result;
102     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
103     result->scalar<int32>()() = map->tensors().size();
104   }
105 };
106 
107 class TensorMapLookup : public OpKernel {
108  public:
TensorMapLookup(OpKernelConstruction * ctx)109   explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapLookup()110   ~TensorMapLookup() override {}
111 
Compute(OpKernelContext * ctx)112   void Compute(OpKernelContext* ctx) override {
113     const TensorKey& key = ctx->input(1);
114     const TensorMap* map = nullptr;
115     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
116 
117     OP_REQUIRES(
118         ctx, map->tensors().find(key) != map->tensors().end(),
119         errors::InvalidArgument("Trying to lookup non-existent key. Could not "
120                                 "find key \"" +
121                                 key.SummarizeValue(100) + "\"."));
122 
123     ctx->set_output(0, map->tensors().find(key)->second);
124   }
125 };
126 
127 class TensorMapInsert : public OpKernel {
128  public:
TensorMapInsert(OpKernelConstruction * ctx)129   explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapInsert()130   ~TensorMapInsert() override {}
131 
Compute(OpKernelContext * ctx)132   void Compute(OpKernelContext* ctx) override {
133     const TensorKey& key = ctx->input(1);
134     const Tensor& value = ctx->input(2);
135     const TensorMap* map = nullptr;
136     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
137 
138     TensorMap* output_map = nullptr;
139     OP_REQUIRES_OK(ctx,
140                    ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
141     output_map->replace(key, value);
142   }
143 };
144 
145 class TensorMapErase : public OpKernel {
146  public:
TensorMapErase(OpKernelConstruction * ctx)147   explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
148 
Compute(OpKernelContext * ctx)149   void Compute(OpKernelContext* ctx) override {
150     const TensorKey& key = ctx->input(1);
151     const TensorMap* map = nullptr;
152     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
153 
154     OP_REQUIRES(
155         ctx, map->tensors().find(key) != map->tensors().end(),
156         errors::InvalidArgument("Trying to erase non-existent item. Could not "
157                                 "find key \"" +
158                                 key.SummarizeValue(100) + "\"."));
159 
160     TensorMap* output_map = nullptr;
161     OP_REQUIRES_OK(ctx,
162                    ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
163     output_map->tensors().erase(key);
164   }
165 };
166 
167 class TensorMapHasKey : public OpKernel {
168  public:
TensorMapHasKey(OpKernelConstruction * ctx)169   explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapHasKey()170   ~TensorMapHasKey() override {}
171 
Compute(OpKernelContext * ctx)172   void Compute(OpKernelContext* ctx) override {
173     const TensorKey& key = ctx->input(1);
174     const TensorMap* map = nullptr;
175     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
176     Tensor* result;
177     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
178     result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
179   }
180 };
181 
182 class TensorMapStackKeys : public OpKernel {
183  public:
TensorMapStackKeys(OpKernelConstruction * ctx)184   explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
185     OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
186   }
~TensorMapStackKeys()187   ~TensorMapStackKeys() override {}
188 
Compute(OpKernelContext * ctx)189   void Compute(OpKernelContext* ctx) override {
190     const TensorMap* map = nullptr;
191     OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
192 
193     OP_REQUIRES(ctx, map->size() != 0,
194                 errors::InvalidArgument(
195                     "TensorMapStackKeys cannot be called on empty map."));
196 
197     auto it = map->tensors().begin();
198     TensorShape output_shape = it->first.shape();
199     output_shape.InsertDim(0, map->tensors().size());
200     Tensor* result;
201     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
202 
203     int i = 0;
204     size_t sz = map->tensors().size();
205     TensorShape key_shape = it->first.shape();
206     while (it != map->tensors().end() && i < sz) {
207       OP_REQUIRES(
208           ctx, it->first.dtype() == key_dtype_,
209           errors::InvalidArgument("Key does not match requested dtype."));
210       OP_REQUIRES(
211           ctx, it->first.shape() == key_shape,
212           errors::InvalidArgument("Keys must all have the same shape."));
213       OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
214       i++;
215       it++;
216     }
217   }
218 
219  private:
220   DataType key_dtype_;
221 };
222 
223 template <typename Device>
TensorMapBinaryAdd(OpKernelContext * ctx,const TensorMap & a,const TensorMap & b,TensorMap * out)224 Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
225                           const TensorMap& b, TensorMap* out) {
226   // Binary add returns a map containing the union of keys.
227   // Values with keys in the intersection are added.
228   out->tensors() = a.tensors();
229   for (const std::pair<TensorKey, Tensor>& p : b.tensors()) {
230     absl::flat_hash_map<TensorKey, Tensor>::iterator it =
231         out->tensors().find(p.first);
232     if (it != out->tensors().end()) {
233       Tensor out_tensor;
234       TF_RETURN_IF_ERROR(
235           BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
236       it->second = out_tensor;
237     } else {
238       out->tensors().emplace(p.first, p.second);
239     }
240   }
241   return Status::OK();
242 }
243 
244 template <typename Device>
TensorMapZerosLike(OpKernelContext * ctx,const TensorMap & x,TensorMap * y)245 Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
246                           TensorMap* y) {
247   // Zeros like returns an empty map.
248   return Status::OK();
249 }
250 
251 }  // namespace tensorflow
252 
253 #endif  // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
254