• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Model evaluation tools for TFGAN.
16
17These methods come from https://arxiv.org/abs/1606.03498 and
18https://arxiv.org/abs/1706.08500.
19
20NOTE: This implementation uses the same weights as in
21https://github.com/openai/improved-gan/blob/master/inception_score/model.py,
22but is more numerically stable and is an unbiased estimator of the true
23Inception score even when splitting the inputs into batches.
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30import functools
31import os
32import sys
33import tarfile
34
35from six.moves import urllib
36
37from tensorflow.contrib.layers.python.layers import layers
38from tensorflow.core.framework import graph_pb2
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import importer
41from tensorflow.python.framework import ops
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import functional_ops
44from tensorflow.python.ops import image_ops
45from tensorflow.python.ops import linalg_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import nn_ops
48from tensorflow.python.platform import gfile
49from tensorflow.python.platform import resource_loader
50
51
52__all__ = [
53    'get_graph_def_from_disk',
54    'get_graph_def_from_resource',
55    'get_graph_def_from_url_tarball',
56    'preprocess_image',
57    'run_image_classifier',
58    'run_inception',
59    'inception_score',
60    'classifier_score',
61    'classifier_score_from_logits',
62    'frechet_inception_distance',
63    'frechet_classifier_distance',
64    'frechet_classifier_distance_from_activations',
65    'INCEPTION_DEFAULT_IMAGE_SIZE',
66]
67
68
69INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
70INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
71INCEPTION_INPUT = 'Mul:0'
72INCEPTION_OUTPUT = 'logits:0'
73INCEPTION_FINAL_POOL = 'pool_3:0'
74INCEPTION_DEFAULT_IMAGE_SIZE = 299
75
76
77def _validate_images(images, image_size):
78  images = ops.convert_to_tensor(images)
79  images.shape.with_rank(4)
80  images.shape.assert_is_compatible_with(
81      [None, image_size, image_size, None])
82  return images
83
84
85def _symmetric_matrix_square_root(mat, eps=1e-10):
86  """Compute square root of a symmetric matrix.
87
88  Note that this is different from an elementwise square root. We want to
89  compute M' where M' = sqrt(mat) such that M' * M' = mat.
90
91  Also note that this method **only** works for symmetric matrices.
92
93  Args:
94    mat: Matrix to take the square root of.
95    eps: Small epsilon such that any element less than eps will not be square
96      rooted to guard against numerical instability.
97
98  Returns:
99    Matrix square root of mat.
100  """
101  # Unlike numpy, tensorflow's return order is (s, u, v)
102  s, u, v = linalg_ops.svd(mat)
103  # sqrt is unstable around 0, just use 0 in such case
104  si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
105  # Note that the v returned by Tensorflow is v = V
106  # (when referencing the equation A = U S V^T)
107  # This is unlike Numpy which returns v = V^T
108  return math_ops.matmul(
109      math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
110
111
112def preprocess_image(
113    images, height=INCEPTION_DEFAULT_IMAGE_SIZE,
114    width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None):
115  """Prepare a batch of images for evaluation.
116
117  This is the preprocessing portion of the graph from
118  http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
119
120  Note that it expects Tensors in [0, 255]. This function maps pixel values to
121  [-1, 1] and resizes to match the InceptionV1 network.
122
123  Args:
124    images: 3-D or 4-D Tensor of images. Values are in [0, 255].
125    height: Integer. Height of resized output image.
126    width: Integer. Width of resized output image.
127    scope: Optional scope for name_scope.
128
129  Returns:
130    3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1].
131  """
132  is_single = images.shape.ndims == 3
133  with ops.name_scope(scope, 'preprocess', [images, height, width]):
134    if not images.dtype.is_floating:
135      images = math_ops.to_float(images)
136    if is_single:
137      images = array_ops.expand_dims(images, axis=0)
138    resized = image_ops.resize_bilinear(images, [height, width])
139    resized = (resized - 128.0) / 128.0
140    if is_single:
141      resized = array_ops.squeeze(resized, axis=0)
142    return resized
143
144
145def _kl_divergence(p, p_logits, q):
146  """Computes the Kullback-Liebler divergence between p and q.
147
148  This function uses p's logits in some places to improve numerical stability.
149
150  Specifically:
151
152  KL(p || q) = sum[ p * log(p / q) ]
153    = sum[ p * ( log(p)                - log(q) ) ]
154    = sum[ p * ( log_softmax(p_logits) - log(q) ) ]
155
156  Args:
157    p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
158      example and `j` corresponds to the probability of being in class `j`.
159    p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
160    q: A 1-D floating-point Tensor, where q_j corresponds to the probability
161      of class `j`.
162
163  Returns:
164    KL divergence between two distributions. Output dimension is 1D, one entry
165    per distribution in `p`.
166
167  Raises:
168    ValueError: If any of the inputs aren't floating-point.
169    ValueError: If p or p_logits aren't 2D.
170    ValueError: If q isn't 1D.
171  """
172  for tensor in [p, p_logits, q]:
173    if not tensor.dtype.is_floating:
174      raise ValueError('Input %s must be floating type.', tensor.name)
175  p.shape.assert_has_rank(2)
176  p_logits.shape.assert_has_rank(2)
177  q.shape.assert_has_rank(1)
178  return math_ops.reduce_sum(
179      p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
180
181
182def get_graph_def_from_disk(filename):
183  """Get a GraphDef proto from a disk location."""
184  with gfile.FastGFile(filename, 'rb') as f:
185    return graph_pb2.GraphDef.FromString(f.read())
186
187
188def get_graph_def_from_resource(filename):
189  """Get a GraphDef proto from within a .par file."""
190  return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename))
191
192
193def get_graph_def_from_url_tarball(url, filename, tar_filename=None):
194  """Get a GraphDef proto from a tarball on the web.
195
196  Args:
197    url: Web address of tarball
198    filename: Filename of graph definition within tarball
199    tar_filename: Temporary download filename (None = always download)
200
201  Returns:
202    A GraphDef loaded from a file in the downloaded tarball.
203  """
204  if not (tar_filename and os.path.exists(tar_filename)):
205
206    def _progress(count, block_size, total_size):
207      sys.stdout.write('\r>> Downloading %s %.1f%%' %
208                       (url,
209                        float(count * block_size) / float(total_size) * 100.0))
210      sys.stdout.flush()
211
212    tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress)
213  with tarfile.open(tar_filename, 'r:gz') as tar:
214    proto_str = tar.extractfile(filename).read()
215  return graph_pb2.GraphDef.FromString(proto_str)
216
217
218def _default_graph_def_fn():
219  return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH,
220                                        os.path.basename(INCEPTION_URL))
221
222
223def run_inception(images,
224                  graph_def=None,
225                  default_graph_def_fn=_default_graph_def_fn,
226                  image_size=INCEPTION_DEFAULT_IMAGE_SIZE,
227                  input_tensor=INCEPTION_INPUT,
228                  output_tensor=INCEPTION_OUTPUT):
229  """Run images through a pretrained Inception classifier.
230
231  Args:
232    images: Input tensors. Must be [batch, height, width, channels]. Input shape
233      and values must be in [-1, 1], which can be achieved using
234      `preprocess_image`.
235    graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
236      call `default_graph_def_fn` to get GraphDef.
237    default_graph_def_fn: A function that returns a GraphDef. Used if
238      `graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
239    image_size: Required image width and height. See unit tests for the default
240      values.
241    input_tensor: Name of input Tensor.
242    output_tensor: Name or list of output Tensors. This function will compute
243      activations at the specified layer. Examples include INCEPTION_V3_OUTPUT
244      and INCEPTION_V3_FINAL_POOL which would result in this function computing
245      the final logits or the penultimate pooling layer.
246
247  Returns:
248    Tensor or Tensors corresponding to computed `output_tensor`.
249
250  Raises:
251    ValueError: If images are not the correct size.
252    ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
253  """
254  images = _validate_images(images, image_size)
255
256  if graph_def is None:
257    if default_graph_def_fn is None:
258      raise ValueError('If `graph_def` is `None`, must provide '
259                       '`default_graph_def_fn`.')
260    graph_def = default_graph_def_fn()
261
262  activations = run_image_classifier(images, graph_def, input_tensor,
263                                     output_tensor)
264  if isinstance(activations, list):
265    for i, activation in enumerate(activations):
266      if array_ops.rank(activation) != 2:
267        activations[i] = layers.flatten(activation)
268  else:
269    if array_ops.rank(activations) != 2:
270      activations = layers.flatten(activations)
271
272  return activations
273
274
275def run_image_classifier(tensor, graph_def, input_tensor,
276                         output_tensor, scope='RunClassifier'):
277  """Runs a network from a frozen graph.
278
279  Args:
280    tensor: An Input tensor.
281    graph_def: A GraphDef proto.
282    input_tensor: Name of input tensor in graph def.
283    output_tensor: A tensor name or list of tensor names in graph def.
284    scope: Name scope for classifier.
285
286  Returns:
287    Classifier output if `output_tensor` is a string, or a list of outputs if
288    `output_tensor` is a list.
289
290  Raises:
291    ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def.
292  """
293  input_map = {input_tensor: tensor}
294  is_singleton = isinstance(output_tensor, str)
295  if is_singleton:
296    output_tensor = [output_tensor]
297  classifier_outputs = importer.import_graph_def(
298      graph_def, input_map, output_tensor, name=scope)
299  if is_singleton:
300    classifier_outputs = classifier_outputs[0]
301
302  return classifier_outputs
303
304
305def classifier_score(images, classifier_fn, num_batches=1):
306  """Classifier score for evaluating a conditional generative model.
307
308  This is based on the Inception Score, but for an arbitrary classifier.
309
310  This technique is described in detail in https://arxiv.org/abs/1606.03498. In
311  summary, this function calculates
312
313  exp( E[ KL(p(y|x) || p(y)) ] )
314
315  which captures how different the network's classification prediction is from
316  the prior distribution over classes.
317
318  NOTE: This function consumes images, computes their logits, and then
319  computes the classifier score. If you would like to precompute many logits for
320  large batches, use clasifier_score_from_logits(), which this method also
321  uses.
322
323  Args:
324    images: Images to calculate the classifier score for.
325    classifier_fn: A function that takes images and produces logits based on a
326      classifier.
327    num_batches: Number of batches to split `generated_images` in to in order to
328      efficiently run them through the classifier network.
329
330  Returns:
331    The classifier score. A floating-point scalar of the same type as the output
332    of `classifier_fn`.
333  """
334  generated_images_list = array_ops.split(
335      images, num_or_size_splits=num_batches)
336
337  # Compute the classifier splits using the memory-efficient `map_fn`.
338  logits = functional_ops.map_fn(
339      fn=classifier_fn,
340      elems=array_ops.stack(generated_images_list),
341      parallel_iterations=1,
342      back_prop=False,
343      swap_memory=True,
344      name='RunClassifier')
345  logits = array_ops.concat(array_ops.unstack(logits), 0)
346
347  return classifier_score_from_logits(logits)
348
349
350def classifier_score_from_logits(logits):
351  """Classifier score for evaluating a generative model from logits.
352
353  This method computes the classifier score for a set of logits. This can be
354  used independently of the classifier_score() method, especially in the case
355  of using large batches during evaluation where we would like precompute all
356  of the logits before computing the classifier score.
357
358  This technique is described in detail in https://arxiv.org/abs/1606.03498. In
359  summary, this function calculates:
360
361  exp( E[ KL(p(y|x) || p(y)) ] )
362
363  which captures how different the network's classification prediction is from
364  the prior distribution over classes.
365
366  Args:
367    logits: Precomputed 2D tensor of logits that will be used to
368      compute the classifier score.
369
370  Returns:
371    The classifier score. A floating-point scalar of the same type as the output
372    of `logits`.
373  """
374  logits.shape.assert_has_rank(2)
375
376  # Use maximum precision for best results.
377  logits_dtype = logits.dtype
378  if logits_dtype != dtypes.float64:
379    logits = math_ops.to_double(logits)
380
381  p = nn_ops.softmax(logits)
382  q = math_ops.reduce_mean(p, axis=0)
383  kl = _kl_divergence(p, logits, q)
384  kl.shape.assert_has_rank(1)
385  log_score = math_ops.reduce_mean(kl)
386  final_score = math_ops.exp(log_score)
387
388  if logits_dtype != dtypes.float64:
389    final_score = math_ops.cast(final_score, logits_dtype)
390
391  return final_score
392
393
394inception_score = functools.partial(
395    classifier_score,
396    classifier_fn=functools.partial(
397        run_inception, output_tensor=INCEPTION_OUTPUT))
398
399
400def trace_sqrt_product(sigma, sigma_v):
401  """Find the trace of the positive sqrt of product of covariance matrices.
402
403  '_symmetric_matrix_square_root' only works for symmetric matrices, so we
404  cannot just take _symmetric_matrix_square_root(sigma * sigma_v).
405  ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily).
406
407  Let sigma = A A so A = sqrt(sigma), and sigma_v = B B.
408  We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B))
409  Note the following properties:
410  (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1)
411     => eigenvalues(A A B B) = eigenvalues (A B B A)
412  (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2))
413     => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A))
414  (iii) forall M: trace(M) = sum(eigenvalues(M))
415     => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v)))
416                                   = sum(sqrt(eigenvalues(A B B A)))
417                                   = sum(eigenvalues(sqrt(A B B A)))
418                                   = trace(sqrt(A B B A))
419                                   = trace(sqrt(A sigma_v A))
420  A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can**
421  use the _symmetric_matrix_square_root function to find the roots of these
422  matrices.
423
424  Args:
425    sigma: a square, symmetric, real, positive semi-definite covariance matrix
426    sigma_v: same as sigma
427
428  Returns:
429    The trace of the positive square root of sigma*sigma_v
430  """
431
432  # Note sqrt_sigma is called "A" in the proof above
433  sqrt_sigma = _symmetric_matrix_square_root(sigma)
434
435  # This is sqrt(A sigma_v A) above
436  sqrt_a_sigmav_a = math_ops.matmul(
437      sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma))
438
439  return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
440
441
442def frechet_classifier_distance(real_images,
443                                generated_images,
444                                classifier_fn,
445                                num_batches=1):
446  """Classifier distance for evaluating a generative model.
447
448  This is based on the Frechet Inception distance, but for an arbitrary
449  classifier.
450
451  This technique is described in detail in https://arxiv.org/abs/1706.08500.
452  Given two Gaussian distribution with means m and m_w and covariance matrices
453  C and C_w, this function calcuates
454
455  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
456
457  which captures how different the distributions of real images and generated
458  images (or more accurately, their visual features) are. Note that unlike the
459  Inception score, this is a true distance and utilizes information about real
460  world images.
461
462  Note that when computed using sample means and sample covariance matrices,
463  Frechet distance is biased. It is more biased for small sample sizes. (e.g.
464  even if the two distributions are the same, for a small sample size, the
465  expected Frechet distance is large). It is important to use the same
466  sample size to compute frechet classifier distance when comparing two
467  generative models.
468
469  NOTE: This function consumes images, computes their activations, and then
470  computes the classifier score. If you would like to precompute many
471  activations for real and generated images for large batches, please use
472  frechet_clasifier_distance_from_activations(), which this method also uses.
473
474  Args:
475    real_images: Real images to use to compute Frechet Inception distance.
476    generated_images: Generated images to use to compute Frechet Inception
477      distance.
478    classifier_fn: A function that takes images and produces activations
479      based on a classifier.
480    num_batches: Number of batches to split images in to in order to
481      efficiently run them through the classifier network.
482
483  Returns:
484    The Frechet Inception distance. A floating-point scalar of the same type
485    as the output of `classifier_fn`.
486  """
487
488  real_images_list = array_ops.split(
489      real_images, num_or_size_splits=num_batches)
490  generated_images_list = array_ops.split(
491      generated_images, num_or_size_splits=num_batches)
492
493  imgs = array_ops.stack(real_images_list + generated_images_list)
494
495  # Compute the activations using the memory-efficient `map_fn`.
496  activations = functional_ops.map_fn(
497      fn=classifier_fn,
498      elems=imgs,
499      parallel_iterations=1,
500      back_prop=False,
501      swap_memory=True,
502      name='RunClassifier')
503
504  # Split the activations by the real and generated images.
505  real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
506
507  # Ensure the activations have the right shapes.
508  real_a = array_ops.concat(array_ops.unstack(real_a), 0)
509  gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
510
511  return frechet_classifier_distance_from_activations(real_a, gen_a)
512
513
514def frechet_classifier_distance_from_activations(
515    real_activations, generated_activations):
516  """Classifier distance for evaluating a generative model from activations.
517
518  This methods computes the Frechet classifier distance from activations of
519  real images and generated images. This can be used independently of the
520  frechet_classifier_distance() method, especially in the case of using large
521  batches during evaluation where we would like precompute all of the
522  activations before computing the classifier distance.
523
524  This technique is described in detail in https://arxiv.org/abs/1706.08500.
525  Given two Gaussian distribution with means m and m_w and covariance matrices
526  C and C_w, this function calcuates
527
528  |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
529
530  which captures how different the distributions of real images and generated
531  images (or more accurately, their visual features) are. Note that unlike the
532  Inception score, this is a true distance and utilizes information about real
533  world images.
534
535  Args:
536    real_activations: 2D Tensor containing activations of real data. Shape is
537      [batch_size, activation_size].
538    generated_activations: 2D Tensor containing activations of generated data.
539      Shape is [batch_size, activation_size].
540
541  Returns:
542   The Frechet Inception distance. A floating-point scalar of the same type
543   as the output of the activations.
544
545  """
546  real_activations.shape.assert_has_rank(2)
547  generated_activations.shape.assert_has_rank(2)
548
549  activations_dtype = real_activations.dtype
550  if activations_dtype != dtypes.float64:
551    real_activations = math_ops.to_double(real_activations)
552    generated_activations = math_ops.to_double(generated_activations)
553
554  # Compute mean and covariance matrices of activations.
555  m = math_ops.reduce_mean(real_activations, 0)
556  m_v = math_ops.reduce_mean(generated_activations, 0)
557  num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
558
559  # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
560  real_centered = real_activations - m
561  sigma = math_ops.matmul(
562      real_centered, real_centered, transpose_a=True) / (num_examples - 1)
563
564  gen_centered = generated_activations - m_v
565  sigma_v = math_ops.matmul(
566      gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)
567
568  # Find the Tr(sqrt(sigma sigma_v)) component of FID
569  sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)
570
571  # Compute the two components of FID.
572
573  # First the covariance component.
574  # Here, note that trace(A + B) = trace(A) + trace(B)
575  trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component
576
577  # Next the distance between means.
578  mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
579  fid = trace + mean
580  if activations_dtype != dtypes.float64:
581    fid = math_ops.cast(fid, activations_dtype)
582
583  return fid
584
585
586frechet_inception_distance = functools.partial(
587    frechet_classifier_distance,
588    classifier_fn=functools.partial(
589        run_inception, output_tensor=INCEPTION_FINAL_POOL))
590