#!/usr/bin/python
# stack_size - a program to parse ARM assembly files and produce
# a listing of stack sizes for the functions encountered.
#
# To use: compile your kernel as normal, then run this program on vmlinux:
# $ tools/stack_size/stack_size vmlinux
#
# If you have cross-compiled, then your CROSS_COMPILE environment variable
# should be set to the cross prefix for your toolchain
#
# Author: Tim Bird <tim dot bird ~(at)~ am dot sony dot com>
# Copyright 2011 Sony Network Entertainment, Inc
#
# GPL version 2.0 applies.
#

import os, sys
from stat import *

version = "1.1"

debug = 0

def dprint(msg):
	global debug
	if debug:
		print msg

def usage():
	prog = os.path.basename(sys.argv[0])
	print """Usage: %s [options] <filename>

Show stack size for functions in a program or object file.
Generate and parse the ARM assembly for the indicated filename,
and print out the size of the stack for each function.

Ex: $ %s -t 20 vmlinux

Options:
 -h       show this usage help
 -t <n>   show the top n functions with the largest stacks
          (default is to show top 10 functions with largest stacks)
 -f       show full function list
 -g       show histogram of stack depths
 -V or --version    show version information and exit
""" % (prog, prog)
	sys.exit(0)

class function_class():
	def __init__(self, name, addr):
		self.name = name
		self.addr = addr
		self.depth = 0

	def __repr__(self):
		return "function: %s with depth %d" % (self.name, self.depth)

def open_assembly(filename):
	base = filename
	if base.endswith(".o"):
		base = base[:-2]

	assembly_filename = "%s.S" % base

	try:
		st = os.stat(filename)
	except:
		print "Could not find input file '%s'\n" % filename
		sys.exit(1)

	generate_assembly =  0
	try:
		assembly_st = os.stat(assembly_filename)
		# test whether source file is newer than assembly file
		if st[ST_MTIME] > assembly_st[ST_MTIME]:
			generate_assembly = 1
	except:
		print "Could not find '%s' - trying to generate from '%s'\n" % \
			(assembly_filename, filename)
		generate_assembly = 1


	if generate_assembly:
		print "Generating assembly file.  This might take a while..."
		cross = os.environ["CROSS_COMPILE"]
		cmd = "%sobjdump -d %s >%s" % \
			(cross, filename, assembly_filename)
		print "Using command: '%s'" % cmd
		os.system(cmd)

	try:
		fd = open(assembly_filename)
	except:
		print "Could not find '%s' -- conversion didn't work\n" % \
			(assembly_filename)
		sys.exit(1)

	return fd;

def do_show_histogram(functions):
	depth_count = {}

	max_depth_bucket = 0
	for funcname in functions.keys():
		depth = functions[funcname].depth
		depth_bucket = depth/10
		if depth_bucket > max_depth_bucket:
			max_depth_bucket = depth_bucket

		try:
			depth_count[depth_bucket] += 1
		except:
			depth_count[depth_bucket] = 1

	print "=========== HISTOGRAM =============="
	for i in range(0,max_depth_bucket+1):
		try:
			count = depth_count[i]
		except:
			count = 0

		# use numbers for bars that are too long
		if count<72:
			bar = count*"*"
		else:
			bar = (72-8)*"*" + "(%d)*" % count
		print "%3d: %s" % (i*10,bar)

def cmp_depth(f1, f2):
	return cmp(f1.depth, f2.depth)

