• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/lib/core/stringpiece.h"
19 #include "tensorflow/core/lib/gtl/map_util.h"
20 #include "tensorflow/core/lib/random/distribution_sampler.h"
21 #include "tensorflow/core/lib/random/philox_random.h"
22 #include "tensorflow/core/lib/random/simple_philox.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/core/util/guarded_philox_random.h"
26 
27 namespace tensorflow {
28 
29 // Number of examples to precalculate.
30 const int kPrecalc = 3000;
31 // Number of words to read into a sentence before processing.
32 const int kSentenceSize = 1000;
33 
34 namespace {
35 
ScanWord(StringPiece * input,string * word)36 bool ScanWord(StringPiece* input, string* word) {
37   str_util::RemoveLeadingWhitespace(input);
38   StringPiece tmp;
39   if (str_util::ConsumeNonWhitespace(input, &tmp)) {
40     word->assign(tmp.data(), tmp.size());
41     return true;
42   } else {
43     return false;
44   }
45 }
46 
47 }  // end namespace
48 
49 class SkipgramOp : public OpKernel {
50  public:
SkipgramOp(OpKernelConstruction * ctx)51   explicit SkipgramOp(OpKernelConstruction* ctx)
52       : OpKernel(ctx), rng_(&philox_) {
53     string filename;
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("filename", &filename));
55     OP_REQUIRES_OK(ctx, ctx->GetAttr("batch_size", &batch_size_));
56     OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size", &window_size_));
57     OP_REQUIRES_OK(ctx, ctx->GetAttr("min_count", &min_count_));
58     OP_REQUIRES_OK(ctx, ctx->GetAttr("subsample", &subsample_));
59     OP_REQUIRES_OK(ctx, Init(ctx->env(), filename));
60 
61     mutex_lock l(mu_);
62     example_pos_ = corpus_size_;
63     label_pos_ = corpus_size_;
64     label_limit_ = corpus_size_;
65     sentence_index_ = kSentenceSize;
66     for (int i = 0; i < kPrecalc; ++i) {
67       NextExample(&precalc_examples_[i].input, &precalc_examples_[i].label);
68     }
69   }
70 
Compute(OpKernelContext * ctx)71   void Compute(OpKernelContext* ctx) override {
72     Tensor words_per_epoch(DT_INT64, TensorShape({}));
73     Tensor current_epoch(DT_INT32, TensorShape({}));
74     Tensor total_words_processed(DT_INT64, TensorShape({}));
75     Tensor examples(DT_INT32, TensorShape({batch_size_}));
76     auto Texamples = examples.flat<int32>();
77     Tensor labels(DT_INT32, TensorShape({batch_size_}));
78     auto Tlabels = labels.flat<int32>();
79     {
80       mutex_lock l(mu_);
81       for (int i = 0; i < batch_size_; ++i) {
82         Texamples(i) = precalc_examples_[precalc_index_].input;
83         Tlabels(i) = precalc_examples_[precalc_index_].label;
84         precalc_index_++;
85         if (precalc_index_ >= kPrecalc) {
86           precalc_index_ = 0;
87           for (int j = 0; j < kPrecalc; ++j) {
88             NextExample(&precalc_examples_[j].input,
89                         &precalc_examples_[j].label);
90           }
91         }
92       }
93       words_per_epoch.scalar<int64>()() = corpus_size_;
94       current_epoch.scalar<int32>()() = current_epoch_;
95       total_words_processed.scalar<int64>()() = total_words_processed_;
96     }
97     ctx->set_output(0, word_);
98     ctx->set_output(1, freq_);
99     ctx->set_output(2, words_per_epoch);
100     ctx->set_output(3, current_epoch);
101     ctx->set_output(4, total_words_processed);
102     ctx->set_output(5, examples);
103     ctx->set_output(6, labels);
104   }
105 
106  private:
107   struct Example {
108     int32 input;
109     int32 label;
110   };
111 
112   int32 batch_size_ = 0;
113   int32 window_size_ = 5;
114   float subsample_ = 1e-3;
115   int min_count_ = 5;
116   int32 vocab_size_ = 0;
117   Tensor word_;
118   Tensor freq_;
119   int64 corpus_size_ = 0;
120   std::vector<int32> corpus_;
121   std::vector<Example> precalc_examples_;
122   int precalc_index_ = 0;
123   std::vector<int32> sentence_;
124   int sentence_index_ = 0;
125 
126   mutex mu_;
127   random::PhiloxRandom philox_ GUARDED_BY(mu_);
128   random::SimplePhilox rng_ GUARDED_BY(mu_);
129   int32 current_epoch_ GUARDED_BY(mu_) = -1;
130   int64 total_words_processed_ GUARDED_BY(mu_) = 0;
131   int32 example_pos_ GUARDED_BY(mu_);
132   int32 label_pos_ GUARDED_BY(mu_);
133   int32 label_limit_ GUARDED_BY(mu_);
134 
135   // {example_pos_, label_pos_} is the cursor for the next example.
136   // example_pos_ wraps around at the end of corpus_. For each
137   // example, we randomly generate [label_pos_, label_limit) for
138   // labels.
NextExample(int32 * example,int32 * label)139   void NextExample(int32* example, int32* label) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
140     while (true) {
141       if (label_pos_ >= label_limit_) {
142         ++total_words_processed_;
143         ++sentence_index_;
144         if (sentence_index_ >= kSentenceSize) {
145           sentence_index_ = 0;
146           for (int i = 0; i < kSentenceSize; ++i, ++example_pos_) {
147             if (example_pos_ >= corpus_size_) {
148               ++current_epoch_;
149               example_pos_ = 0;
150             }
151             if (subsample_ > 0) {
152               int32 word_freq = freq_.flat<int32>()(corpus_[example_pos_]);
153               // See Eq. 5 in http://arxiv.org/abs/1310.4546
154               float keep_prob =
155                   (std::sqrt(word_freq / (subsample_ * corpus_size_)) + 1) *
156                   (subsample_ * corpus_size_) / word_freq;
157               if (rng_.RandFloat() > keep_prob) {
158                 i--;
159                 continue;
160               }
161             }
162             sentence_[i] = corpus_[example_pos_];
163           }
164         }
165         const int32 skip = 1 + rng_.Uniform(window_size_);
166         label_pos_ = std::max<int32>(0, sentence_index_ - skip);
167         label_limit_ =
168             std::min<int32>(kSentenceSize, sentence_index_ + skip + 1);
169       }
170       if (sentence_index_ != label_pos_) {
171         break;
172       }
173       ++label_pos_;
174     }
175     *example = sentence_[sentence_index_];
176     *label = sentence_[label_pos_++];
177   }
178 
Init(Env * env,const string & filename)179   Status Init(Env* env, const string& filename) {
180     string data;
181     TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data));
182     StringPiece input = data;
183     string w;
184     corpus_size_ = 0;
185     std::unordered_map<string, int32> word_freq;
186     while (ScanWord(&input, &w)) {
187       ++(word_freq[w]);
188       ++corpus_size_;
189     }
190     if (corpus_size_ < window_size_ * 10) {
191       return errors::InvalidArgument(
192           "The text file ", filename,
193           " contains too little data: ", corpus_size_, " words");
194     }
195     typedef std::pair<string, int32> WordFreq;
196     std::vector<WordFreq> ordered;
197     for (const auto& p : word_freq) {
198       if (p.second >= min_count_) ordered.push_back(p);
199     }
200     LOG(INFO) << "Data file: " << filename << " contains " << data.size()
201               << " bytes, " << corpus_size_ << " words, " << word_freq.size()
202               << " unique words, " << ordered.size()
203               << " unique frequent words.";
204     word_freq.clear();
205     std::sort(ordered.begin(), ordered.end(),
206               [](const WordFreq& x, const WordFreq& y) {
207                 return x.second > y.second;
208               });
209     vocab_size_ = static_cast<int32>(1 + ordered.size());
210     Tensor word(DT_STRING, TensorShape({vocab_size_}));
211     Tensor freq(DT_INT32, TensorShape({vocab_size_}));
212     word.flat<string>()(0) = "UNK";
213     static const int32 kUnkId = 0;
214     std::unordered_map<string, int32> word_id;
215     int64 total_counted = 0;
216     for (std::size_t i = 0; i < ordered.size(); ++i) {
217       const auto& w = ordered[i].first;
218       auto id = i + 1;
219       word.flat<string>()(id) = w;
220       auto word_count = ordered[i].second;
221       freq.flat<int32>()(id) = word_count;
222       total_counted += word_count;
223       word_id[w] = id;
224     }
225     freq.flat<int32>()(kUnkId) = corpus_size_ - total_counted;
226     word_ = word;
227     freq_ = freq;
228     corpus_.reserve(corpus_size_);
229     input = data;
230     while (ScanWord(&input, &w)) {
231       corpus_.push_back(gtl::FindWithDefault(word_id, w, kUnkId));
232     }
233     precalc_examples_.resize(kPrecalc);
234     sentence_.resize(kSentenceSize);
235     return Status::OK();
236   }
237 };
238 
239 REGISTER_KERNEL_BUILDER(Name("Skipgram").Device(DEVICE_CPU), SkipgramOp);
240 
241 class NegTrainOp : public OpKernel {
242  public:
NegTrainOp(OpKernelConstruction * ctx)243   explicit NegTrainOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
244     base_.Init(0, 0);
245 
246     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_negative_samples", &num_samples_));
247 
248     std::vector<int32> vocab_count;
249     OP_REQUIRES_OK(ctx, ctx->GetAttr("vocab_count", &vocab_count));
250 
251     std::vector<float> vocab_weights;
252     vocab_weights.reserve(vocab_count.size());
253     for (const auto& f : vocab_count) {
254       float r = std::pow(static_cast<float>(f), 0.75f);
255       vocab_weights.push_back(r);
256     }
257     sampler_ = new random::DistributionSampler(vocab_weights);
258   }
259 
~NegTrainOp()260   ~NegTrainOp() override { delete sampler_; }
261 
Compute(OpKernelContext * ctx)262   void Compute(OpKernelContext* ctx) override {
263     Tensor w_in = ctx->mutable_input(0, false);
264     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(w_in.shape()),
265                 errors::InvalidArgument("Must be a matrix"));
266     Tensor w_out = ctx->mutable_input(1, false);
267     OP_REQUIRES(ctx, w_in.shape() == w_out.shape(),
268                 errors::InvalidArgument("w_in.shape == w_out.shape"));
269     const Tensor& examples = ctx->input(2);
270     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(examples.shape()),
271                 errors::InvalidArgument("Must be a vector"));
272     const Tensor& labels = ctx->input(3);
273     OP_REQUIRES(ctx, examples.shape() == labels.shape(),
274                 errors::InvalidArgument("examples.shape == labels.shape"));
275     const Tensor& learning_rate = ctx->input(4);
276     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(learning_rate.shape()),
277                 errors::InvalidArgument("Must be a scalar"));
278 
279     auto Tw_in = w_in.matrix<float>();
280     auto Tw_out = w_out.matrix<float>();
281     auto Texamples = examples.flat<int32>();
282     auto Tlabels = labels.flat<int32>();
283     auto lr = learning_rate.scalar<float>()();
284     const int64 vocab_size = w_in.dim_size(0);
285     const int64 dims = w_in.dim_size(1);
286     const int64 batch_size = examples.dim_size(0);
287     OP_REQUIRES(ctx, vocab_size == sampler_->num(),
288                 errors::InvalidArgument("vocab_size mismatches: ", vocab_size,
289                                         " vs. ", sampler_->num()));
290 
291     // Gradient accumulator for v_in.
292     Tensor buf(DT_FLOAT, TensorShape({dims}));
293     auto Tbuf = buf.flat<float>();
294 
295     // Scalar buffer to hold sigmoid(+/- dot).
296     Tensor g_buf(DT_FLOAT, TensorShape({}));
297     auto g = g_buf.scalar<float>();
298 
299     // The following loop needs 2 random 32-bit values per negative
300     // sample.  We reserve 8 values per sample just in case the
301     // underlying implementation changes.
302     auto rnd = base_.ReserveSamples32(batch_size * num_samples_ * 8);
303     random::SimplePhilox srnd(&rnd);
304 
305     for (int64 i = 0; i < batch_size; ++i) {
306       const int32 example = Texamples(i);
307       DCHECK(0 <= example && example < vocab_size) << example;
308       const int32 label = Tlabels(i);
309       DCHECK(0 <= label && label < vocab_size) << label;
310       auto v_in = Tw_in.chip<0>(example);
311 
312       // Positive: example predicts label.
313       //   forward: x = v_in' * v_out
314       //            l = log(sigmoid(x))
315       //   backward: dl/dx = g = sigmoid(-x)
316       //             dl/d(v_in) = g * v_out'
317       //             dl/d(v_out) = v_in' * g
318       {
319         auto v_out = Tw_out.chip<0>(label);
320         auto dot = (v_in * v_out).sum();
321         g = (dot.exp() + 1.f).inverse();
322         Tbuf = v_out * (g() * lr);
323         v_out += v_in * (g() * lr);
324       }
325 
326       // Negative samples:
327       //   forward: x = v_in' * v_sample
328       //            l = log(sigmoid(-x))
329       //   backward: dl/dx = g = -sigmoid(x)
330       //             dl/d(v_in) = g * v_out'
331       //             dl/d(v_out) = v_in' * g
332       for (int j = 0; j < num_samples_; ++j) {
333         const int sample = sampler_->Sample(&srnd);
334         if (sample == label) continue;  // Skip.
335         auto v_sample = Tw_out.chip<0>(sample);
336         auto dot = (v_in * v_sample).sum();
337         g = -((-dot).exp() + 1.f).inverse();
338         Tbuf += v_sample * (g() * lr);
339         v_sample += v_in * (g() * lr);
340       }
341 
342       // Applies the gradient on v_in.
343       v_in += Tbuf;
344     }
345   }
346 
347  private:
348   int32 num_samples_ = 0;
349   random::DistributionSampler* sampler_ = nullptr;
350   GuardedPhiloxRandom base_;
351 };
352 
353 REGISTER_KERNEL_BUILDER(Name("NegTrain").Device(DEVICE_CPU), NegTrainOp);
354 
355 }  // end namespace tensorflow
356