import config
from util import *
from domain_model import resource_factory




nodes = {}

map2type = {'/dev/tty' : 'devtty_t',
           '/usr/bin/perl' : 'bin_t',
           "selinux" : 'security_t',
           "/etc/ld.so.cache" : "ld_so_cache_t",
            "/etc/mtab"	: 'etc_runtime_t',
            'CVS/Root': 'cvs_t',
            'PREFIX' : 'polgen_temp_t',
            'VAR_PREFIX' : 'polgen_temp_t',
            'USER_SUPPLIED' : 'polgen_temp_t'}


#-------------------------------------------
# Nodes - the ultimate product produced by the parser

class Node:
    def __init__(self, sobj, name):
        self.name = name
        self.source = sobj
        self.parent = None
        self.reads= []
        self.writes = []
        self.listens_at = []
        self.location = 'internal'
        self.execs = []
        self.connects_at = []
        self.requires = []
        nodes[name] = self
    def get_extended_reads (self):
        return self.reads + flatten([nodes[x].get_extended_reads() for x in self.requires]) 
    def get_extended_writes (self):
        return self.writes + flatten([nodes[x].get_extended_writes() for x in self.requires])
    def get_extended_connects_at (self):
        return self.connects_at + flatten([nodes[x].get_extended_connects_at() for x in self.requires])
    def get_extended_execs (self):
        return self.execs + flatten([nodes[x].get_extended_execs() for x in self.requires])
    def get_extended_listens_at (self):
        return self.listens_at + flatten([nodes[x].get_extended_listens_at() for x in self.requires])
    def __repr__ (self):
        return 'Component ' + self.name
    def browse(self):
        for x in self.__dict__.keys():
            print x + ": " + str(self.__dict__[x])

#-----------------------------------------------------
# SpecString - intermediate data structure for holding collection
#of text strings

class SpecString:
    def __init__(self, specstring):
        self.text = specstring
        self.set_component_strings()
        
    def set_component_strings (self):
        self.components = self.text.split('component')[1:]

#--------------------------------------------------
# Reader Visitors

class ReaderVisitor:
    def __init__(self, reader):
        self.target = reader

class SpecParser (ReaderVisitor):
    def do_it (self, sobj):
        strm = self.target.out
        components = self.target.specobj.components
        for c in components:
            start = c.strip()
            name = start.split()[0]
            co = Node(sobj, name)
            attributes = start[len(name) + 2:-1].strip()
            begin = True
            vals = []
            next = None
            for a in attributes.split():
                if begin:
                    aname = a
                    if next == 'parent':
                        #print name + ' ' + a
                        co.parent = nodes[a]
                        next = None
                    elif next == 'type':
                        co.type = aname
                        next = None
                    elif aname == '{':
                        continue
                    elif aname == 'parent':
                        next = 'parent'
                    elif aname == 'type':
                        next = 'type'
                    elif aname not in ['reads', 'requires', 'pipe', 'pipe_reads', 'pipe_writes', 'writes', 'socket_read', \
                                     'socket_write', 'requires', 'execs', 'pipeline', 'listens_at', 'connects_at']:
                        print 'Warning: Undefined Attribute: ' + aname
                    else:
                        #print 'attr = ' + aname
                        begin = False               
                else:
                    if a == ' ': # should not happen
                        continue
                    items = a[:-1].split(',')
                    if a[0] == '{':
                        items = items[1:]
                        if len(a) == 1:
                            continue
                        else:
                           for x in items:
                                vals.append(fix_value(aname, x))
                           if a[-1] == '}':
                                begin = True
                                sput(co, aname, vals)
                                vals = []
                    elif a[-1] == '}':
                                begin = True
                                for x in items:
                                    vals.append(fix_value(aname, x))
                                sput(co, aname, vals)
                                vals = []
                             
                    else:
                        for x in items:
                                vals.append(fix_value(aname, x))

class Spec2Html (ReaderVisitor):
    def do_it (self, sobj):
        strm = self.target.out
        strm.write('<h1>' + get_current_program().name + ' Specification </h1> \n')
        components = self.target.specobj.components
        strm.write('<html><table border = 1 width = 80%>\n')
        for c in components:
            strm.write('<tr><td>')
            start = c.strip()
            name = start.split()[0]
            strm.write('Component <a href = PolicyHelp.html#' + name + '>' + name + '</a> \n') 
            count = 0
            attributes = start[len(name) + 2:-1].strip()
            for a in attributes.split():
                start = '<br>----'
                if count == 1:
                    count = 0
                    start = " "
                else:
                    count = 1
                strm.write(start + a + ' \n')
            strm.write('</td></tr>')
        strm.write('</table> \n\n')
        strm.close()


