1# mypy: allow-untyped-defs 2from typing import Optional 3from tensorboard.compat.proto.node_def_pb2 import NodeDef 4from tensorboard.compat.proto.attr_value_pb2 import AttrValue 5from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto 6 7 8def attr_value_proto(dtype, shape, s): 9 """Create a dict of objects matching a NodeDef's attr field. 10 11 Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto 12 specifically designed for a NodeDef. The values have been reverse engineered from 13 standard TensorBoard logged data. 14 """ 15 attr = {} 16 if s is not None: 17 attr["attr"] = AttrValue(s=s.encode(encoding="utf_8")) 18 if shape is not None: 19 shapeproto = tensor_shape_proto(shape) 20 attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) 21 return attr 22 23 24def tensor_shape_proto(outputsize): 25 """Create an object matching a tensor_shape field. 26 27 Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . 28 """ 29 return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) 30 31 32def node_proto( 33 name, 34 op="UnSpecified", 35 input=None, 36 dtype=None, 37 shape: Optional[tuple] = None, 38 outputsize=None, 39 attributes="", 40): 41 """Create an object matching a NodeDef. 42 43 Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . 44 """ 45 if input is None: 46 input = [] 47 if not isinstance(input, list): 48 input = [input] 49 return NodeDef( 50 name=name.encode(encoding="utf_8"), 51 op=op, 52 input=input, 53 attr=attr_value_proto(dtype, outputsize, attributes), 54 ) 55