• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Slicing the input Model file
17
18Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is
19not intended to be invoked directly by the users. See that script for
20details on how to use the slicing tool is used.
21
22This script does the following work:
23
24Perform a topological sort similar to the test generator, except that:
25* It would stop at the N-th operation it encounters, and
26* Rename the output of the N-th operation to a model output, and
27* Name that as the output of the model.
28* Also only inputs and weights used by the submodel would be emitted.
29
30"""
31
32from __future__ import absolute_import
33from __future__ import division
34from __future__ import print_function
35import argparse
36from functools import reduce
37import math
38import os
39import struct
40import sys
41import contextlib
42import test_generator
43import pprint
44# Stuff from test generator
45from test_generator import Configuration
46from test_generator import Example
47from test_generator import Float32Scalar
48from test_generator import Input
49from test_generator import Int32Scalar
50from test_generator import Internal
51from test_generator import Model
52from test_generator import Output
53from test_generator import Parameter
54from test_generator import smart_open
55
56
57# Take a model from command line
58def import_source():
59  parser = argparse.ArgumentParser()
60  parser.add_argument("spec", help="the spec file")
61  parser.add_argument(
62      "-n", "--number",
63      help="number of operations in the sliced model. Default = 1",
64      default=1)
65  parser.add_argument(
66      "-m", "--model", help="the output model file", default="-")
67  parser.add_argument(
68      "-e", "--example", help="the output example file", default="-")
69  args = parser.parse_args()
70
71  if os.path.exists(args.spec):
72    test_generator.FileNames.SpecFile = os.path.basename(args.spec)
73    exec (open(args.spec).read())
74  else:
75    print("cannot find file %s" % args.spec)
76    sys.exit(1)
77
78  return (args.model, args.example, args.number)
79
80
81# Slice till the Nth op the topological sort finds
82# the output of that op becomes the output of the model
83class slicing:
84
85  def __init__(self, threshold):
86    self.__nr_op_seen = 0
87    self.__threshold = threshold
88    self.__last_outs = []
89    self.__all_formatted_ops = []
90    self.__referenced_operands = set()
91
92  def format_as_py_op(self, op):
93    try:
94      fmt = op.PyDefinition()
95    except AttributeError:  # not an op, but things like weights
96      return True
97    if fmt is not None:
98      self.__nr_op_seen += 1
99      if self.__nr_op_seen > self.__threshold:
100        return False
101      self.__last_outs = op.outs
102      for o in op.ins:
103        self.__referenced_operands.add(o)
104      for o in op.outs:
105        self.__referenced_operands.add(o)
106      self.__all_formatted_ops.append("model = model.%s" % fmt)
107      return True
108
109  def dump(self, model_file):
110    for x in self.__all_formatted_ops:
111      print(x, file=model_file)
112
113  def dump_example(self, example_file):
114    override = {}
115    # Make alias for the output variable
116    for lo in self.__last_outs:
117      override[lo.get_name()] = lo.type.get_nr_elements()
118      alias_def = """\
119# Alias for the output variable {operand_name}
120aliased_output{number} = {operand_name}
121"""
122      op = {
123          'operand_name': lo.get_name(),
124          'number': 0 # only support one output as of now
125      }
126      print (alias_def.format(**op), file=example_file)
127    Example.py_dump(example_file, override, self.__referenced_operands)
128
129  def format_operands(self):
130    # Dump operand definitions
131    op_definitions = []
132    for o in test_generator.Operand.operands.objects():
133      if o not in self.__referenced_operands:
134        continue
135      ty = o.type
136      raw_shape = ty.get_raw_shape()
137      op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """
138      if isinstance(o, test_generator.Parameter):
139        op_def += """, {initializer})"""
140        init = o.initializer
141        py_operand_name = "Parameter"
142      else:
143        op_def += ")"
144        init = []
145        py_operand_name = "IgnoredOutput" if o in set(
146            self.__last_outs) else o.__class__.__name__
147
148      op = {
149          "element_type": ty.get_element_type(),
150          "shape": ty.get_raw_shape(),
151          "op_name": o.get_name(),
152          "operand": py_operand_name,
153          "initializer": init
154      }
155      op_definitions.append(op_def.format(**op))
156    return "\n".join(op_definitions)
157
158
159if __name__ == "__main__":
160  (model, example, number) = import_source()
161  s = slicing(int(number))
162
163  with smart_open(model) as model_file:
164    spec_file = " (from: %s)" % (test_generator.FileNames.SpecFile)
165    print("# Generated file%s. Do not edit" % (spec_file), file=model_file)
166    print("model = Model()", file=model_file)
167    test_generator.TopologicalSort(lambda x: s.format_as_py_op(x))
168    print(s.format_operands(), file=model_file)
169    s.dump(model_file)
170  with smart_open(example) as example_file:
171    s.dump_example(example_file)
172