1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Hyperparameter values.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import json 21import numbers 22import re 23 24import six 25 26from tensorflow.contrib.training.python.training import hparam_pb2 27from tensorflow.python.framework import ops 28from tensorflow.python.util import compat 29from tensorflow.python.util import deprecation 30 31# Define the regular expression for parsing a single clause of the input 32# (delimited by commas). A legal clause looks like: 33# <variable name>[<index>]? = <rhs> 34# where <rhs> is either a single token or [] enclosed list of tokens. 35# For example: "var[1] = a" or "x = [1,2,3]" 36PARAM_RE = re.compile(r""" 37 (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" 38 (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None 39 \s*=\s* 40 ((?P<val>[^,\[]*) # single value: "a" or None 41 | 42 \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3" 43 ($|,\s*)""", re.VERBOSE) 44 45 46def _parse_fail(name, var_type, value, values): 47 """Helper function for raising a value error for bad assignment.""" 48 raise ValueError( 49 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % 50 (name, var_type.__name__, value, values)) 51 52 53def _reuse_fail(name, values): 54 """Helper function for raising a value error for reuse of name.""" 55 raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, 56 values)) 57 58 59def _process_scalar_value(name, parse_fn, var_type, m_dict, values, 60 results_dictionary): 61 """Update results_dictionary with a scalar value. 62 63 Used to update the results_dictionary to be returned by parse_values when 64 encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) 65 66 Mutates results_dictionary. 67 68 Args: 69 name: Name of variable in assignment ("s" or "arr"). 70 parse_fn: Function for parsing the actual value. 71 var_type: Type of named variable. 72 m_dict: Dictionary constructed from regex parsing. 73 m_dict['val']: RHS value (scalar) 74 m_dict['index']: List index value (or None) 75 values: Full expression being parsed 76 results_dictionary: The dictionary being updated for return by the parsing 77 function. 78 79 Raises: 80 ValueError: If the name has already been used. 81 """ 82 try: 83 parsed_value = parse_fn(m_dict['val']) 84 except ValueError: 85 _parse_fail(name, var_type, m_dict['val'], values) 86 87 # If no index is provided 88 if not m_dict['index']: 89 if name in results_dictionary: 90 _reuse_fail(name, values) 91 results_dictionary[name] = parsed_value 92 else: 93 if name in results_dictionary: 94 # The name has already been used as a scalar, then it 95 # will be in this dictionary and map to a non-dictionary. 96 if not isinstance(results_dictionary.get(name), dict): 97 _reuse_fail(name, values) 98 else: 99 results_dictionary[name] = {} 100 101 index = int(m_dict['index']) 102 # Make sure the index position hasn't already been assigned a value. 103 if index in results_dictionary[name]: 104 _reuse_fail('{}[{}]'.format(name, index), values) 105 results_dictionary[name][index] = parsed_value 106 107 108def _process_list_value(name, parse_fn, var_type, m_dict, values, 109 results_dictionary): 110 """Update results_dictionary from a list of values. 111 112 Used to update results_dictionary to be returned by parse_values when 113 encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) 114 115 Mutates results_dictionary. 116 117 Args: 118 name: Name of variable in assignment ("arr"). 119 parse_fn: Function for parsing individual values. 120 var_type: Type of named variable. 121 m_dict: Dictionary constructed from regex parsing. 122 m_dict['val']: RHS value (scalar) 123 values: Full expression being parsed 124 results_dictionary: The dictionary being updated for return by the parsing 125 function. 126 127 Raises: 128 ValueError: If the name has an index or the values cannot be parsed. 129 """ 130 if m_dict['index'] is not None: 131 raise ValueError('Assignment of a list to a list index.') 132 elements = filter(None, re.split('[ ,]', m_dict['vals'])) 133 # Make sure the name hasn't already been assigned a value 134 if name in results_dictionary: 135 raise _reuse_fail(name, values) 136 try: 137 results_dictionary[name] = [parse_fn(e) for e in elements] 138 except ValueError: 139 _parse_fail(name, var_type, m_dict['vals'], values) 140 141 142def _cast_to_type_if_compatible(name, param_type, value): 143 """Cast hparam to the provided type, if compatible. 144 145 Args: 146 name: Name of the hparam to be cast. 147 param_type: The type of the hparam. 148 value: The value to be cast, if compatible. 149 150 Returns: 151 The result of casting `value` to `param_type`. 152 153 Raises: 154 ValueError: If the type of `value` is not compatible with param_type. 155 * If `param_type` is a string type, but `value` is not. 156 * If `param_type` is a boolean, but `value` is not, or vice versa. 157 * If `param_type` is an integer type, but `value` is not. 158 * If `param_type` is a float type, but `value` is not a numeric type. 159 """ 160 fail_msg = ( 161 "Could not cast hparam '%s' of type '%s' from value %r" % 162 (name, param_type, value)) 163 164 # Some callers use None, for which we can't do any casting/checking. :( 165 if issubclass(param_type, type(None)): 166 return value 167 168 # Avoid converting a non-string type to a string. 169 if (issubclass(param_type, (six.string_types, six.binary_type)) and 170 not isinstance(value, (six.string_types, six.binary_type))): 171 raise ValueError(fail_msg) 172 173 # Avoid converting a number or string type to a boolean or vice versa. 174 if issubclass(param_type, bool) != isinstance(value, bool): 175 raise ValueError(fail_msg) 176 177 # Avoid converting float to an integer (the reverse is fine). 178 if (issubclass(param_type, numbers.Integral) and 179 not isinstance(value, numbers.Integral)): 180 raise ValueError(fail_msg) 181 182 # Avoid converting a non-numeric type to a numeric type. 183 if (issubclass(param_type, numbers.Number) and 184 not isinstance(value, numbers.Number)): 185 raise ValueError(fail_msg) 186 187 return param_type(value) 188 189 190def parse_values(values, type_map, ignore_unknown=False): 191 """Parses hyperparameter values from a string into a python map. 192 193 `values` is a string containing comma-separated `name=value` pairs. 194 For each pair, the value of the hyperparameter named `name` is set to 195 `value`. 196 197 If a hyperparameter name appears multiple times in `values`, a ValueError 198 is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). 199 200 If a hyperparameter name in both an index assignment and scalar assignment, 201 a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). 202 203 The hyperparameter name may contain '.' symbols, which will result in an 204 attribute name that is only accessible through the getattr and setattr 205 functions. (And must be first explicit added through add_hparam.) 206 207 WARNING: Use of '.' in your variable names is allowed, but is not well 208 supported and not recommended. 209 210 The `value` in `name=value` must follows the syntax according to the 211 type of the parameter: 212 213 * Scalar integer: A Python-parsable integer point value. E.g.: 1, 214 100, -12. 215 * Scalar float: A Python-parsable floating point value. E.g.: 1.0, 216 -.54e89. 217 * Boolean: Either true or false. 218 * Scalar string: A non-empty sequence of characters, excluding comma, 219 spaces, and square brackets. E.g.: foo, bar_1. 220 * List: A comma separated list of scalar values of the parameter type 221 enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. 222 223 When index assignment is used, the corresponding type_map key should be the 224 list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not 225 "arr[1]"). 226 227 Args: 228 values: String. Comma separated list of `name=value` pairs where 229 'value' must follow the syntax described above. 230 type_map: A dictionary mapping hyperparameter names to types. Note every 231 parameter name in values must be a key in type_map. The values must 232 conform to the types indicated, where a value V is said to conform to a 233 type T if either V has type T, or V is a list of elements of type T. 234 Hence, for a multidimensional parameter 'x' taking float values, 235 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. 236 ignore_unknown: Bool. Whether values that are missing a type in type_map 237 should be ignored. If set to True, a ValueError will not be raised for 238 unknown hyperparameter type. 239 240 Returns: 241 A python map mapping each name to either: 242 * A scalar value. 243 * A list of scalar values. 244 * A dictionary mapping index numbers to scalar values. 245 (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") 246 247 Raises: 248 ValueError: If there is a problem with input. 249 * If `values` cannot be parsed. 250 * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). 251 * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', 252 'a[1]=1,a[1]=2', or 'a=1,a=[1]') 253 """ 254 results_dictionary = {} 255 pos = 0 256 while pos < len(values): 257 m = PARAM_RE.match(values, pos) 258 if not m: 259 raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) 260 # Check that there is a comma between parameters and move past it. 261 pos = m.end() 262 # Parse the values. 263 m_dict = m.groupdict() 264 name = m_dict['name'] 265 if name not in type_map: 266 if ignore_unknown: 267 continue 268 raise ValueError('Unknown hyperparameter type for %s' % name) 269 type_ = type_map[name] 270 271 # Set up correct parsing function (depending on whether type_ is a bool) 272 if type_ == bool: 273 274 def parse_bool(value): 275 if value in ['true', 'True']: 276 return True 277 elif value in ['false', 'False']: 278 return False 279 else: 280 try: 281 return bool(int(value)) 282 except ValueError: 283 _parse_fail(name, type_, value, values) 284 285 parse = parse_bool 286 else: 287 parse = type_ 288 289 # If a singe value is provided 290 if m_dict['val'] is not None: 291 _process_scalar_value(name, parse, type_, m_dict, values, 292 results_dictionary) 293 294 # If the assigned value is a list: 295 elif m_dict['vals'] is not None: 296 _process_list_value(name, parse, type_, m_dict, values, 297 results_dictionary) 298 299 else: # Not assigned a list or value 300 _parse_fail(name, type_, '', values) 301 302 return results_dictionary 303 304 305class HParams(object): 306 """Class to hold a set of hyperparameters as name-value pairs. 307 308 A `HParams` object holds hyperparameters used to build and train a model, 309 such as the number of hidden units in a neural net layer or the learning rate 310 to use when training. 311 312 You first create a `HParams` object by specifying the names and values of the 313 hyperparameters. 314 315 To make them easily accessible the parameter names are added as direct 316 attributes of the class. A typical usage is as follows: 317 318 ```python 319 # Create a HParams object specifying names and values of the model 320 # hyperparameters: 321 hparams = HParams(learning_rate=0.1, num_hidden_units=100) 322 323 # The hyperparameter are available as attributes of the HParams object: 324 hparams.learning_rate ==> 0.1 325 hparams.num_hidden_units ==> 100 326 ``` 327 328 Hyperparameters have type, which is inferred from the type of their value 329 passed at construction type. The currently supported types are: integer, 330 float, boolean, string, and list of integer, float, boolean, or string. 331 332 You can override hyperparameter values by calling the 333 [`parse()`](#HParams.parse) method, passing a string of comma separated 334 `name=value` pairs. This is intended to make it possible to override 335 any hyperparameter values from a single command-line flag to which 336 the user passes 'hyper-param=value' pairs. It avoids having to define 337 one flag for each hyperparameter. 338 339 The syntax expected for each value depends on the type of the parameter. 340 See `parse()` for a description of the syntax. 341 342 Example: 343 344 ```python 345 # Define a command line flag to pass name=value pairs. 346 # For example using argparse: 347 import argparse 348 parser = argparse.ArgumentParser(description='Train my model.') 349 parser.add_argument('--hparams', type=str, 350 help='Comma separated list of "name=value" pairs.') 351 args = parser.parse_args() 352 ... 353 def my_program(): 354 # Create a HParams object specifying the names and values of the 355 # model hyperparameters: 356 hparams = tf.contrib.training.HParams( 357 learning_rate=0.1, 358 num_hidden_units=100, 359 activations=['relu', 'tanh']) 360 361 # Override hyperparameters values by parsing the command line 362 hparams.parse(args.hparams) 363 364 # If the user passed `--hparams=learning_rate=0.3` on the command line 365 # then 'hparams' has the following attributes: 366 hparams.learning_rate ==> 0.3 367 hparams.num_hidden_units ==> 100 368 hparams.activations ==> ['relu', 'tanh'] 369 370 # If the hyperparameters are in json format use parse_json: 371 hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') 372 ``` 373 """ 374 375 _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. 376 377 def __init__(self, hparam_def=None, model_structure=None, **kwargs): 378 """Create an instance of `HParams` from keyword arguments. 379 380 The keyword arguments specify name-values pairs for the hyperparameters. 381 The parameter types are inferred from the type of the values passed. 382 383 The parameter names are added as attributes of `HParams` object, so they 384 can be accessed directly with the dot notation `hparams._name_`. 385 386 Example: 387 388 ```python 389 # Define 3 hyperparameters: 'learning_rate' is a float parameter, 390 # 'num_hidden_units' an integer parameter, and 'activation' a string 391 # parameter. 392 hparams = tf.contrib.training.HParams( 393 learning_rate=0.1, num_hidden_units=100, activation='relu') 394 395 hparams.activation ==> 'relu' 396 ``` 397 398 Note that a few names are reserved and cannot be used as hyperparameter 399 names. If you use one of the reserved name the constructor raises a 400 `ValueError`. 401 402 Args: 403 hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef 404 protocol buffer. If provided, this object is initialized by 405 deserializing hparam_def. Otherwise **kwargs is used. 406 model_structure: An instance of ModelStructure, defining the feature 407 crosses to be used in the Trial. 408 **kwargs: Key-value pairs where the key is the hyperparameter name and 409 the value is the value for the parameter. 410 411 Raises: 412 ValueError: If both `hparam_def` and initialization values are provided, 413 or if one of the arguments is invalid. 414 415 """ 416 # Register the hyperparameters and their type in _hparam_types. 417 # This simplifies the implementation of parse(). 418 # _hparam_types maps the parameter name to a tuple (type, bool). 419 # The type value is the type of the parameter for scalar hyperparameters, 420 # or the type of the list elements for multidimensional hyperparameters. 421 # The bool value is True if the value is a list, False otherwise. 422 self._hparam_types = {} 423 self._model_structure = model_structure 424 if hparam_def: 425 self._init_from_proto(hparam_def) 426 if kwargs: 427 raise ValueError('hparam_def and initialization values are ' 428 'mutually exclusive') 429 else: 430 for name, value in six.iteritems(kwargs): 431 self.add_hparam(name, value) 432 433 def _init_from_proto(self, hparam_def): 434 """Creates a new HParams from `HParamDef` protocol buffer. 435 436 Args: 437 hparam_def: `HParamDef` protocol buffer. 438 """ 439 assert isinstance(hparam_def, hparam_pb2.HParamDef) 440 for name, value in hparam_def.hparam.items(): 441 kind = value.WhichOneof('kind') 442 if kind.endswith('_value'): 443 # Single value. 444 if kind.startswith('int64'): 445 # Setting attribute value to be 'int' to ensure the type is compatible 446 # with both Python2 and Python3. 447 self.add_hparam(name, int(getattr(value, kind))) 448 elif kind.startswith('bytes'): 449 # Setting attribute value to be 'str' to ensure the type is compatible 450 # with both Python2 and Python3. UTF-8 encoding is assumed. 451 self.add_hparam(name, compat.as_str(getattr(value, kind))) 452 else: 453 self.add_hparam(name, getattr(value, kind)) 454 else: 455 # List of values. 456 if kind.startswith('int64'): 457 # Setting attribute value to be 'int' to ensure the type is compatible 458 # with both Python2 and Python3. 459 self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) 460 elif kind.startswith('bytes'): 461 # Setting attribute value to be 'str' to ensure the type is compatible 462 # with both Python2 and Python3. UTF-8 encoding is assumed. 463 self.add_hparam( 464 name, [compat.as_str(v) for v in getattr(value, kind).value]) 465 else: 466 self.add_hparam(name, [v for v in getattr(value, kind).value]) 467 468 def add_hparam(self, name, value): 469 """Adds {name, value} pair to hyperparameters. 470 471 Args: 472 name: Name of the hyperparameter. 473 value: Value of the hyperparameter. Can be one of the following types: 474 int, float, string, int list, float list, or string list. 475 476 Raises: 477 ValueError: if one of the arguments is invalid. 478 """ 479 # Keys in kwargs are unique, but 'name' could the name of a pre-existing 480 # attribute of this object. In that case we refuse to use it as a 481 # hyperparameter name. 482 if getattr(self, name, None) is not None: 483 raise ValueError('Hyperparameter name is reserved: %s' % name) 484 if isinstance(value, (list, tuple)): 485 if not value: 486 raise ValueError( 487 'Multi-valued hyperparameters cannot be empty: %s' % name) 488 self._hparam_types[name] = (type(value[0]), True) 489 else: 490 self._hparam_types[name] = (type(value), False) 491 setattr(self, name, value) 492 493 def set_hparam(self, name, value): 494 """Set the value of an existing hyperparameter. 495 496 This function verifies that the type of the value matches the type of the 497 existing hyperparameter. 498 499 Args: 500 name: Name of the hyperparameter. 501 value: New value of the hyperparameter. 502 503 Raises: 504 KeyError: If the hyperparameter doesn't exist. 505 ValueError: If there is a type mismatch. 506 """ 507 param_type, is_list = self._hparam_types[name] 508 if isinstance(value, list): 509 if not is_list: 510 raise ValueError( 511 'Must not pass a list for single-valued parameter: %s' % name) 512 setattr(self, name, [ 513 _cast_to_type_if_compatible(name, param_type, v) for v in value]) 514 else: 515 if is_list: 516 raise ValueError( 517 'Must pass a list for multi-valued parameter: %s.' % name) 518 setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) 519 520 def del_hparam(self, name): 521 """Removes the hyperparameter with key 'name'. 522 523 Does nothing if it isn't present. 524 525 Args: 526 name: Name of the hyperparameter. 527 """ 528 if hasattr(self, name): 529 delattr(self, name) 530 del self._hparam_types[name] 531 532 def parse(self, values): 533 """Override existing hyperparameter values, parsing new values from a string. 534 535 See parse_values for more detail on the allowed format for values. 536 537 Args: 538 values: String. Comma separated list of `name=value` pairs where 'value' 539 must follow the syntax described above. 540 541 Returns: 542 The `HParams` instance. 543 544 Raises: 545 ValueError: If `values` cannot be parsed or a hyperparameter in `values` 546 doesn't exist. 547 """ 548 type_map = dict() 549 for name, t in self._hparam_types.items(): 550 param_type, _ = t 551 type_map[name] = param_type 552 553 values_map = parse_values(values, type_map) 554 return self.override_from_dict(values_map) 555 556 def override_from_dict(self, values_dict): 557 """Override existing hyperparameter values, parsing new values from a dictionary. 558 559 Args: 560 values_dict: Dictionary of name:value pairs. 561 562 Returns: 563 The `HParams` instance. 564 565 Raises: 566 KeyError: If a hyperparameter in `values_dict` doesn't exist. 567 ValueError: If `values_dict` cannot be parsed. 568 """ 569 for name, value in values_dict.items(): 570 self.set_hparam(name, value) 571 return self 572 573 @deprecation.deprecated(None, 'Use `override_from_dict`.') 574 def set_from_map(self, values_map): 575 """DEPRECATED. Use override_from_dict.""" 576 return self.override_from_dict(values_dict=values_map) 577 578 def set_model_structure(self, model_structure): 579 self._model_structure = model_structure 580 581 def get_model_structure(self): 582 return self._model_structure 583 584 def to_json(self, indent=None, separators=None, sort_keys=False): 585 """Serializes the hyperparameters into JSON. 586 587 Args: 588 indent: If a non-negative integer, JSON array elements and object members 589 will be pretty-printed with that indent level. An indent level of 0, or 590 negative, will only insert newlines. `None` (the default) selects the 591 most compact representation. 592 separators: Optional `(item_separator, key_separator)` tuple. Default is 593 `(', ', ': ')`. 594 sort_keys: If `True`, the output dictionaries will be sorted by key. 595 596 Returns: 597 A JSON string. 598 """ 599 return json.dumps( 600 self.values(), 601 indent=indent, 602 separators=separators, 603 sort_keys=sort_keys) 604 605 def parse_json(self, values_json): 606 """Override existing hyperparameter values, parsing new values from a json object. 607 608 Args: 609 values_json: String containing a json object of name:value pairs. 610 611 Returns: 612 The `HParams` instance. 613 614 Raises: 615 KeyError: If a hyperparameter in `values_json` doesn't exist. 616 ValueError: If `values_json` cannot be parsed. 617 """ 618 values_map = json.loads(values_json) 619 return self.override_from_dict(values_map) 620 621 def values(self): 622 """Return the hyperparameter values as a Python dictionary. 623 624 Returns: 625 A dictionary with hyperparameter names as keys. The values are the 626 hyperparameter values. 627 """ 628 return {n: getattr(self, n) for n in self._hparam_types.keys()} 629 630 def get(self, key, default=None): 631 """Returns the value of `key` if it exists, else `default`.""" 632 if key in self._hparam_types: 633 # Ensure that default is compatible with the parameter type. 634 if default is not None: 635 param_type, is_param_list = self._hparam_types[key] 636 type_str = 'list<%s>' % param_type if is_param_list else str(param_type) 637 fail_msg = ("Hparam '%s' of type '%s' is incompatible with " 638 'default=%s' % (key, type_str, default)) 639 640 is_default_list = isinstance(default, list) 641 if is_param_list != is_default_list: 642 raise ValueError(fail_msg) 643 644 try: 645 if is_default_list: 646 for value in default: 647 _cast_to_type_if_compatible(key, param_type, value) 648 else: 649 _cast_to_type_if_compatible(key, param_type, default) 650 except ValueError as e: 651 raise ValueError('%s. %s' % (fail_msg, e)) 652 653 return getattr(self, key) 654 655 return default 656 657 def __contains__(self, key): 658 return key in self._hparam_types 659 660 def __str__(self): 661 return str(sorted(self.values().items())) 662 663 def __repr__(self): 664 return '%s(%s)' % (type(self).__name__, self.__str__()) 665 666 @staticmethod 667 def _get_kind_name(param_type, is_list): 668 """Returns the field name given parameter type and is_list. 669 670 Args: 671 param_type: Data type of the hparam. 672 is_list: Whether this is a list. 673 674 Returns: 675 A string representation of the field name. 676 677 Raises: 678 ValueError: If parameter type is not recognized. 679 """ 680 if issubclass(param_type, bool): 681 # This check must happen before issubclass(param_type, six.integer_types), 682 # since Python considers bool to be a subclass of int. 683 typename = 'bool' 684 elif issubclass(param_type, six.integer_types): 685 # Setting 'int' and 'long' types to be 'int64' to ensure the type is 686 # compatible with both Python2 and Python3. 687 typename = 'int64' 688 elif issubclass(param_type, (six.string_types, six.binary_type)): 689 # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is 690 # compatible with both Python2 and Python3. 691 typename = 'bytes' 692 elif issubclass(param_type, float): 693 typename = 'float' 694 else: 695 raise ValueError('Unsupported parameter type: %s' % str(param_type)) 696 697 suffix = 'list' if is_list else 'value' 698 return '_'.join([typename, suffix]) 699 700 def to_proto(self, export_scope=None): # pylint: disable=unused-argument 701 """Converts a `HParams` object to a `HParamDef` protocol buffer. 702 703 Args: 704 export_scope: Optional `string`. Name scope to remove. 705 706 Returns: 707 A `HParamDef` protocol buffer. 708 """ 709 hparam_proto = hparam_pb2.HParamDef() 710 for name in self._hparam_types: 711 # Parse the values. 712 param_type, is_list = self._hparam_types.get(name, (None, None)) 713 kind = HParams._get_kind_name(param_type, is_list) 714 715 if is_list: 716 if kind.startswith('bytes'): 717 v_list = [compat.as_bytes(v) for v in getattr(self, name)] 718 else: 719 v_list = [v for v in getattr(self, name)] 720 getattr(hparam_proto.hparam[name], kind).value.extend(v_list) 721 else: 722 v = getattr(self, name) 723 if kind.startswith('bytes'): 724 v = compat.as_bytes(getattr(self, name)) 725 setattr(hparam_proto.hparam[name], kind, v) 726 727 return hparam_proto 728 729 @staticmethod 730 def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument 731 return HParams(hparam_def=hparam_def) 732 733 734ops.register_proto_function( 735 'hparams', 736 proto_type=hparam_pb2.HParamDef, 737 to_proto=HParams.to_proto, 738 from_proto=HParams.from_proto) 739