• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Generate a mock model for LLVM tests for Register Allocation.
2The generated model is not a neural net - it is just a tf.function with the
3correct input and output parameters. By construction, the mock model will always
4output the first liverange that can be evicted.
5"""
6import os
7import sys
8import tensorflow as tf
9POLICY_DECISION_LABEL = 'index_to_evict'
10POLICY_OUTPUT_SPEC = """
11[
12    {
13        "logging_name": "index_to_evict",
14        "tensor_spec": {
15            "name": "StatefulPartitionedCall",
16            "port": 0,
17            "type": "int64_t",
18            "shape": [
19                1
20            ]
21        }
22    }
23]
24"""
25PER_REGISTER_FEATURE_LIST = ['mask']
26NUM_REGISTERS = 33
27
28
29def get_input_signature():
30  """Returns (time_step_spec, action_spec) for LLVM register allocation."""
31  inputs = dict(
32      (key, tf.TensorSpec(dtype=tf.int64, shape=(NUM_REGISTERS), name=key))
33      for key in PER_REGISTER_FEATURE_LIST)
34  return inputs
35
36
37def get_output_spec_path(path):
38  return os.path.join(path, 'output_spec.json')
39
40
41def build_mock_model(path):
42  """Build and save the mock model with the given signature."""
43  module = tf.Module()
44  # We have to set this useless variable in order for the TF C API to correctly
45  # intake it
46  module.var = tf.Variable(0, dtype=tf.int64)
47
48  def action(*inputs):
49    result = tf.math.argmax(
50        tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
51    return {POLICY_DECISION_LABEL: result}
52  module.action = tf.function()(action)
53  action = {
54      'action': module.action.get_concrete_function(get_input_signature())
55  }
56  tf.saved_model.save(module, path, signatures=action)
57  output_spec_path = get_output_spec_path(path)
58  with open(output_spec_path, 'w') as f:
59    print(f'Writing output spec to {output_spec_path}.')
60    f.write(POLICY_OUTPUT_SPEC)
61
62
63def main(argv):
64  assert len(argv) == 2
65  model_path = argv[1]
66  build_mock_model(model_path)
67
68
69if __name__ == '__main__':
70  main(sys.argv)
71