#------------------------------------------------
# Readers

class Reader:
    def __init__(self, filename, folder = ""):
        """ Initializes the class by opening the file
         """
        pathname = config.data_directory + "/" + folder + '/' + filename
        out_file = config.results_directory + "/" \
        + filename + '.html'
        print "Reading from " + pathname
        self.file = open(pathname, 'r')
        self.out= open(out_file, 'w')
        self.first_line = 1
        self.specobj = None
        self.name = filename
        
class PSLReader (Reader):
    """A class for reading psl files."""
    def readlines (self, sobj):
        file = self.file
        spec = ""
        for l in file.readlines():
            if len(l.strip()) > 0 and l.strip()[0] == "#":
                continue
            #print l
            spec = spec + l[:-1]
        self.specobj = SpecString(spec)



def fix_value(aname, val):
    return_val = val
    if aname[-5:] == 'types':
        return_val = []
        if val[-2:] == '_t':
            return_val = val
        else:
            return_val = val + '_t'
    return return_val

def sput (obj, slot, val):
    #print "Setting the " + slot + ' of ' + str(obj) + ' to ' + str(val)
    setattr(obj, slot,val)
    if slot in ['reads', 'writes']:
        #create a data node
        for item in val:
            if item == "":
                print "Setting the " + slot + ' of ' + str(obj) + ' to ' + str(val)
            node_factory(obj.source, item, 'data_file')
    if slot == 'execs':
        for item in val:
            node_factory(obj.source, item, 'process')
    elif slot == 'requires':
        for item in val:
            node_factory(obj.source, item, 'code_module')

def node_factory(sobj, name, type):
    toreturn = None
    #print name
    if name in nodes.keys():
        obj = nodes[name]
        if obj.type == type:
            toreturn = obj
        else:
            print "Name clash with " + name
    else:
        toreturn = Node(sobj, name)
        toreturn.type = type
        #may need indication of internal versus external file
    return toreturn



#preliminary - will need work
def psl2xml (file):
    out_pathname = config.data_directory + '/' + \
                   file[:-4] + '.xml'
    strm = open(out_pathname, 'w')
    text = read_it(file)
    parse_it(text) # creates nodes dictionary
    strm.write('<components> \n')
    for item in nodes.values():
        strm.write('<component> \n')
        for k in item.__dict__.keys():
            there = item.__dict__[k]
            nbr = 1
            mults = False
            if type(there) == type([]):
                nbr = len(there)
                mults = True
            if nbr > 0:
                strm.write("   <" + k + "> \n")
                count = 0
                if mults:                                            
                    for v in item.__dict__[k]:
                        count = count + 1
                        add = ','
                        if count == nbr:
                            add = ''
                        strm.write('    ' + str(v) + add + '\n')
                else:
                    strm.write('     ' + str(item.__dict__[k]) + '\n')
                strm.write("   </" + k + "> \n")
        strm.write('</component> \n')
    strm.write('</components>')
    strm.close()


def post_process ():
    for c in nodes.values():
        for x in c.socket_write:
            c.socket_write_obj.append(nodes[x])
        for x in c.socket_read:
            c.socket_read_obj.append(nodes[x])

def get_resource_name (givenstr):
    #for now
    return givenstr

def get_sc_from_type(givenstr, ty):
    if ty == 'file':
        return file_sc_from_type(givenstr)
    elif ty == 'process':
        return get_process_sc_from_type(givenstr)
    
