1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <string> 18 #include <utility> 19 20 #include "absl/strings/str_format.h" 21 #include "fcp/base/random_token.h" 22 #include "fcp/tensorflow/external_dataset.h" 23 #include "fcp/tensorflow/status.h" 24 #include "tensorflow/core/framework/common_shape_fns.h" 25 #include "tensorflow/core/framework/dataset.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/shape_inference.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 #include "tensorflow/core/public/version.h" 30 31 namespace fcp { 32 33 /** 34 * ExternalDataset op-kernel. Delegates to an ExternalDatasetProvider, found 35 * from the ExternalDatasetProviderRegistry (a HostObjectRegistry). 36 * 37 * Inputs: 38 * selector: An opaque string scalar. Forwarded to the stub. 39 * token: String scalar. It should encode a token obtained from 40 * ExternalDatasetProviderRegistry::Register. 41 * 42 * See TensorFlow's guide to making custom dataset ops: 43 * https://www.tensorflow.org/guide/extend/formats 44 */ 45 class ExternalDatasetOp : public tensorflow::data::DatasetOpKernel { 46 public: 47 using tensorflow::data::DatasetOpKernel::DatasetOpKernel; 48 MakeDataset(tensorflow::OpKernelContext * ctx,tensorflow::data::DatasetBase ** output)49 void MakeDataset(tensorflow::OpKernelContext* ctx, 50 tensorflow::data::DatasetBase** output) override { 51 tensorflow::tstring token_str; 52 OP_REQUIRES_OK(ctx, 53 tensorflow::data::ParseScalarArgument<tensorflow::tstring>( 54 ctx, "token", &token_str)); 55 absl::Span<char const> token_bytes = token_str; 56 OP_REQUIRES(ctx, token_bytes.size() == kRandomTokenSizeInBytes, 57 tensorflow::errors::InvalidArgument(absl::StrFormat( 58 "Tokens have a fixed size. Expected: %d; Actual %d", 59 kRandomTokenSizeInBytes, token_bytes.size()))); 60 RandomToken token = RandomToken::FromBytes(token_bytes); 61 62 tensorflow::tstring selector_str; 63 OP_REQUIRES_OK(ctx, 64 tensorflow::data::ParseScalarArgument<tensorflow::tstring>( 65 ctx, "selector", &selector_str)); 66 67 std::optional<std::shared_ptr<ExternalDatasetProvider>> maybe_provider = 68 ExternalDatasetProviderRegistry::TryLookup(token); 69 OP_REQUIRES(ctx, maybe_provider.has_value(), 70 tensorflow::errors::InvalidArgument( 71 "A dataset provider is not currently registered for the " 72 "provided token: ", 73 token.ToPrintableString())); 74 75 std::shared_ptr<ExternalDatasetProvider> provider = 76 *std::move(maybe_provider); 77 StatusOr<std::unique_ptr<ExternalDataset>> maybe_dataset = 78 provider->MakeDataset(selector_str); 79 // The provider might not like the given selector. 80 if (!maybe_dataset.ok()) { 81 ctx->SetStatus(ConvertToTensorFlowStatus(maybe_dataset.status())); 82 return; 83 } 84 85 *output = new Dataset(ctx, std::move(maybe_dataset).value()); 86 } 87 88 private: 89 class Dataset : public tensorflow::data::DatasetBase { 90 public: Dataset(tensorflow::OpKernelContext * ctx,std::unique_ptr<ExternalDataset> stub)91 Dataset(tensorflow::OpKernelContext* ctx, 92 std::unique_ptr<ExternalDataset> stub) 93 : DatasetBase(tensorflow::data::DatasetContext(ctx)), 94 stub_(std::move(stub)) {} 95 MakeIteratorInternal(const std::string & prefix) const96 std::unique_ptr<tensorflow::data::IteratorBase> MakeIteratorInternal( 97 const std::string& prefix) const override { 98 std::unique_ptr<ExternalDatasetIterator> iter = stub_->MakeIterator(); 99 Iterator::Params params{ 100 this, tensorflow::strings::StrCat(prefix, "::ExternalDataset")}; 101 return std::unique_ptr<tensorflow::data::IteratorBase>( 102 new Iterator(params, std::move(iter))); 103 } 104 105 // Each iterator element is just a scalar string. 106 output_dtypes() const107 const tensorflow::DataTypeVector& output_dtypes() const override { 108 static auto* const dtypes = 109 new tensorflow::DataTypeVector({tensorflow::DT_STRING}); 110 return *dtypes; 111 } 112 output_shapes() const113 const std::vector<tensorflow::PartialTensorShape>& output_shapes() 114 const override { 115 static std::vector<tensorflow::PartialTensorShape>* shapes = 116 new std::vector<tensorflow::PartialTensorShape>({{}}); 117 return *shapes; 118 } 119 DebugString() const120 std::string DebugString() const override { 121 return "ExternalDatasetOp::Dataset"; 122 } 123 InputDatasets(std::vector<const DatasetBase * > * inputs) const124 tensorflow::Status InputDatasets( 125 std::vector<const DatasetBase*>* inputs) const override { 126 // ExternalDatast has no input datasets, so just return OK. 127 return tensorflow::OkStatus(); 128 } 129 130 // The `DatasetBase::CheckExternalState()` method was introduced on 8/7/2019. We 131 // use the `TF_GRAPH_DEF_VERSION` value (which is updated daily) to determine if 132 // we should add its override. 133 #if TF_GRAPH_DEF_VERSION > 125 CheckExternalState() const134 tensorflow::Status CheckExternalState() const override { 135 return tensorflow::OkStatus(); 136 } 137 #endif 138 139 protected: AsGraphDefInternal(tensorflow::data::SerializationContext * ctx,DatasetGraphDefBuilder * b,tensorflow::Node ** output) const140 tensorflow::Status AsGraphDefInternal( 141 tensorflow::data::SerializationContext* ctx, DatasetGraphDefBuilder* b, 142 tensorflow::Node** output) const override { 143 return ::tensorflow::errors::Unimplemented( 144 DebugString(), " does not support serialization."); 145 } 146 147 private: 148 class Iterator : public tensorflow::data::DatasetIterator<Dataset> { 149 public: Iterator(const Params & params,std::unique_ptr<ExternalDatasetIterator> stub)150 explicit Iterator(const Params& params, 151 std::unique_ptr<ExternalDatasetIterator> stub) 152 : DatasetIterator<Dataset>(params), stub_(std::move(stub)) {} 153 GetNextInternal(tensorflow::data::IteratorContext * ctx,std::vector<tensorflow::Tensor> * out_tensors,bool * end_of_sequence)154 tensorflow::Status GetNextInternal( 155 tensorflow::data::IteratorContext* ctx, 156 std::vector<tensorflow::Tensor>* out_tensors, 157 bool* end_of_sequence) override { 158 StatusOr<std::string> maybe_element; 159 { 160 absl::MutexLock _(&mu_); 161 maybe_element = stub_->GetNext(); 162 } 163 164 if (maybe_element.ok()) { 165 std::string element = std::move(maybe_element).value(); 166 167 // The {} at the end specifies a scalar tensor. 168 tensorflow::Tensor element_tensor(ctx->allocator({}), 169 tensorflow::DT_STRING, {}); 170 element_tensor.scalar<tensorflow::tstring>()() = element; 171 172 *end_of_sequence = false; 173 out_tensors->push_back(std::move(element_tensor)); 174 return tensorflow::OkStatus(); 175 } else { 176 *end_of_sequence = true; 177 if (maybe_element.status().code() == StatusCode::kOutOfRange) { 178 return tensorflow::OkStatus(); 179 } else { 180 return ConvertToTensorFlowStatus(maybe_element.status()); 181 } 182 } 183 } 184 185 protected: SaveInternal(tensorflow::data::SerializationContext * ctx,tensorflow::data::IteratorStateWriter * writer)186 tensorflow::Status SaveInternal( 187 // `::tensorflow::data::SerializationContext` argument was added on 188 // 2020-03-17 when `TF_GRAPH_DEF_VERSION` was defined to 343. 189 #if TF_GRAPH_DEF_VERSION > 343 190 tensorflow::data::SerializationContext* ctx, 191 #endif 192 tensorflow::data::IteratorStateWriter* writer) override { 193 return ::tensorflow::errors::Unimplemented( 194 "Save / Restore of an ExternalDataset iterator is not supported"); 195 } RestoreInternal(tensorflow::data::IteratorContext * ctx,tensorflow::data::IteratorStateReader * reader)196 tensorflow::Status RestoreInternal( 197 tensorflow::data::IteratorContext* ctx, 198 tensorflow::data::IteratorStateReader* reader) override { 199 return ::tensorflow::errors::Unimplemented( 200 "Save / Restore of an ExternalDataset iterator is not supported"); 201 } 202 203 private: 204 std::unique_ptr<ExternalDatasetIterator> stub_; 205 absl::Mutex mu_; 206 }; 207 208 // Private members of Dataset 209 210 std::unique_ptr<ExternalDataset> stub_; 211 }; 212 }; 213 214 REGISTER_OP("ExternalDataset") 215 .Input("token: string") 216 .Input("selector: string") 217 .Output("handle: variant") 218 .SetIsStateful() 219 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 220 221 REGISTER_KERNEL_BUILDER(Name("ExternalDataset").Device(tensorflow::DEVICE_CPU), 222 ExternalDatasetOp); 223 224 } // namespace fcp 225