#!/usr/bin/python3
# stack_size - a program to parse ARM assembly files and produce
# a listing of stack sizes for the functions encountered.
#
# To use: compile your kernel as normal, then run this program on vmlinux:
# $ tools/stack_size/stack_size vmlinux
#
# If you have cross-compiled, then your CROSS_COMPILE environment variable
# should be set to the cross prefix for your toolchain
#
# Author: Tim Bird <tim dot bird ~(at)~ am dot sony dot com>
# Copyright 2011 Sony Network Entertainment, Inc
#
# GPL version 2.0 applies.
#

import os
import sys
import functools

from stat import *

VERSION = "1.1"

DEBUG = 0

def dprint(msg):
    global DEBUG
    if DEBUG:
        print(msg)

def usage():
    prog = os.path.basename(sys.argv[0])
    print("""Usage: %s [options] <filename>

Show stack size for functions in a program or object file.
Generate and parse the ARM assembly for the indicated filename,
and print out the size of the stack for each function.

Ex: $ %s -t 20 vmlinux

Options:
 -h       show this usage help
 -t <n>   show the top n functions with the largest stacks
          (default is to show top 10 functions with largest stacks)
 -f       show full function list
 -g       show histogram of stack depths
 -V or --version    show version information and exit
""" % (prog, prog))
    sys.exit(0)

class function_class():
    def __init__(self, name, addr):
        self.name = name
        self.addr = addr
        self.depth = 0

    def __repr__(self):
        return "function: %s with depth %d" % (self.name, self.depth)

def open_assembly(filename):
    base = filename
    if base.endswith(".o"):
        base = base[:-2]

    assembly_filename = "%s.S" % base

    try:
        st = os.stat(filename)
    except:
        print("Could not find input file '%s'\n" % filename)
        sys.exit(1)

    generate_assembly = 0
    try:
        assembly_st = os.stat(assembly_filename)
        # test whether source file is newer than assembly file
        if st[ST_MTIME] > assembly_st[ST_MTIME]:
            generate_assembly = 1
    except:
        print("Could not find '%s' - trying to generate from '%s'\n" % \
	      (assembly_filename, filename))
        generate_assembly = 1

    if generate_assembly:
        print("Generating assembly file.  This might take a while...")
        cross = os.environ["CROSS_COMPILE"]
        cmd = "%sobjdump -d %s >%s" % \
        (cross, filename, assembly_filename)
        print("Using command: '%s'" % cmd)
        os.system(cmd)

    try:
        fd = open(assembly_filename)
    except:
        print("Could not find '%s' -- conversion didn't work\n" % \
	      (assembly_filename))
        sys.exit(1)

    return fd

def do_show_histogram(functions):
    depth_count = {}

    max_depth_bucket = 0
    for funcname in list(functions.keys()):
        depth = functions[funcname].depth
        depth_bucket = depth/10
        if depth_bucket > max_depth_bucket:
            max_depth_bucket = depth_bucket

        try:
            depth_count[depth_bucket] += 1
        except:
            depth_count[depth_bucket] = 1

    print("=========== HISTOGRAM ==============")
    for i in range(0, max_depth_bucket+1):
        try:
            count = depth_count[i]
        except:
            count = 0

        # use numbers for bars that are too long
        if count < 72:
            bar = count*"*"
        else:
            bar = (72-8)*"*" + "(%d)*" % count
        print("%3d: %s" % (i*10, bar))

def cmp_depth(f1, f2):
    # Below expression same as python2 cmp(f1.depth, f2.depth) function.
    return (f1.depth > f2.depth) - (f1.depth < f2.depth)

