# cpushare client
# Copyright (C) 2005-2006  Andrea Arcangeli <andrea@cpushare.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation;
# only version 2.1 of the License.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from twisted.internet import reactor, protocol, defer

import os, struct, string, errno, md5, signal

# local
from cpushare.config import *
from cpushare.proto_const import *

class seccomp_protocol_class(protocol.ProcessProtocol):
	def __init__(self, seccomp, d_start, d_end):
		self.sigkill_is_seccomp = True
		self.seccomp = seccomp
		self.d_start, self.d_end = d_start, d_end
		self.outReceived = self.enable_seccomp_mode
		print 'Starting seccomp task'
		loader = seccomp.cpushare_protocol.factory.seccomp_loader
		reactor.spawnProcess(self,
				     loader,
				     (loader,),
				     {}, # clear the enviroment for security reasons!
				     childFDs = { 0 : 'w',
						  1 : 'r',
						  2 : 'r', })
		self.seccomp_file = '/proc/' + str(self.transport.pid) + '/seccomp'

		try:
			os.stat(self.seccomp_file)
		except OSError, err:
			if err[0] != errno.ENOENT:
				raise
			self.transport.writeToChild(0, MAGIC_SECCOMP_PRCTL)
		else:
			self.transport.writeToChild(0, MAGIC_SECCOMP_PROC)

	def pending_sigstop(self):
		proto = self.seccomp.cpushare_protocol
		if proto.factory.factory_type == PROTO_FACTORY_TYPE_SECCOMP:
			if proto.factory.sigstop:
				assert self.transport.pid
				self.sigstop()

	def connectionMade(self):
		self.pending_sigstop()

		self.seccomp.cpushare_protocol.seccomp = self
		self.seccomp.cpushare_protocol.transport.registerProducer(self, 1)
		if hasattr(self.transport, 'registerProducer'):
			self.transport.registerProducer(self.seccomp.cpushare_protocol, 1)
		elif self.transport.pipes.has_key(0): # hack for older twisted
			self.transport.pipes[0].registerProducer(self.seccomp.cpushare_protocol, 1)
		self.transport.closeChildFD(2) # close stderr right away
		self.transport.writeToChild(0, self.seccomp.header + self.seccomp.text_data)
	def enable_seccomp_mode(self, data):
		assert data == MAGIC_ASK_SECCOMP, "didn't ask seccomp"

		try:
			if file(self.seccomp_file, 'r').read(1) != '0':
				raise Exception('seccomp already enabled?')
		except IOError, err:
			if err[0] != errno.ENOENT:
				raise
			# With the prctl API the seccomp-loader
			# enables seccomp automatically
		else:

			file(self.seccomp_file, 'w').write('1')

			if file(self.seccomp_file, 'r').read(1) != '1':
				assert self.transport.pid is not None
				print 'Killing the seccomp-loader before it starts the untrusted bytecode'
				self.sigkill()
				raise Exception('seccomp enable failure')

		#os.system('cat /proc/%d/maps' % self.transport.pid)

		# start the seccomp engine now
		self.transport.writeToChild(0, MAGIC_GOT_SECCOMP)
		self.outReceived = self.send_to_server
		self.d_start.callback(None) # now the buyer is connected
	def send_to_server(self, data):
		#print repr(data)
		self.seccomp.cpushare_protocol.sendString(PROTO_SECCOMP +
							  PROTO_SECCOMP_FORWARD + data)
	def recv_from_server(self, data):
		self.transport.writeToChild(0, data)
	def errReceived(self, data):
		raise "shouldn't happen"
	def processEnded(self, status):
		self.seccomp.cpushare_protocol.seccomp = None
		self.seccomp.cpushare_protocol.transport.unregisterProducer()
		if hasattr(self.transport, 'unregisterProducer'):
			self.transport.unregisterProducer()
		elif self.transport.pipes.has_key(0): # hack for older twisted
			self.transport.pipes[0].unregisterProducer()
		if status.value.exitCode or status.value.signal:
			if status.value.exitCode == 4:
				print 'Failure in setting the stack size to %d bytes.' % self.seccomp.stack
			if status.value.signal == signal.SIGKILL:
				if self.sigkill_is_seccomp:
					print 'Seccomp task gracefully killed by seccomp.'
				else:
					print 'Seccomp task killed by sigkill.'
			elif status.value.signal == signal.SIGSEGV:
				print 'Seccomp task gracefully killed by sigsegv.'
			elif status.value.signal == signal.SIGQUIT:
				print 'Seccomp task killed by sigquit - should never happen.'
			self.d_end.errback(status)
		else:
			print 'Seccomp task completed successfully.'
			self.d_end.callback(None)

	def signal(self, s):
		if self.transport.pid is not None:
			os.kill(self.transport.pid, s)
	def sigquit(self):
		self.signal(signal.SIGQUIT)
	def sigkill(self):
		self.signal(signal.SIGKILL)
	def sigstop(self):
		self.signal(signal.SIGSTOP)
	def sigcont(self):
		self.signal(signal.SIGCONT)

	def resumeProducing(self):
		self.transport.resumeProducing()
	def pauseProducing(self):
		self.transport.pauseProducing()
	def stopProducing(self):
		self.transport.loseConnection()

