• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Hyperparameter values."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import json
21import numbers
22import re
23
24import six
25
26from tensorflow.contrib.training.python.training import hparam_pb2
27from tensorflow.python.framework import ops
28from tensorflow.python.util import compat
29from tensorflow.python.util import deprecation
30
31# Define the regular expression for parsing a single clause of the input
32# (delimited by commas).  A legal clause looks like:
33#   <variable name>[<index>]? = <rhs>
34# where <rhs> is either a single token or [] enclosed list of tokens.
35# For example:  "var[1] = a" or "x = [1,2,3]"
36PARAM_RE = re.compile(r"""
37  (?P<name>[a-zA-Z][\w\.]*)      # variable name: "var" or "x"
38  (\[\s*(?P<index>\d+)\s*\])?  # (optional) index: "1" or None
39  \s*=\s*
40  ((?P<val>[^,\[]*)            # single value: "a" or None
41   |
42   \[(?P<vals>[^\]]*)\])       # list of values: None or "1,2,3"
43  ($|,\s*)""", re.VERBOSE)
44
45
46def _parse_fail(name, var_type, value, values):
47  """Helper function for raising a value error for bad assignment."""
48  raise ValueError(
49      'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
50      (name, var_type.__name__, value, values))
51
52
53def _reuse_fail(name, values):
54  """Helper function for raising a value error for reuse of name."""
55  raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
56                                                                      values))
57
58
59def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
60                          results_dictionary):
61  """Update results_dictionary with a scalar value.
62
63  Used to update the results_dictionary to be returned by parse_values when
64  encountering a clause with a scalar RHS (e.g.  "s=5" or "arr[0]=5".)
65
66  Mutates results_dictionary.
67
68  Args:
69    name: Name of variable in assignment ("s" or "arr").
70    parse_fn: Function for parsing the actual value.
71    var_type: Type of named variable.
72    m_dict: Dictionary constructed from regex parsing.
73      m_dict['val']: RHS value (scalar)
74      m_dict['index']: List index value (or None)
75    values: Full expression being parsed
76    results_dictionary: The dictionary being updated for return by the parsing
77      function.
78
79  Raises:
80    ValueError: If the name has already been used.
81  """
82  try:
83    parsed_value = parse_fn(m_dict['val'])
84  except ValueError:
85    _parse_fail(name, var_type, m_dict['val'], values)
86
87  # If no index is provided
88  if not m_dict['index']:
89    if name in results_dictionary:
90      _reuse_fail(name, values)
91    results_dictionary[name] = parsed_value
92  else:
93    if name in results_dictionary:
94      # The name has already been used as a scalar, then it
95      # will be in this dictionary and map to a non-dictionary.
96      if not isinstance(results_dictionary.get(name), dict):
97        _reuse_fail(name, values)
98    else:
99      results_dictionary[name] = {}
100
101    index = int(m_dict['index'])
102    # Make sure the index position hasn't already been assigned a value.
103    if index in results_dictionary[name]:
104      _reuse_fail('{}[{}]'.format(name, index), values)
105    results_dictionary[name][index] = parsed_value
106
107
108def _process_list_value(name, parse_fn, var_type, m_dict, values,
109                        results_dictionary):
110  """Update results_dictionary from a list of values.
111
112  Used to update results_dictionary to be returned by parse_values when
113  encountering a clause with a list RHS (e.g.  "arr=[1,2,3]".)
114
115  Mutates results_dictionary.
116
117  Args:
118    name: Name of variable in assignment ("arr").
119    parse_fn: Function for parsing individual values.
120    var_type: Type of named variable.
121    m_dict: Dictionary constructed from regex parsing.
122      m_dict['val']: RHS value (scalar)
123    values: Full expression being parsed
124    results_dictionary: The dictionary being updated for return by the parsing
125      function.
126
127  Raises:
128    ValueError: If the name has an index or the values cannot be parsed.
129  """
130  if m_dict['index'] is not None:
131    raise ValueError('Assignment of a list to a list index.')
132  elements = filter(None, re.split('[ ,]', m_dict['vals']))
133  # Make sure the name hasn't already been assigned a value
134  if name in results_dictionary:
135    raise _reuse_fail(name, values)
136  try:
137    results_dictionary[name] = [parse_fn(e) for e in elements]
138  except ValueError:
139    _parse_fail(name, var_type, m_dict['vals'], values)
140
141
142def _cast_to_type_if_compatible(name, param_type, value):
143  """Cast hparam to the provided type, if compatible.
144
145  Args:
146    name: Name of the hparam to be cast.
147    param_type: The type of the hparam.
148    value: The value to be cast, if compatible.
149
150  Returns:
151    The result of casting `value` to `param_type`.
152
153  Raises:
154    ValueError: If the type of `value` is not compatible with param_type.
155      * If `param_type` is a string type, but `value` is not.
156      * If `param_type` is a boolean, but `value` is not, or vice versa.
157      * If `param_type` is an integer type, but `value` is not.
158      * If `param_type` is a float type, but `value` is not a numeric type.
159  """
160  fail_msg = (
161      "Could not cast hparam '%s' of type '%s' from value %r" %
162      (name, param_type, value))
163
164  # Some callers use None, for which we can't do any casting/checking. :(
165  if issubclass(param_type, type(None)):
166    return value
167
168  # Avoid converting a non-string type to a string.
169  if (issubclass(param_type, (six.string_types, six.binary_type)) and
170      not isinstance(value, (six.string_types, six.binary_type))):
171    raise ValueError(fail_msg)
172
173  # Avoid converting a number or string type to a boolean or vice versa.
174  if issubclass(param_type, bool) != isinstance(value, bool):
175    raise ValueError(fail_msg)
176
177  # Avoid converting float to an integer (the reverse is fine).
178  if (issubclass(param_type, numbers.Integral) and
179      not isinstance(value, numbers.Integral)):
180    raise ValueError(fail_msg)
181
182  # Avoid converting a non-numeric type to a numeric type.
183  if (issubclass(param_type, numbers.Number) and
184      not isinstance(value, numbers.Number)):
185    raise ValueError(fail_msg)
186
187  return param_type(value)
188
189
190def parse_values(values, type_map, ignore_unknown=False):
191  """Parses hyperparameter values from a string into a python map.
192
193  `values` is a string containing comma-separated `name=value` pairs.
194  For each pair, the value of the hyperparameter named `name` is set to
195  `value`.
196
197  If a hyperparameter name appears multiple times in `values`, a ValueError
198  is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
199
200  If a hyperparameter name in both an index assignment and scalar assignment,
201  a ValueError is raised.  (e.g. 'a=[1,2,3],a[0] = 1').
202
203  The hyperparameter name may contain '.' symbols, which will result in an
204  attribute name that is only accessible through the getattr and setattr
205  functions.  (And must be first explicit added through add_hparam.)
206
207  WARNING: Use of '.' in your variable names is allowed, but is not well
208  supported and not recommended.
209
210  The `value` in `name=value` must follows the syntax according to the
211  type of the parameter:
212
213  *  Scalar integer: A Python-parsable integer point value.  E.g.: 1,
214     100, -12.
215  *  Scalar float: A Python-parsable floating point value.  E.g.: 1.0,
216     -.54e89.
217  *  Boolean: Either true or false.
218  *  Scalar string: A non-empty sequence of characters, excluding comma,
219     spaces, and square brackets.  E.g.: foo, bar_1.
220  *  List: A comma separated list of scalar values of the parameter type
221     enclosed in square brackets.  E.g.: [1,2,3], [1.0,1e-12], [high,low].
222
223  When index assignment is used, the corresponding type_map key should be the
224  list name.  E.g. for "arr[1]=0" the type_map must have the key "arr" (not
225  "arr[1]").
226
227  Args:
228    values: String.  Comma separated list of `name=value` pairs where
229      'value' must follow the syntax described above.
230    type_map: A dictionary mapping hyperparameter names to types.  Note every
231      parameter name in values must be a key in type_map.  The values must
232      conform to the types indicated, where a value V is said to conform to a
233      type T if either V has type T, or V is a list of elements of type T.
234      Hence, for a multidimensional parameter 'x' taking float values,
235      'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
236    ignore_unknown: Bool. Whether values that are missing a type in type_map
237      should be ignored. If set to True, a ValueError will not be raised for
238      unknown hyperparameter type.
239
240  Returns:
241    A python map mapping each name to either:
242    * A scalar value.
243    * A list of scalar values.
244    * A dictionary mapping index numbers to scalar values.
245    (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
246
247  Raises:
248    ValueError: If there is a problem with input.
249    * If `values` cannot be parsed.
250    * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
251    * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
252      'a[1]=1,a[1]=2', or 'a=1,a=[1]')
253  """
254  results_dictionary = {}
255  pos = 0
256  while pos < len(values):
257    m = PARAM_RE.match(values, pos)
258    if not m:
259      raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
260    # Check that there is a comma between parameters and move past it.
261    pos = m.end()
262    # Parse the values.
263    m_dict = m.groupdict()
264    name = m_dict['name']
265    if name not in type_map:
266      if ignore_unknown:
267        continue
268      raise ValueError('Unknown hyperparameter type for %s' % name)
269    type_ = type_map[name]
270
271    # Set up correct parsing function (depending on whether type_ is a bool)
272    if type_ == bool:
273
274      def parse_bool(value):
275        if value in ['true', 'True']:
276          return True
277        elif value in ['false', 'False']:
278          return False
279        else:
280          try:
281            return bool(int(value))
282          except ValueError:
283            _parse_fail(name, type_, value, values)
284
285      parse = parse_bool
286    else:
287      parse = type_
288
289    # If a singe value is provided
290    if m_dict['val'] is not None:
291      _process_scalar_value(name, parse, type_, m_dict, values,
292                            results_dictionary)
293
294    # If the assigned value is a list:
295    elif m_dict['vals'] is not None:
296      _process_list_value(name, parse, type_, m_dict, values,
297                          results_dictionary)
298
299    else:  # Not assigned a list or value
300      _parse_fail(name, type_, '', values)
301
302  return results_dictionary
303
304
305class HParams(object):
306  """Class to hold a set of hyperparameters as name-value pairs.
307
308  A `HParams` object holds hyperparameters used to build and train a model,
309  such as the number of hidden units in a neural net layer or the learning rate
310  to use when training.
311
312  You first create a `HParams` object by specifying the names and values of the
313  hyperparameters.
314
315  To make them easily accessible the parameter names are added as direct
316  attributes of the class.  A typical usage is as follows:
317
318  ```python
319  # Create a HParams object specifying names and values of the model
320  # hyperparameters:
321  hparams = HParams(learning_rate=0.1, num_hidden_units=100)
322
323  # The hyperparameter are available as attributes of the HParams object:
324  hparams.learning_rate ==> 0.1
325  hparams.num_hidden_units ==> 100
326  ```
327
328  Hyperparameters have type, which is inferred from the type of their value
329  passed at construction type.   The currently supported types are: integer,
330  float, boolean, string, and list of integer, float, boolean, or string.
331
332  You can override hyperparameter values by calling the
333  [`parse()`](#HParams.parse) method, passing a string of comma separated
334  `name=value` pairs.  This is intended to make it possible to override
335  any hyperparameter values from a single command-line flag to which
336  the user passes 'hyper-param=value' pairs.  It avoids having to define
337  one flag for each hyperparameter.
338
339  The syntax expected for each value depends on the type of the parameter.
340  See `parse()` for a description of the syntax.
341
342  Example:
343
344  ```python
345  # Define a command line flag to pass name=value pairs.
346  # For example using argparse:
347  import argparse
348  parser = argparse.ArgumentParser(description='Train my model.')
349  parser.add_argument('--hparams', type=str,
350                      help='Comma separated list of "name=value" pairs.')
351  args = parser.parse_args()
352  ...
353  def my_program():
354    # Create a HParams object specifying the names and values of the
355    # model hyperparameters:
356    hparams = tf.contrib.training.HParams(
357        learning_rate=0.1,
358        num_hidden_units=100,
359        activations=['relu', 'tanh'])
360
361    # Override hyperparameters values by parsing the command line
362    hparams.parse(args.hparams)
363
364    # If the user passed `--hparams=learning_rate=0.3` on the command line
365    # then 'hparams' has the following attributes:
366    hparams.learning_rate ==> 0.3
367    hparams.num_hidden_units ==> 100
368    hparams.activations ==> ['relu', 'tanh']
369
370    # If the hyperparameters are in json format use parse_json:
371    hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
372  ```
373  """
374
375  _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.
376
377  def __init__(self, hparam_def=None, model_structure=None, **kwargs):
378    """Create an instance of `HParams` from keyword arguments.
379
380    The keyword arguments specify name-values pairs for the hyperparameters.
381    The parameter types are inferred from the type of the values passed.
382
383    The parameter names are added as attributes of `HParams` object, so they
384    can be accessed directly with the dot notation `hparams._name_`.
385
386    Example:
387
388    ```python
389    # Define 3 hyperparameters: 'learning_rate' is a float parameter,
390    # 'num_hidden_units' an integer parameter, and 'activation' a string
391    # parameter.
392    hparams = tf.contrib.training.HParams(
393        learning_rate=0.1, num_hidden_units=100, activation='relu')
394
395    hparams.activation ==> 'relu'
396    ```
397
398    Note that a few names are reserved and cannot be used as hyperparameter
399    names.  If you use one of the reserved name the constructor raises a
400    `ValueError`.
401
402    Args:
403      hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
404        protocol buffer. If provided, this object is initialized by
405        deserializing hparam_def.  Otherwise **kwargs is used.
406      model_structure: An instance of ModelStructure, defining the feature
407        crosses to be used in the Trial.
408      **kwargs: Key-value pairs where the key is the hyperparameter name and
409        the value is the value for the parameter.
410
411    Raises:
412      ValueError: If both `hparam_def` and initialization values are provided,
413        or if one of the arguments is invalid.
414
415    """
416    # Register the hyperparameters and their type in _hparam_types.
417    # This simplifies the implementation of parse().
418    # _hparam_types maps the parameter name to a tuple (type, bool).
419    # The type value is the type of the parameter for scalar hyperparameters,
420    # or the type of the list elements for multidimensional hyperparameters.
421    # The bool value is True if the value is a list, False otherwise.
422    self._hparam_types = {}
423    self._model_structure = model_structure
424    if hparam_def:
425      self._init_from_proto(hparam_def)
426      if kwargs:
427        raise ValueError('hparam_def and initialization values are '
428                         'mutually exclusive')
429    else:
430      for name, value in six.iteritems(kwargs):
431        self.add_hparam(name, value)
432
433  def _init_from_proto(self, hparam_def):
434    """Creates a new HParams from `HParamDef` protocol buffer.
435
436    Args:
437      hparam_def: `HParamDef` protocol buffer.
438    """
439    assert isinstance(hparam_def, hparam_pb2.HParamDef)
440    for name, value in hparam_def.hparam.items():
441      kind = value.WhichOneof('kind')
442      if kind.endswith('_value'):
443        # Single value.
444        if kind.startswith('int64'):
445          # Setting attribute value to be 'int' to ensure the type is compatible
446          # with both Python2 and Python3.
447          self.add_hparam(name, int(getattr(value, kind)))
448        elif kind.startswith('bytes'):
449          # Setting attribute value to be 'str' to ensure the type is compatible
450          # with both Python2 and Python3. UTF-8 encoding is assumed.
451          self.add_hparam(name, compat.as_str(getattr(value, kind)))
452        else:
453          self.add_hparam(name, getattr(value, kind))
454      else:
455        # List of values.
456        if kind.startswith('int64'):
457          # Setting attribute value to be 'int' to ensure the type is compatible
458          # with both Python2 and Python3.
459          self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
460        elif kind.startswith('bytes'):
461          # Setting attribute value to be 'str' to ensure the type is compatible
462          # with both Python2 and Python3. UTF-8 encoding is assumed.
463          self.add_hparam(
464              name, [compat.as_str(v) for v in getattr(value, kind).value])
465        else:
466          self.add_hparam(name, [v for v in getattr(value, kind).value])
467
468  def add_hparam(self, name, value):
469    """Adds {name, value} pair to hyperparameters.
470
471    Args:
472      name: Name of the hyperparameter.
473      value: Value of the hyperparameter. Can be one of the following types:
474        int, float, string, int list, float list, or string list.
475
476    Raises:
477      ValueError: if one of the arguments is invalid.
478    """
479    # Keys in kwargs are unique, but 'name' could the name of a pre-existing
480    # attribute of this object.  In that case we refuse to use it as a
481    # hyperparameter name.
482    if getattr(self, name, None) is not None:
483      raise ValueError('Hyperparameter name is reserved: %s' % name)
484    if isinstance(value, (list, tuple)):
485      if not value:
486        raise ValueError(
487            'Multi-valued hyperparameters cannot be empty: %s' % name)
488      self._hparam_types[name] = (type(value[0]), True)
489    else:
490      self._hparam_types[name] = (type(value), False)
491    setattr(self, name, value)
492
493  def set_hparam(self, name, value):
494    """Set the value of an existing hyperparameter.
495
496    This function verifies that the type of the value matches the type of the
497    existing hyperparameter.
498
499    Args:
500      name: Name of the hyperparameter.
501      value: New value of the hyperparameter.
502
503    Raises:
504      KeyError: If the hyperparameter doesn't exist.
505      ValueError: If there is a type mismatch.
506    """
507    param_type, is_list = self._hparam_types[name]
508    if isinstance(value, list):
509      if not is_list:
510        raise ValueError(
511            'Must not pass a list for single-valued parameter: %s' % name)
512      setattr(self, name, [
513          _cast_to_type_if_compatible(name, param_type, v) for v in value])
514    else:
515      if is_list:
516        raise ValueError(
517            'Must pass a list for multi-valued parameter: %s.' % name)
518      setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
519
520  def del_hparam(self, name):
521    """Removes the hyperparameter with key 'name'.
522
523    Does nothing if it isn't present.
524
525    Args:
526      name: Name of the hyperparameter.
527    """
528    if hasattr(self, name):
529      delattr(self, name)
530      del self._hparam_types[name]
531
532  def parse(self, values):
533    """Override existing hyperparameter values, parsing new values from a string.
534
535    See parse_values for more detail on the allowed format for values.
536
537    Args:
538      values: String.  Comma separated list of `name=value` pairs where 'value'
539        must follow the syntax described above.
540
541    Returns:
542      The `HParams` instance.
543
544    Raises:
545      ValueError: If `values` cannot be parsed or a hyperparameter in `values`
546      doesn't exist.
547    """
548    type_map = dict()
549    for name, t in self._hparam_types.items():
550      param_type, _ = t
551      type_map[name] = param_type
552
553    values_map = parse_values(values, type_map)
554    return self.override_from_dict(values_map)
555
556  def override_from_dict(self, values_dict):
557    """Override existing hyperparameter values, parsing new values from a dictionary.
558
559    Args:
560      values_dict: Dictionary of name:value pairs.
561
562    Returns:
563      The `HParams` instance.
564
565    Raises:
566      KeyError: If a hyperparameter in `values_dict` doesn't exist.
567      ValueError: If `values_dict` cannot be parsed.
568    """
569    for name, value in values_dict.items():
570      self.set_hparam(name, value)
571    return self
572
573  @deprecation.deprecated(None, 'Use `override_from_dict`.')
574  def set_from_map(self, values_map):
575    """DEPRECATED. Use override_from_dict."""
576    return self.override_from_dict(values_dict=values_map)
577
578  def set_model_structure(self, model_structure):
579    self._model_structure = model_structure
580
581  def get_model_structure(self):
582    return self._model_structure
583
584  def to_json(self, indent=None, separators=None, sort_keys=False):
585    """Serializes the hyperparameters into JSON.
586
587    Args:
588      indent: If a non-negative integer, JSON array elements and object members
589        will be pretty-printed with that indent level. An indent level of 0, or
590        negative, will only insert newlines. `None` (the default) selects the
591        most compact representation.
592      separators: Optional `(item_separator, key_separator)` tuple. Default is
593        `(', ', ': ')`.
594      sort_keys: If `True`, the output dictionaries will be sorted by key.
595
596    Returns:
597      A JSON string.
598    """
599    return json.dumps(
600        self.values(),
601        indent=indent,
602        separators=separators,
603        sort_keys=sort_keys)
604
605  def parse_json(self, values_json):
606    """Override existing hyperparameter values, parsing new values from a json object.
607
608    Args:
609      values_json: String containing a json object of name:value pairs.
610
611    Returns:
612      The `HParams` instance.
613
614    Raises:
615      KeyError: If a hyperparameter in `values_json` doesn't exist.
616      ValueError: If `values_json` cannot be parsed.
617    """
618    values_map = json.loads(values_json)
619    return self.override_from_dict(values_map)
620
621  def values(self):
622    """Return the hyperparameter values as a Python dictionary.
623
624    Returns:
625      A dictionary with hyperparameter names as keys.  The values are the
626      hyperparameter values.
627    """
628    return {n: getattr(self, n) for n in self._hparam_types.keys()}
629
630  def get(self, key, default=None):
631    """Returns the value of `key` if it exists, else `default`."""
632    if key in self._hparam_types:
633      # Ensure that default is compatible with the parameter type.
634      if default is not None:
635        param_type, is_param_list = self._hparam_types[key]
636        type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
637        fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
638                    'default=%s' % (key, type_str, default))
639
640        is_default_list = isinstance(default, list)
641        if is_param_list != is_default_list:
642          raise ValueError(fail_msg)
643
644        try:
645          if is_default_list:
646            for value in default:
647              _cast_to_type_if_compatible(key, param_type, value)
648          else:
649            _cast_to_type_if_compatible(key, param_type, default)
650        except ValueError as e:
651          raise ValueError('%s. %s' % (fail_msg, e))
652
653      return getattr(self, key)
654
655    return default
656
657  def __contains__(self, key):
658    return key in self._hparam_types
659
660  def __str__(self):
661    return str(sorted(self.values().items()))
662
663  def __repr__(self):
664    return '%s(%s)' % (type(self).__name__, self.__str__())
665
666  @staticmethod
667  def _get_kind_name(param_type, is_list):
668    """Returns the field name given parameter type and is_list.
669
670    Args:
671      param_type: Data type of the hparam.
672      is_list: Whether this is a list.
673
674    Returns:
675      A string representation of the field name.
676
677    Raises:
678      ValueError: If parameter type is not recognized.
679    """
680    if issubclass(param_type, bool):
681      # This check must happen before issubclass(param_type, six.integer_types),
682      # since Python considers bool to be a subclass of int.
683      typename = 'bool'
684    elif issubclass(param_type, six.integer_types):
685      # Setting 'int' and 'long' types to be 'int64' to ensure the type is
686      # compatible with both Python2 and Python3.
687      typename = 'int64'
688    elif issubclass(param_type, (six.string_types, six.binary_type)):
689      # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
690      # compatible with both Python2 and Python3.
691      typename = 'bytes'
692    elif issubclass(param_type, float):
693      typename = 'float'
694    else:
695      raise ValueError('Unsupported parameter type: %s' % str(param_type))
696
697    suffix = 'list' if is_list else 'value'
698    return '_'.join([typename, suffix])
699
700  def to_proto(self, export_scope=None):  # pylint: disable=unused-argument
701    """Converts a `HParams` object to a `HParamDef` protocol buffer.
702
703    Args:
704      export_scope: Optional `string`. Name scope to remove.
705
706    Returns:
707      A `HParamDef` protocol buffer.
708    """
709    hparam_proto = hparam_pb2.HParamDef()
710    for name in self._hparam_types:
711      # Parse the values.
712      param_type, is_list = self._hparam_types.get(name, (None, None))
713      kind = HParams._get_kind_name(param_type, is_list)
714
715      if is_list:
716        if kind.startswith('bytes'):
717          v_list = [compat.as_bytes(v) for v in getattr(self, name)]
718        else:
719          v_list = [v for v in getattr(self, name)]
720        getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
721      else:
722        v = getattr(self, name)
723        if kind.startswith('bytes'):
724          v = compat.as_bytes(getattr(self, name))
725        setattr(hparam_proto.hparam[name], kind, v)
726
727    return hparam_proto
728
729  @staticmethod
730  def from_proto(hparam_def, import_scope=None):  # pylint: disable=unused-argument
731    return HParams(hparam_def=hparam_def)
732
733
734ops.register_proto_function(
735    'hparams',
736    proto_type=hparam_pb2.HParamDef,
737    to_proto=HParams.to_proto,
738    from_proto=HParams.from_proto)
739