def main():
	global debug

	full_list = 0
	show_histogram = 0
	top_count = 10

	if "-h" in sys.argv:
		usage()

	if "-f" in sys.argv:
		full_list = 1
		sys.argv.remove("-f")

	if "-g" in sys.argv:
		show_histogram = 1
		sys.argv.remove("-g")

	if "-t" in sys.argv:
		top_count = sys.argv[sys.argv.index("-t")+1]
		sys.argv.remove(top_count)
		sys.argv.remove("-t")
		top_count = int(top_count)

	if "--debug" in sys.argv:
		debug = 1
		sys.argv.remove("--debug")

	if "--version" in sys.argv or "-V" in sys.argv:
		print "show_stack version %s" % version
		sys.exit(0)

	try:
		filename = sys.argv[1]
	except:
		print "Error: missing filename to scan"
		usage()

	fd = open_assembly(filename)

	functions = {}
	func = None

	max_cur_depth = 0
	depth = 0
	func_line_count = 0
	PROGRAM_ENTRY_MAXLINES = 6

	for line in fd.xreadlines():
		# parse lines

		# find function starts
		try:
			if line[8:10]==" <":
				func_line = line
				(func_addr_s,rest) = func_line.split(" ",1)
				funcname = rest[1:-3]
				#print line,
				func_addr = int(func_addr_s, 16)
				#print "func_addr=%x, func=%s" % (func_addr, func)
				# record depth of last function
				if func:
					func.depth = max_cur_depth
					dprint("stack depth: %d" % max_cur_depth)

				# start new function
				dprint("function: %s" % funcname)
				func = function_class(funcname, func_addr)
				functions[funcname] = func
				depth = 0
				max_cur_depth = 0
				func_line_count = 0

		except:
			pass

		func_line_count += 1

		# calculate stack depth
		# pattern is: initial register push (with "push {...}
		# reservation of local stack space
		if line.find("push\t") != -1 and \
		    func_line_count<PROGRAM_ENTRY_MAXLINES:
			# parse push
			(before, args) = line.split("push",1)
			dprint("push args=%s" % args)
			regs = line.split(",")
			depth += 4 * len(regs)
			if depth>max_cur_depth:
				max_cur_depth = depth

			dprint("depth=%d" % depth)

		if line.find("sub\tsp, sp, #") != -1 and \
		    func_line_count<PROGRAM_ENTRY_MAXLINES:
			# parse sub
			(before, args) = line.split("sub",1)
			if debug:
				print "sub args=%s" % args
			(sp1,sp2,imm) = line.split(",",2)
			# next line splits: ' #84\t ; 0x54\n'
			# into ['#84',';','0x54']
			# take [0]th element, and strip leading '#'
			#print "imm='%s'" % imm
			depth += int(imm.split()[0][1:])
			if depth>max_cur_depth:
				max_cur_depth = depth

			dprint("depth=%d" % depth)

		# disable check for pops (used for debugging)
		if None and line.find("pop\t") != -1:
			# parse pop
			(before, args) = line.split("pop",1)
			dprint("pop args=%s" % args)
			regs = line.split(",")
			depth -= 4 * len(regs)
			if depth<0:
				print "Parser Error: function: %s, depth=%d" % \
					(funcname, depth)
			dprint("depth=%d" % depth)

	# record depth of last function
	if func:
		func.depth = max_cur_depth
		dprint("max stack depth: %d" % max_cur_depth)

	fd.close()

	if not functions:
		print "No functions found. Done."
		return

	# calculate results
	max_depth = 0
	maxfunc = func
	for func in functions.values():
		if func.depth > max_depth:
			max_depth = func.depth
			maxfunc = func

	print "============ RESULTS ==============="
	print "number of functions     = %d" % len(functions)
	print "max function stack depth= %d" % maxfunc.depth
	print "function with max depth = %s\n" % maxfunc.name

	if full_list or top_count:
		funclist = functions.values()
		funclist.sort(cmp_depth)
		print "%-32s %s" % ("Function Name        ","Stack Depth")
		print "%-32s %s" % ("=====================","===========")

		if full_list:
			top_count = len(funclist)

		start = len(funclist)-top_count
		end = start+top_count
		if (start<0):
			start=0
			end=len(funclist)
		for i in range(start, end):
			func = funclist[i]
			print "%-32s %4d" % (func.name, func.depth)

	if show_histogram:
		do_show_histogram(functions)

if __name__=="__main__":
	main()
