#!/usr/bin/python3
# Copyright 2019, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spec Visualizer
Visualize python spec file for test generator.
Invoked by ml/nn/runtime/test/specs/visualize_spec.sh;
See that script for details on how this script is used.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import sys
from string import Template
# Stuff from test generator
import test_generator as tg
from test_generator import ActivationConverter
from test_generator import BoolScalar
from test_generator import Configuration
from test_generator import DataTypeConverter
from test_generator import DataLayoutConverter
from test_generator import Example
from test_generator import Float16Scalar
from test_generator import Float32Scalar
from test_generator import Float32Vector
from test_generator import GetJointStr
from test_generator import IgnoredOutput
from test_generator import Input
from test_generator import Int32Scalar
from test_generator import Int32Vector
from test_generator import Internal
from test_generator import Model
from test_generator import Operand
from test_generator import Output
from test_generator import Parameter
from test_generator import RelaxedModeConverter
from test_generator import SymmPerChannelQuantParams
TEMPLATE_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "spec_viz_template.html")
global_graphs = dict()
def FormatArray(data, is_scalar=False):
    if is_scalar:
        assert len(data) == 1
        return str(data[0])
    else:
        return "[%s]" % (", ".join(str(i) for i in data))
def FormatDict(data):
    return "
".join("%s: %s"%(k.capitalize(), v) for k, v in data.items())
def GetOperandInfo(op):
    op_info = {"lifetime": op.lifetime, "type": op.type.type}
    if not op.type.IsScalar():
        op_info["dimensions"] = FormatArray(op.type.dimensions)
    if op.type.scale != 0:
        op_info["scale"] = op.type.scale
        op_info["zero point"] = op.type.zeroPoint
    if op.type.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL":
        op_info["scale"] = FormatArray(op.type.extraParams.scales)
        op_info["channel dim"] = op.type.extraParams.channelDim
    return op_info
def FormatOperand(op):
    # All keys and values in op_info will appear in the tooltip. We only display the operand data
    # if the length is less than 10. This should be convenient enough for most parameters.
    op_info = GetOperandInfo(op)
    if isinstance(op, Parameter) and len(op.value) <= 10:
        op_info["data"] = FormatArray(op.value, op.type.IsScalar())
    template = "{tooltip_content}{op_name}"
    return template.format(
        op_name=str(op),
        tooltip_content=FormatDict(op_info),
        inpage_link="#details-operands-%d" % (op.model_index),
    )
def GetSubgraph(example):
    """Produces the nodes and edges information for d3 visualization."""
    node_index_map = {}
    topological_order = []
    def AddToTopologicalOrder(op):
        if op not in node_index_map:
            node_index_map[op] = len(topological_order)
            topological_order.append(op)
    # Get the topological order, both operands and operations are treated the same.
    # Given that the example.model.operations is already topologically sorted, here we simply
    # iterate through and insert inputs and outputs.
    for op in example.model.operations:
        for i in op.ins:
            AddToTopologicalOrder(i)
        AddToTopologicalOrder(op)
        for o in op.outs:
            AddToTopologicalOrder(o)
    # Assign layers to the nodes.
    layers = {}
    for node in topological_order:
        layers[node] = max([layers[i] for i in node.ins], default=-1) + 1
    for node in reversed(topological_order):
        layers[node] = min([layers[o] for o in node.outs], default=layers[node]+1) - 1
    num_layers = max(layers.values()) + 1
    # Assign coordinates to the nodes. Nodes are equally spaced.
    CoordX = lambda index: (index + 0.5) * 200  # 200px spacing horizontally
    CoordY = lambda index: (index + 0.5) * 100  # 100px spacing vertically
    coords = {}
    layer_cnt = [0] * num_layers
    for node in topological_order:
        coords[node] = (CoordX(layer_cnt[layers[node]]), CoordY(layers[node]))
        layer_cnt[layers[node]] += 1
    # Create edges and nodes dictionaries for d3 visualization.
    OpName = lambda idx: "operation%d" % idx
    edges = []
    nodes = []
    for ind, op in enumerate(example.model.operations):
        for tensor in op.ins:
            edges.append({
                "source": str(tensor),
                "target": OpName(ind)
            })
        for tensor in op.outs:
            edges.append({
                "target": str(tensor),
                "source": OpName(ind)
            })
        nodes.append({
            "index": ind,
            "id": OpName(ind),
            "name": op.optype,
            "group": 2,
            "x": coords[op][0],
            "y": coords[op][1],
        })
    for ind, op in enumerate(example.model.operands):
        nodes.append({
            "index": ind,
            "id": str(op),
            "name": str(op),
            "group": 1,
            "x": coords[op][0],
            "y": coords[op][1],
        })
    return {"nodes": nodes, "edges": edges}
# The following Get**Info methods will each return a list of dictionaries,
# whose content will appear in the tables and sidebar views.
def GetConfigurationsInfo(example):
    return [{
        "relaxed": str(example.model.isRelaxed),
        "use shared memory": str(tg.Configuration.useSHM()),
        "expect failure": str(example.expectFailure),
    }]
def GetOperandsInfo(example):
    ret = []
    for index, op in enumerate(example.model.operands):
        ret.append({
                "index": index,
                "name": str(op),
                "group": "operand"
            })
        ret[-1].update(GetOperandInfo(op))
        if isinstance(op, (Parameter, Input, Output)):
            ret[-1]["data"] = FormatArray(op.value, op.type.IsScalar())
    return ret
def GetOperationsInfo(example):
    return [{
            "index": index,
            "name": op.optype,
            "group": "operation",
            "opcode": op.optype,
            "inputs": ", ".join(FormatOperand(i) for i in op.ins),
            "outputs": ", ".join(FormatOperand(o) for o in op.outs),
        } for index,op in enumerate(example.model.operations)]
# TODO: Remove the unused fd from the parameter.
def ProcessExample(example, fd):
    """Process an example and save the information into the global dictionary global_graphs."""
    global global_graphs
    print("    Processing variation %s" % example.testName)
    global_graphs[str(example.testName)] = {
        "subgraph": GetSubgraph(example),
        "details": {
            "configurations": GetConfigurationsInfo(example),
            "operands": GetOperandsInfo(example),
            "operations": GetOperationsInfo(example)
        }
    }
def DumpHtml(spec_file, out_file):
    """Dump the final HTML file by replacing entries from a template file."""
    with open(TEMPLATE_FILE, "r") as template_fd:
        html_template = template_fd.read()
    with open(out_file, "w") as out_fd:
        out_fd.write(Template(html_template).substitute(
            spec_name=os.path.basename(spec_file),
            graph_dump=json.dumps(global_graphs),
        ))
def ParseCmdLine():
    parser = argparse.ArgumentParser()
    parser.add_argument("spec", help="the spec file")
    parser.add_argument("-o", "--out", help="the output html path", default="out.html")
    args = parser.parse_args()
    tg.FileNames.InitializeFileLists(args.spec, "-")
    tg.FileNames.NextFile()
    return os.path.abspath(args.spec), os.path.abspath(args.out)
if __name__ == '__main__':
    spec_file, out_file = ParseCmdLine()
    print("Visualizing from spec: %s" % spec_file)
    exec(open(spec_file, "r").read())
    Example.DumpAllExamples(DumpExample=ProcessExample, example_fd=0)
    DumpHtml(spec_file, out_file)
    print("Output HTML file: %s" % out_file)