• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#  Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilites for `Model.compile`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21
22import six
23
24from tensorflow.python.distribute import distribution_strategy_context as ds_context
25from tensorflow.python.keras import losses as losses_mod
26from tensorflow.python.keras import metrics as metrics_mod
27from tensorflow.python.keras.utils import generic_utils
28from tensorflow.python.keras.utils import losses_utils
29from tensorflow.python.keras.utils import tf_utils
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.util import nest
33
34
35class Container(object):
36  """Base Container class."""
37
38  def __init__(self, output_names=None):
39    self._output_names = output_names
40
41  def build(self, y_pred):
42    if self._output_names is None:
43      # In Subclass API, output names like 'output_1' are used for
44      # `Metric` names.
45      self._output_names = create_pseudo_output_names(y_pred)
46
47  def _conform_to_outputs(self, outputs, struct):
48    """Convenience method to conform `struct` to `outputs` structure.
49
50    Mappings performed:
51
52    (1) Map a dict to a list of outputs, using the output names.
53    (2) Fill missing keys in a dict w/ `None`s.
54    (3) Map a single item to all outputs.
55
56    Args:
57      outputs: Model predictions.
58      struct: Arbitrary nested structure (e.g. of labels, sample_weights,
59        losses, or metrics).
60
61    Returns:
62      Mapping of `struct` to `outputs` structure.
63    """
64    struct = map_to_output_names(outputs, self._output_names, struct)
65    struct = map_missing_dict_keys(outputs, struct)
66    # Allow passing one object that applies to all outputs.
67    if not nest.is_nested(struct) and nest.is_nested(outputs):
68      struct = nest.map_structure(lambda _: struct, outputs)
69    return struct
70
71  def _maybe_broadcast_to_outputs(self, outputs, objects):
72    """Determines if losses / metrics should be applied to all outputs.
73
74    NOTE: This method should only be called for Metrics / Losses, not for
75    y_true / sample_weight.
76
77    Args:
78      outputs: Model predictions.
79      objects: Arbitrary nested structure (e.g. of losses or metrics)
80
81    Returns:
82      Arbitrary nested structure of objects, maybe copied to each output.
83
84    Applies a Loss / Metric to all outputs.
85    """
86    if not self._should_broadcast(objects):
87      return objects
88
89    # When there is more than one Model output, this is needed to keep
90    # each Metric / Loss separate. When there is only one Model output,
91    # the user-supplied object should be used.
92    should_copy_objects = len(nest.flatten(outputs)) > 1
93
94    def _broadcast_fn():
95      if should_copy_objects:
96        return nest.map_structure(self._copy_object, objects)
97      return objects
98
99    return nest.map_structure(lambda _: _broadcast_fn(), outputs)
100
101  def _should_broadcast(self, objects):
102    raise NotImplementedError
103
104  def _copy_object(self, obj):
105    raise NotImplementedError
106
107
108class LossesContainer(Container):
109  """A container class for losses passed to `Model.compile`."""
110
111  def __init__(self, losses, loss_weights=None, output_names=None):
112    super(LossesContainer, self).__init__(output_names=output_names)
113
114    # Keep user-supplied values untouched for recompiling and serialization.
115    self._user_losses = losses
116    self._user_loss_weights = loss_weights
117
118    self._losses = losses
119    self._loss_weights = loss_weights
120    self._per_output_metrics = None  # Per-output losses become metrics.
121    self._loss_metric = metrics_mod.Mean(name='loss')  # Total loss.
122    self._built = False
123
124  @property
125  def metrics(self):
126    """Per-output loss metrics."""
127    if not self._built:
128      return []
129    per_output_metrics = [
130        metric_obj for metric_obj in nest.flatten(self._per_output_metrics)
131        if metric_obj is not None
132    ]
133    return [self._loss_metric] + per_output_metrics
134
135  def build(self, y_pred):
136    """One-time setup of loss objects."""
137    super(LossesContainer, self).build(y_pred)
138
139    self._losses = self._maybe_broadcast_to_outputs(y_pred, self._losses)
140    self._losses = self._conform_to_outputs(y_pred, self._losses)
141    self._losses = nest.map_structure(self._get_loss_object, self._losses)
142    self._losses = nest.flatten(self._losses)
143
144    self._loss_weights = self._maybe_broadcast_to_outputs(
145        y_pred, self._loss_weights)
146    self._loss_weights = self._conform_to_outputs(y_pred, self._loss_weights)
147    self._loss_weights = nest.flatten(self._loss_weights)
148
149    self._create_metrics()
150    self._built = True
151
152  def _create_metrics(self):
153    """Creates per-output loss metrics, but only for multi-output Models."""
154    if len(self._output_names) == 1:
155      self._per_output_metrics = [None]
156    else:
157      self._per_output_metrics = []
158      for loss_obj, output_name in zip(self._losses, self._output_names):
159        if loss_obj is None:
160          self._per_output_metrics.append(None)
161        else:
162          self._per_output_metrics.append(
163              metrics_mod.Mean(output_name + '_loss'))
164
165  def __call__(self,
166               y_true,
167               y_pred,
168               sample_weight=None,
169               regularization_losses=None):
170    """Computes the overall loss.
171
172    Args:
173      y_true: An arbitrary structure of Tensors representing the ground truth.
174      y_pred: An arbitrary structure of Tensors representing a Model's outputs.
175      sample_weight: An arbitrary structure of Tensors representing the
176        per-sample loss weights. If one Tensor is passed, it is used for all
177        losses. If multiple Tensors are passed, the structure should match
178        `y_pred`.
179      regularization_losses: Additional losses to be added to the total loss.
180
181    Returns:
182      Tuple of `(total_loss, per_output_loss_list)`
183    """
184    y_true = self._conform_to_outputs(y_pred, y_true)
185    sample_weight = self._conform_to_outputs(y_pred, sample_weight)
186
187    if not self._built:
188      self.build(y_pred)
189
190    y_pred = nest.flatten(y_pred)
191    y_true = nest.flatten(y_true)
192    sample_weight = nest.flatten(sample_weight)
193
194    loss_values = []  # Used for gradient calculation.
195    loss_metric_values = []  # Used for loss metric calculation.
196    batch_dim = None
197    zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
198                self._per_output_metrics)
199    for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
200      if y_t is None or loss_obj is None:  # Ok to have no loss for an output.
201        continue
202
203      y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
204      sw = apply_mask(y_p, sw, get_mask(y_p))
205      loss_value = loss_obj(y_t, y_p, sample_weight=sw)
206
207      loss_metric_value = loss_value
208      # Correct for the `Mean` loss metrics counting each replica as a batch.
209      if loss_obj.reduction == losses_utils.ReductionV2.SUM:
210        loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
211
212      if batch_dim is None:
213        if tf_utils.is_ragged(y_t):
214          batch_dim = y_t.nrows()
215        else:
216          batch_dim = array_ops.shape(y_t)[0]
217
218      if metric_obj is not None:
219        metric_obj.update_state(loss_metric_value, sample_weight=batch_dim)
220
221      if loss_weight is not None:
222        loss_value *= loss_weight
223        loss_metric_value *= loss_weight
224
225      if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
226          loss_obj.reduction == losses_utils.ReductionV2.AUTO):
227        loss_value = losses_utils.scale_loss_for_distribution(loss_value)
228
229      loss_values.append(loss_value)
230      loss_metric_values.append(loss_metric_value)
231
232    if regularization_losses:
233      regularization_losses = losses_utils.cast_losses_to_common_dtype(
234          regularization_losses)
235      reg_loss = math_ops.add_n(regularization_losses)
236      loss_metric_values.append(reg_loss)
237      loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
238
239    if loss_values:
240      loss_metric_values = losses_utils.cast_losses_to_common_dtype(
241          loss_metric_values)
242      total_loss_metric_value = math_ops.add_n(loss_metric_values)
243      self._loss_metric.update_state(
244          total_loss_metric_value, sample_weight=batch_dim)
245
246      loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
247      total_loss = math_ops.add_n(loss_values)
248      return total_loss
249    else:
250      # Ok for a model to have no compiled loss.
251      return array_ops.zeros(shape=())
252
253  def _get_loss_object(self, loss):
254    """Returns a `Loss` object.
255
256    Converts the user-supplied loss to a `Loss` object. Also allows
257    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
258
259    Args:
260      loss: A string, function, or `Loss` object.
261
262    Returns:
263      A `Loss` object.
264    """
265    if loss is None:
266      return None  # Ok to have no loss for an output.
267
268    loss = losses_mod.get(loss)
269    if not isinstance(loss, losses_mod.Loss):
270      loss_name = get_custom_object_name(loss)
271      if loss_name is None:
272        raise ValueError('Loss should be a callable, found: {}'.format(loss))
273      loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
274    loss._allow_sum_over_batch_size = True  # pylint: disable=protected-access
275    return loss
276
277  def _should_broadcast(self, obj):
278    return not nest.is_nested(obj)
279
280  def _copy_object(self, obj):
281    return obj  # Losses don't need to be copied.
282
283
284class MetricsContainer(Container):
285  """A container class for metrics passed to `Model.compile`."""
286
287  def __init__(self, metrics=None, weighted_metrics=None, output_names=None):
288    super(MetricsContainer, self).__init__(output_names=output_names)
289
290    # Keep user-supplied values untouched for recompiling and serialization.
291    self._user_metrics = metrics
292    self._user_weighted_metrics = weighted_metrics
293
294    self._metrics = metrics
295    self._weighted_metrics = weighted_metrics
296    self._built = False
297
298  @property
299  def metrics(self):
300    """All metrics in this container."""
301    if not self._built:
302      return []
303    return self._metrics_in_order
304
305  @property
306  def unweighted_metrics(self):
307    """Metrics in this container that should not be passed `sample_weight`."""
308    if not self._built:
309      return None
310    return nest.flatten(self._metrics)
311
312  @property
313  def weighted_metrics(self):
314    """Metrics in this container that should be passed `sample_weight`."""
315    if not self._built:
316      return None
317    return nest.flatten(self._weighted_metrics)
318
319  def build(self, y_pred, y_true):
320    """One-time setup of metric objects."""
321    super(MetricsContainer, self).build(y_pred)
322
323    self._metrics = self._maybe_broadcast_to_outputs(y_pred, self._metrics)
324    self._metrics = self._conform_to_outputs(y_pred, self._metrics)
325
326    self._weighted_metrics = self._maybe_broadcast_to_outputs(
327        y_pred, self._weighted_metrics)
328    self._weighted_metrics = self._conform_to_outputs(y_pred,
329                                                      self._weighted_metrics)
330
331    # Standardize on tuple since `tf.data` turns lists into `Tensor`s.
332    y_pred = nest.list_to_tuple(y_pred)
333    y_true = nest.list_to_tuple(y_true)
334    self._metrics = nest.list_to_tuple(self._metrics)
335    self._weighted_metrics = nest.list_to_tuple(self._weighted_metrics)
336
337    # Convert to `Metric` objects, potentially disambiguating based on output
338    # properties.
339    self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects,
340                                             self._metrics, y_true, y_pred)
341    self._weighted_metrics = nest.map_structure_up_to(y_pred,
342                                                      self._get_metric_objects,
343                                                      self._weighted_metrics,
344                                                      y_true, y_pred)
345
346    self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False)
347    self._weighted_metrics = nest.flatten_up_to(
348        y_pred, self._weighted_metrics, check_types=False)
349
350    # Assumes metrics, weighted_metrics have been flattened up to outputs.
351    self._set_metric_names()
352    self._create_ordered_metrics()
353    self._built = True
354
355  def _set_metric_names(self):
356    """Sets unique metric names."""
357    # For multi-output models, prepend the output name to the metric name.
358    # For weighted metrics, prepend "weighted_" if the name would be non-unique.
359    # pylint: disable=protected-access
360    metric_names = set()
361    is_multi_output = len(self._output_names) > 1
362    zip_args = (self._output_names, self._metrics, self._weighted_metrics)
363    for output_name, output_metrics, weighted_output_metrics in zip(*zip_args):
364      for m in output_metrics:
365        if m is None:
366          continue
367        if is_multi_output:
368          m._name = output_name + '_' + m._name
369        if m._name in metric_names:
370          raise ValueError('Found two metrics with the same name: {}'.format(
371              m._name))
372        metric_names.add(m._name)
373
374      for wm in weighted_output_metrics:
375        if wm is None:
376          continue
377        if is_multi_output:
378          if output_name + '_' + wm._name in metric_names:
379            wm._name = output_name + '_weighted_' + wm._name
380          else:
381            wm._name = output_name + '_' + wm._name
382        elif wm._name in metric_names:
383          wm._name = 'weighted_' + wm._name
384
385        if wm._name in metric_names:
386          raise ValueError('Found two metrics with the same name: {}'.format(
387              wm._name))
388        metric_names.add(wm._name)
389    # pylint: enable=protected-access
390
391  def _create_ordered_metrics(self):
392    """Cache the flat order needed when returning metrics, for backwards compat."""
393    self._metrics_in_order = []
394    for output_metrics, output_weighted_metrics in zip(self._metrics,
395                                                       self._weighted_metrics):
396      for m in nest.flatten(output_metrics):
397        if m is not None:
398          self._metrics_in_order.append(m)
399      for wm in nest.flatten(output_weighted_metrics):
400        if wm is not None:
401          self._metrics_in_order.append(wm)
402
403  def update_state(self, y_true, y_pred, sample_weight=None):
404    """Updates the state of per-output metrics."""
405    y_true = self._conform_to_outputs(y_pred, y_true)
406    sample_weight = self._conform_to_outputs(y_pred, sample_weight)
407
408    if not self._built:
409      self.build(y_pred, y_true)
410
411    y_pred = nest.flatten(y_pred)
412    y_true = nest.flatten(y_true) if y_true is not None else []
413    sample_weight = nest.flatten(sample_weight)
414
415    zip_args = (y_true, y_pred, sample_weight, self._metrics,
416                self._weighted_metrics)
417    for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
418      # Ok to have no metrics for an output.
419      if (y_t is None or (all(m is None for m in metric_objs) and
420                          all(wm is None for wm in weighted_metric_objs))):
421        continue
422
423      y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
424      mask = get_mask(y_p)
425      sw = apply_mask(y_p, sw, mask)
426
427      for metric_obj in metric_objs:
428        if metric_obj is None:
429          continue
430        metric_obj.update_state(y_t, y_p, sample_weight=mask)
431
432      for weighted_metric_obj in weighted_metric_objs:
433        if weighted_metric_obj is None:
434          continue
435        weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
436
437  def _get_metric_objects(self, metrics, y_t, y_p):
438    """Convert user-supplied metrics to `Metric` objects."""
439    metrics = nest.flatten(metrics)
440    return [self._get_metric_object(m, y_t, y_p) for m in metrics]
441
442  def _get_metric_object(self, metric, y_t, y_p):
443    """Converts user-supplied metric to a `Metric` object.
444
445    Args:
446      metric: A string, function, or `Metric` object.
447      y_t: Sample of label.
448      y_p: Sample of output.
449
450    Returns:
451      A `Metric` object.
452    """
453    if metric is None:
454      return None  # Ok to have no metric for an output.
455
456    # Convenience feature for selecting b/t binary, categorical,
457    # and sparse categorical.
458    if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
459      metric_obj = metrics_mod.get(metric)
460    else:
461      y_t_rank = len(y_t.shape.as_list())
462      y_p_rank = len(y_p.shape.as_list())
463      y_t_last_dim = y_t.shape.as_list()[-1]
464      y_p_last_dim = y_p.shape.as_list()[-1]
465
466      is_binary = y_p_last_dim == 1
467      is_sparse_categorical = (
468          y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1)
469
470      if metric in ['accuracy', 'acc']:
471        if is_binary:
472          metric_obj = metrics_mod.binary_accuracy
473        elif is_sparse_categorical:
474          metric_obj = metrics_mod.sparse_categorical_accuracy
475        else:
476          metric_obj = metrics_mod.categorical_accuracy
477      else:
478        if is_binary:
479          metric_obj = metrics_mod.binary_crossentropy
480        elif is_sparse_categorical:
481          metric_obj = metrics_mod.sparse_categorical_crossentropy
482        else:
483          metric_obj = metrics_mod.categorical_crossentropy
484
485    if isinstance(metric_obj, losses_mod.Loss):
486      metric_obj._allow_sum_over_batch_size = True  # pylint: disable=protected-access
487
488    if not isinstance(metric_obj, metrics_mod.Metric):
489      if isinstance(metric, six.string_types):
490        metric_name = metric
491      else:
492        metric_name = get_custom_object_name(metric)
493        if metric_name is None:
494          raise ValueError(
495              'Metric should be a callable, found: {}'.format(metric))
496
497      metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
498
499    return metric_obj
500
501  def _should_broadcast(self, obj):
502    # e.g. 'mse'.
503    if not nest.is_nested(obj):
504      return True
505    # e.g. ['mse'] or ['mse', 'mae'].
506    return (isinstance(obj, (list, tuple)) and
507            not any(nest.is_nested(o) for o in obj))
508
509  def _copy_object(self, obj):
510    if isinstance(obj, metrics_mod.Metric):
511      return obj.__class__.from_config(obj.get_config())
512    return obj  # Can be a function or `None`.
513
514
515def create_pseudo_output_names(outputs):
516  """Create pseudo output names for a subclassed Model."""
517  return _create_pseudo_names(outputs, prefix='output_')
518
519
520def create_pseudo_input_names(inputs):
521  """Create pseudo input names for a subclassed Model."""
522  return _create_pseudo_names(inputs, prefix='input_')
523
524
525def _create_pseudo_names(tensors, prefix):
526  """Creates pseudo {input | output} names for subclassed Models.
527
528  Warning: this function should only be used to define default
529  names for `Metics` and `SavedModel`. No other use cases should
530  rely on a `Model`'s input or output names.
531
532  Example with dict:
533
534  `{'a': [x1, x2], 'b': x3}` becomes:
535  `['a_1', 'a_2', 'b']`
536
537  Example with list:
538
539  `[x, y]` becomes:
540  `['output_1', 'output_2']`
541
542  Args:
543    tensors: `Model`'s outputs or inputs.
544    prefix: 'output_' for outputs, 'input_' for inputs.
545
546  Returns:
547    Flattened list of pseudo names.
548  """
549
550  def one_index(ele):
551    # Start with "output_1" instead of "output_0".
552    if isinstance(ele, int):
553      return ele + 1
554    return ele
555
556  flat_paths = list(nest.yield_flat_paths(tensors))
557  flat_paths = nest.map_structure(one_index, flat_paths)
558  names = []
559  for path in flat_paths:
560    if not path:
561      name = prefix + '1'  # Single output.
562    else:
563      name = '_'.join(str(p) for p in path)
564      if isinstance(path[0], int):
565        name = prefix + name
566    names.append(name)
567  return names
568
569
570def map_to_output_names(y_pred, output_names, struct):
571  """Maps a dict to a list using `output_names` as keys.
572
573  This is a convenience feature only. When a `Model`'s outputs
574  are a list, you can specify per-output losses and metrics as
575  a dict, where the keys are the output names. If you specify
576  per-output losses and metrics via the same structure as the
577  `Model`'s outputs (recommended), no mapping is performed.
578
579  For the Functional API, the output names are the names of the
580  last layer of each output. For the Subclass API, the output names
581  are determined by `create_pseudo_output_names` (For example:
582  `['output_1', 'output_2']` for a list of outputs).
583
584  This mapping preserves backwards compatibility for `compile` and
585  `fit`.
586
587  Args:
588    y_pred: Sample outputs of the Model, to determine if this convenience
589      feature should be applied (`struct` is returned unmodified if `y_pred`
590      isn't a flat list).
591    output_names: List. The names of the outputs of the Model.
592    struct: The structure to map.
593
594  Returns:
595    `struct` mapped to a list in same order as `output_names`.
596  """
597  single_output = not nest.is_nested(y_pred)
598  outputs_are_flat_list = (not single_output and
599                           isinstance(y_pred, (list, tuple)) and
600                           not any(nest.is_nested(y_p) for y_p in y_pred))
601
602  if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
603    output_names = output_names or create_pseudo_output_names(y_pred)
604    struct = copy.copy(struct)
605    new_struct = [struct.pop(name, None) for name in output_names]
606    if struct:
607      raise ValueError('Found unexpected keys that do not correspond '
608                       'to any Model output: {}. Expected: {}'.format(
609                           struct.keys(), output_names))
610    if len(new_struct) == 1:
611      return new_struct[0]
612    return new_struct
613  else:
614    return struct
615
616
617def map_missing_dict_keys(y_pred, struct):
618  """Replaces missing dict keys in `struct` with `None` placeholders."""
619  if not isinstance(y_pred, dict) or not isinstance(struct, dict):
620    return struct
621  for k in y_pred.keys():
622    if k not in struct:
623      struct[k] = None
624  return struct
625
626
627def match_dtype_and_rank(y_t, y_p, sw):
628  """Match dtype and rank of predictions."""
629  if y_t.shape.rank == 1 and y_p.shape.rank == 2:
630    y_t = array_ops.expand_dims_v2(y_t, axis=-1)
631  if sw is not None:
632    if sw.shape.rank == 1 and y_p.shape.rank == 2:
633      sw = array_ops.expand_dims_v2(sw, axis=-1)
634
635  # Dtype.
636  # This is required mainly for custom loss functions which do not take care
637  # casting dtypes.
638  if ((y_t.dtype.is_floating and y_p.dtype.is_floating) or
639      (y_t.dtype.is_integer and y_p.dtype.is_integer)):
640    y_t = math_ops.cast(y_t, y_p.dtype)
641
642  if sw is not None:
643    sw = math_ops.cast(sw, y_p.dtype)
644  return y_t, y_p, sw
645
646
647def get_mask(y_p):
648  """Returns Keras mask from tensor."""
649  return getattr(y_p, '_keras_mask', None)
650
651
652def apply_mask(y_p, sw, mask):
653  """Applies any mask on predictions to sample weights."""
654  if mask is not None:
655    mask = math_ops.cast(mask, y_p.dtype)
656    if sw is not None:
657      mask, _, sw = (
658          losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw))
659      sw *= mask
660    else:
661      sw = mask
662  return sw
663
664
665def get_custom_object_name(obj):
666  """Returns the name to use for a custom loss or metric callable.
667
668  Args:
669    obj: Custom loss of metric callable
670
671  Returns:
672    Name to use, or `None` if the object was not recognized.
673  """
674  if hasattr(obj, 'name'):  # Accept `Loss` instance as `Metric`.
675    return obj.name
676  elif hasattr(obj, '__name__'):  # Function.
677    return obj.__name__
678  elif hasattr(obj, '__class__'):  # Class instance.
679    return generic_utils.to_snake_case(obj.__class__.__name__)
680  else:  # Unrecognized object.
681    return None
682