def file_sc_from_type (givenstr, owned = False):
    start = 'system_u:object_r:'
    if nodes[givenstr].location == 'external':
        dtrace(givenstr)
        completion = givenstr + '_t'
    elif givenstr[-2:] == '_t':
        completion = givenstr
    elif givenstr.find(get_current_program().name_for_type) or givenstr.find('/tmp') == 0:
        completion = 'polgen_temp_t'
    else:
        found = False
        #do special cases first
        for k in map2type.keys():
            if givenstr.find(k) == 0:
                completion = map2type[k]
                found = True
                break
        if not found:
            #do defaults for head directories
            if len(givenstr) < 5:
                completion = givenstr + '_t'
            elif contains_str(givenstr, '/lib'):
                completion = 'shlib_t'
            elif givenstr.find('/etc/') == 0:
                ###cancatenation of subdirectories
                dirs = givenstr.split('/')
                if len(dirs) < 4:
                    completion = 'etc_t'
                else:
                    completion = ""
                    for i in range(len(dirs) - 3):
                        completion = completion + dirs[i + 2] + '_'
                    completion = completion[:-1] + '_t'
            elif givenstr[:5] == '/dev/':
                completion = givenstr[5:] + '_t'
            else:
                #well, I don't know
                completion = givenstr + '_t'
    return start + completion

def get_process_sc_from_type(givenstr):
    """note: sc not cached on node objects, so this will be called several times"""
    base = 'system_u:object_r:'
    ty = 'bin_t'
    top = get_current_program().name
    if givenstr == top:
        ty = 'polgen_temp_exec_t'
    elif givenstr in nodes.keys():
            parent = nodes[givenstr].parent
            if parent != None:
                if akop(parent.name, top):
                   ty = 'polgen_temp_t'
    sc = base + ty
    return sc

def akop (n1, n2):
    """ true if n1 is a sub-module of n2"""
    start = nodes[n1]
    if start.name == n2:
        return True
    elif start.parent != None:
        return akop(start.parent.name, n2)
    else:
        return False
        
def setup_process (nodeobj, pid):
    """sobj = source, node is a component, pid is the
    process id, startobj is the top process"""
    pname = get_resource_name(nodeobj.name)
    sobj = nodeobj.source
    base_sc = get_process_sc_from_type(pname)
    process = resource_factory(sobj, pname, 'Process', base_sc)
    process.pid = pid
    if nodeobj.type == 'daemon':
        #print nodeobj.name + " is daemon'
        process.special_calls.append('setsid')
    process.processes[pid] = process
    return process

def do_execs (obj, startobj):
    """pobj is a process or file, startobj is the top process"""
    pobj = None
    if startobj != None:
        pname = get_resource_name(obj.name)
        sobj = nodes[pname].source
        #print obj
        if the_type(obj) == 'Process':
            pobj = obj
            startobj.processes[pobj.pid] = pobj
            filename = pname + '_exec_file'
            resource_factory(sobj, filename, 'File', file_sc_from_type(pname, True)[:-2] + '_exec_t')
            when_exec(sobj, startobj, pobj, filename)
        elif the_type(obj) == 'File':
            filename = obj.name
            pobj =  resource_factory(sobj, filename, 'Process', \
                                     get_process_sc_from_type(pname))
            if the_type(pobj) != 'Process':
                print 'Error: '+ filename + ' not a process'
            if the_type(startobj) != 'Process':
                print 'Error: Startobj '+ startobj.name + ' not a process'
            if pobj.pid == 0:
                pobj.pid = new_pid()
            startobj.processes[pobj.pid] = pobj
            #print "..." + str(pobj)
            when_exec(sobj, startobj, pobj, filename)
        else:
            print "Exec'ed object is not a process or file"
    return pobj

def new_pid ():
    the_prgm = get_current_program()
    there = the_prgm.next_pid
    the_prgm.next_pid = there + 1
    return there

def when_exec (sobj, responsible_process, process, filestr):
    #print 'exec ' + responsible_process.name + ' to ' + process.name
    if the_type(responsible_process) != 'Process':
        print '**When exec - source ' + str(responsible_process) + ' not process'
    if the_type(process) != 'Process':
        print '**When exec - target ' + str(process) + ' not process'
    file = resources_dict[filestr + ':File']
    responsible_process.add_transition('read', filestr, 'File')
    responsible_process.add_transition('exec', filestr, 'File')
    responsible_process.executing.append(process)
    process.executed_by.append(responsible_process)
    #responsible_process.add_transition('exec', filestr, 'File')
    link_name = responsible_process.name + '-directly-execs-' + process.name
    spawn_link = resource_factory(sobj, link_name, 'Spawn', responsible_process.get_sc())
    responsible_process.add_transition('activate', link_name, 'Spawn')
    process.add_transition('spawn', link_name, 'Spawn')
    
                