def main():
    global DEBUG

    full_list = 0
    show_histogram = 0
    top_count = 10

    if "-h" in sys.argv:
        usage()

    if "-f" in sys.argv:
        full_list = 1
        sys.argv.remove("-f")

    if "-g" in sys.argv:
        show_histogram = 1
        sys.argv.remove("-g")

    if "-t" in sys.argv:
        top_count = sys.argv[sys.argv.index("-t")+1]
        sys.argv.remove(top_count)
        sys.argv.remove("-t")
        top_count = int(top_count)

    if "--debug" in sys.argv:
        DEBUG = 1
        sys.argv.remove("--debug")

    if "--version" in sys.argv or "-V" in sys.argv:
        print("show_stack version %s" % VERSION)
        sys.exit(0)

    try:
        filename = sys.argv[1]
    except:
        print("Error: missing filename to scan")
        usage()

    fd = open_assembly(filename)

    functions = {}
    func = None

    max_cur_depth = 0
    depth = 0
    func_line_count = 0
    PROGRAM_ENTRY_MAXLINES = 6

    for line in fd:
        # parse lines

        # find function starts
        try:
            if line[8:10] == " <":
                func_line = line
                (func_addr_s, rest) = func_line.split(" ", 1)
                funcname = rest[1:-3]
                #print line,
                func_addr = int(func_addr_s, 16)
                #print "func_addr=%x, func=%s" % (func_addr, func)
                # record depth of last function
                if func:
                    func.depth = max_cur_depth
                    dprint("stack depth: %d" % max_cur_depth)

                # start new function
                dprint("function: %s" % funcname)
                func = function_class(funcname, func_addr)
                functions[funcname] = func
                depth = 0
                max_cur_depth = 0
                func_line_count = 0

        except:
            pass

        func_line_count += 1

        # calculate stack depth
        # pattern is: initial register push (with "push {...}
        # reservation of local stack space
        if line.find("push\t") != -1 and \
                func_line_count < PROGRAM_ENTRY_MAXLINES:
            # parse push
            (before, args) = line.split("push", 1)
            dprint("push args=%s" % args)
            regs = line.split(",")
            depth += 4 * len(regs)
            if depth > max_cur_depth:
                max_cur_depth = depth

            dprint("depth=%d" % depth)

        if line.find("sub\tsp, sp, #") != -1 and \
                func_line_count < PROGRAM_ENTRY_MAXLINES:
            # parse sub
            (before, args) = line.split("sub", 1)
            if DEBUG:
                print("sub args=%s" % args)
            (sp1, sp2, imm) = line.split(",", 2)
            # next line splits: ' #84\t ; 0x54\n'
            # into ['#84',';','0x54']
            # take [0]th element, and strip leading '#'
            #print "imm='%s'" % imm
            depth += int(imm.split()[0][1:])
            if depth > max_cur_depth:
                max_cur_depth = depth

            dprint("depth=%d" % depth)

        # disable check for pops (used for debugging)
        if None and line.find("pop\t") != -1:
            # parse pop
            (before, args) = line.split("pop", 1)
            dprint("pop args=%s" % args)
            regs = line.split(",")
            depth -= 4 * len(regs)
            if depth < 0:
                print("Parser Error: function: %s, depth=%d" % \
                      (funcname, depth))
            dprint("depth=%d" % depth)

    # record depth of last function
    if func:
        func.depth = max_cur_depth
        dprint("max stack depth: %d" % max_cur_depth)

    fd.close()

    if not functions:
        print("No functions found. Done.")
        return

    # calculate results
    max_depth = 0
    maxfunc = func
    for func in list(functions.values()):
        if func.depth > max_depth:
            max_depth = func.depth
            maxfunc = func

    print("============ RESULTS ===============")
    print("number of functions     = %d" % len(functions))
    print("max function stack depth= %d" % maxfunc.depth)
    print("function with max depth = %s\n" % maxfunc.name)

    if full_list or top_count:
        funclist = list(functions.values())
        funclist.sort(key=functools.cmp_to_key(cmp_depth))

        print("%-32s %s" % ("Function Name        ", "Stack Depth"))
        print("%-32s %s" % ("=====================", "==========="))

        if full_list:
            top_count = len(funclist)

        start = len(funclist) - top_count
        end = start + top_count
        if start < 0:
            start = 0
            end = len(funclist)
        for i in range(start, end):
            func = funclist[i]
            print("%-32s %4d" % (func.name, func.depth))

    if show_histogram:
        do_show_histogram(functions)

if __name__ == "__main__":
    main()
