#!/usr/bin/python

# This reads the output of a kernel compile with
# CONFIG_MEASURE_INLINES enabled and counts the number of code
# instantiations. This is slightly tricky as gcc reports
# instantiations of inlines inside of other inlines in header files,
# so we need to build a call tree and credit only instantiations in C
# files.
#
# Usage:
#
#  make 2> inline-data
#  scripts/count-inlines inline-data | less
#

import sys, re, os

showtree = 1
seen = {}
defined = {}
calls = {}
callers = {}
file = {}
uses = {}
usecount = {}
used = []
findcache = {}
bytesused = {}
total = {}
in_c = {}
fsize = {}
fr = re.compile(r"In function `(\S+)'")
dr = re.compile(r"(\S+):(\d+): warning: `(\S+)' is deprecated \(declared at (\S+):(\d+)\)")
current = "unknown"
linecache = {}
fusers = {}

kernel = sys.argv[1]
inlines = open(sys.argv[2])

# Credit nested inlines
def credit(func, nest):
    if nest>10: sys.exit(-1)
    usecount[func] = usecount.get(func, 0) + 1
    if not calls.has_key(func): return
    for f in calls[func]: credit(f, nest+1)

def line2func(pos):
    try:
        return linecache[pos]
    except KeyError:
        c = pos.index(":")
        f = pos[:c]
        l = int(pos[c+1:])
        best = "unknown(%s)" % pos[:-1]

        try:
            for func, line in file[f]:
                if line <= l: best = func
        except KeyError: pass

        linecache[pos] = best
        return best

def locate_section(file, section):
    tsec = re.compile(" %s\s+(\S+)\s+(\S+)" % section)
    for l in os.popen("objdump -h %s" % file).xreadlines():
        try: size, start = map(lambda x: int(x, 16), tsec.search(l).groups())
        except: pass

    return (start, size)

def scan_range(file, start, size):
    r,w = os.pipe()
    if os.fork():
        os.close(w)
        f = "unknown"
        for l in os.fdopen(r).xreadlines():
            l = l[:-1]
            print l
            if ":" in l:
                f = line2func(l)
                try: fsize[f] += 1
                except: fsize[f] = 1
            else:
                try: fusers[f][l] += 1
                except: fusers[f] = { l: 1 }

    else: # child
        os.close(r)
        os.dup2(w, 1)
        a2l = os.popen("addr2line -f -e %s" % file, "w")
        for b in range(start, start+size):
            a2l.write(hex(b)+"\n")
        sys.exit(0)

print >>sys.stderr, "Scanning inlining info...";

for l in inlines.xreadlines():
    g = dr.match(l)
    if not g:
        try: current = fr.search(l).group(1)
        except: pass
        continue

    use, useline, name, src, srcline = g.groups()
    srcline, useline = int(srcline), int(useline)

    total[name]= total.get(name, 0) + 1
    if use[-1] == 'c': in_c[name] = in_c.get(name, 0) + 1

    file.setdefault(src, []).append((name, srcline))
    if src != use or (srcline != useline and srcline != useline-1):
        uses.setdefault(name, []).append((use,useline))
        if use[-1] == 'c': used.append(name)

    if not name in calls.setdefault(current, []):
        calls[current].append(name)
    if not current in callers.setdefault(name, []):
        callers[name].append(current)

for func in used: credit(func, 0)

for sec in (".text", ".init.text", ".exit.text"):
    print >>sys.stderr, "Scanning %s..." % sec;
    start, size = locate_section(kernel, sec)
    scan_range(kernel, start, size)

d = zip(fsize.values(), fsize.keys())
d.sort()
d.reverse()

for c, f in d:
    print "%d %s" % (c, f)
    e = zip(fusers[f].values(), fusers[f].keys())
    e.sort()
    e.reverse()
    for g, c in e:
        print " %s(%d)" % (c, g),
    print


l = [ (usecount[f], f) for f in usecount.keys() ]
l.sort()
l.reverse()

for (c,f) in l:
    def list(l):
        a = [ (usecount.get(f),f) for f in l ]
        a = filter(lambda x: x[0], a)
        n = len(l) - len(a)
        if n:
            a.append((n,"<other>"))
        a.sort()
        a.reverse()
        return " ".join(["%s(%d)" % (f,c) for c,f in a])
    print "%d  %s (%d in *.c)" % (c, f, in_c.get(f, 0))
    if showtree:
        print "calls:", list(calls.get(f, []))
        print "callers:", list(callers.get(f, []))
        print