class seccomp_class(object):
	def __init__(self, header, state_machine):
		#self.state_machine = state_machine
		self.cpushare_protocol = state_machine.protocol
		self.cache_dir = self.cpushare_protocol.factory.cache_dir

		name_len = ord(header[0])
		self.name = header[1:name_len+1]
		header = header[name_len+1:]

		size = struct.calcsize(HEADER_FMT)
		assert size + 16 == len(header), "corrupted header"

		self.header = header[:size]

		self.text_size, \
		self.data_size, \
		self.bss_size, \
		self.call_address, \
		self.signal_address, \
		self.heap, \
		self.stack, \
		self.cksum = struct.unpack(HEADER_FMT, self.header)
		#print self.cksum

		self.digest = header[size:]
		assert len(self.digest) == 16
		self.hexdigest = self.hexdigest(self.digest)

	def size(self):
		return self.text_size + self.data_size

	def hexdigest(self, digest):
		hexdigest = ''
		for i in digest:
			i = ord(i)
			hexdigest += string.hexdigits[(i >> 4) & 0xf]
			hexdigest += string.hexdigits[i & 0xf]
		return hexdigest

	def find_text_data(self):
		try:
			dentries = os.listdir(self.cache_dir)
		except OSError, err:
			if err[0] != errno.ENOENT:
				raise Exception('unknown error')
			return
		if self.hexdigest in dentries:
			self.text_data = ''
			filename = os.path.join(self.cache_dir, self.hexdigest)
			os.utime(filename, None) # set st_mtime for lru rotation w/ noatime
			f = file(filename)
			while 1:
				x = f.read(8192)
				if x:
					self.text_data += x
				else:
					# this way we shouldn't race with other cpus
					return len(self.text_data) == self.size() and \
					       self.hexdigest == md5.new(self.text_data).hexdigest()

	def store_text_data(self, data):
		try:
			os.mkdir(self.cache_dir)
		except OSError, err:
			if err[0] != errno.EEXIST:
				raise Exception('unknown error')
		self.text_data = data
		assert len(data) == self.size(), "wrong text_data length"
		if self.size() <= CACHE_MAX_SIZE:
			file(os.path.join(self.cache_dir, self.hexdigest), 'w').write(data)

		self.check_cache_size()

	def check_cache_size(self):
		dentries = [ (os.stat(os.path.join(self.cache_dir, dentry)), dentry)
			     for dentry in os.listdir(self.cache_dir) ]
		dentries = [ (stat.st_mtime, stat.st_size, dentry)
			     for stat, dentry in dentries ]

		total_size = sum([ dentry[1] for dentry in dentries ])
		if total_size <= CACHE_MAX_SIZE:
			return

		dentries.sort()
		for dentry in dentries:
			print 'Deleting %d bytes from %s' % (dentry[1], self.cache_dir)
			os.unlink(os.path.join(self.cache_dir, dentry[2]))
			total_size -= dentry[1]
			if total_size <= CACHE_MAX_SIZE:
				return

	def build_header(self, cksum):
		return struct.pack(HEADER_FMT,
				   self.text_size,
				   self.data_size,
				   self.bss_size,
				   self.call_address,
				   self.signal_address,
				   self.heap, self.stack, cksum)

	def robustness_check(self):
		obj = self.build_header(0)
		obj += self.text_data

		cksum = 0L
		for c in obj:
			cksum += ord(c)
			cksum &= 2**32-1

		assert cksum == self.cksum, "corrupted cksum %d %d" % (cksum, self.cksum)
		assert len(self.text_data) == self.size(), "corrupted size"
		assert md5.new(self.text_data).digest() == self.digest, "corrupted digest"
		assert md5.new(self.text_data).hexdigest() == self.hexdigest, "corrupted hexdigest"

	def run(self):
		self.robustness_check()

		d_start = defer.Deferred()
		d_end = defer.Deferred()
		self.protocol = seccomp_protocol_class(self, d_start, d_end)
		return d_start, d_end

	def __repr__(self):
		x = '<'
		for i in self.__dict__:
			x += i + ': ' + str(getattr(self, i)) + '\n'
		x = x[:-1] + '>'
		return x
