/* $Id: sfssesskey.C,v 1.21 2001/03/16 03:53:58 dm Exp $ */

/*
 *
 * Copyright (C) 1998 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfsmisc.h"
#include "crypt.h"
#include "hashcash.h"

#ifdef MAINTAINER
bool sfs_nocrypt = getenv ("SFS_NOCRYPT");
#endif MAINTAINER

void
sfs_get_sesskey (sfs_hash *ksc, sfs_hash *kcs,
		 const sfs_servinfo &si, const sfs_kmsg *smsg, 
		 const sfs_connectinfo &ci, const bigint &kc,
		 const sfs_kmsg *cmsg)
{
  sfs_sesskeydat kdat;

  kdat.type = SFS_KSC;
  kdat.si = si;
  kdat.sshare = smsg->ksc_share;
  kdat.ci = ci;
  kdat.cshare = cmsg->ksc_share;
  sha1_hashxdr (ksc->base (), kdat, true);

  kdat.type = SFS_KCS;
  kdat.sshare = smsg->kcs_share;
  kdat.cshare = cmsg->kcs_share;
  sha1_hashxdr (kcs->base (), kdat, true);

  bzero (&kdat.sshare, sizeof (kdat.sshare));
  bzero (&kdat.cshare, sizeof (kdat.cshare));
}

void
sfs_get_sessid (sfs_hash *sessid, const sfs_hash *ksc, const sfs_hash *kcs)
{
  sfs_sessinfo si;
  si.type = SFS_SESSINFO;
  si.ksc.set (*ksc);
  si.kcs.set (*kcs);
  sha1_hashxdr (sessid->base (), si, true);

  bzero (si.ksc.base (), si.ksc.size ());
  bzero (si.kcs.base (), si.kcs.size ());
}

void
sfs_get_authid (sfs_hash *authid, sfs_service service, sfs_hostname name,
		const sfs_hash *hostid, const sfs_hash *sessid)
{
  sfs_authinfo aui;
  aui.type = SFS_AUTHINFO;
  aui.service = service;
  aui.name = name;
  aui.hostid = *hostid;
  aui.sessid = *sessid;
  sha1_hashxdr (authid->base (), aui);
}

static void
set_random_key (axprt_crypt *cx, sfs_hash *sessid)
{
  sfs_hash ksc, kcs;
  rnd.getbytes (ksc.base (), ksc.size ());
  rnd.getbytes (kcs.base (), kcs.size ());

  cx->encrypt (ksc.base (), ksc.size (), kcs.base (), kcs.size ());

  if (sessid)
    sfs_get_sessid (sessid, &ksc, &kcs);
  bzero (ksc.base (), ksc.size ());
  bzero (kcs.base (), kcs.size ());
}

static inline axprt_crypt *
xprt2crypt (axprt *x)
{
  // XXX - dynamic_cast is busted in egcs
  axprt_crypt *cx = static_cast<axprt_crypt *> (&*x);
  assert (typeid (*cx) == typeid (refcounted<axprt_crypt>));
  return cx;
}

void
sfs_server_crypt (svccb *sbp, rabin_priv *sk,
		  const sfs_connectinfo &ci, const sfs_servinfo &si,
		  sfs_hash *sessid, const sfs_hashcharge &charge,
		  axprt_crypt *cx)
{
  assert (sbp->prog () == SFS_PROGRAM && sbp->proc () == SFSPROC_ENCRYPT);
  if (!cx)
    cx = xprt2crypt (sbp->getsrv ()->xprt ());
  if (!ci.name) {
    warn << "sfs_server_crypt: client called encrypt before connect?\n";
    sbp->reject (PROC_UNAVAIL);
    set_random_key (cx, sessid);
    return;
  }

  sfs_encryptarg *arg = sbp->template getarg<sfs_encryptarg> ();
  rabin_pub kc (arg->pubkey);
  sfs_kmsg smsg;

  if (!hashcash_check(arg->payment.base (), ci.hostid.base (), 
		      charge.target.base (), 
		      charge.bitcost)) {
    warn << "payment doesn't match charge\n";
    sbp->reject (GARBAGE_ARGS);
    set_random_key (cx, sessid);
    return;
  }

  if (kc.n.nbits () < sfs_minpubkeysize) {
    warn << "client public key too small\n";
    sbp->reject (GARBAGE_ARGS);
    set_random_key (cx, sessid);
    return;
  }
  if (kc.n.nbits () > sfs_maxpubkeysize) {
    warn << "client public key too large\n";
    sbp->reject (GARBAGE_ARGS);
    set_random_key (cx, sessid);
    return;
  }

  rnd.getbytes (&smsg, sizeof (smsg));
  sfs_encryptres res = kc.encrypt (wstr (&smsg, sizeof (smsg)));
  if (!res) {
    warn << "could not encrypt with client's public key\n";
    sbp->reject (GARBAGE_ARGS);
    set_random_key (cx, sessid);
    return;
  }
  bigint kmsg = arg->kmsg;
  sbp->reply (&res);
  
  sfs_hash ksc, kcs;
  str cmsgptxt = sk->decrypt (kmsg, sizeof (sfs_kmsg));
  if (!cmsgptxt) {
    set_random_key (cx, sessid);
    return;
  }

  sfs_get_sesskey (&ksc, &kcs, si, &smsg, ci, kc.n, sfs_get_kmsg (cmsgptxt));
  if (sessid)
    sfs_get_sessid (sessid, &ksc, &kcs);

#ifdef MAINTAINER
  if (!sfs_nocrypt)
    cx->encrypt (ksc.base (), ksc.size (), kcs.base (), kcs.size ());
#endif /* MAINTAINER */

  bzero (ksc.base (), ksc.size ());
  bzero (kcs.base (), kcs.size ());
  bzero (&smsg, sizeof (smsg));
}

struct sfs_client_crypt_state {
  typedef callback<void, const sfs_hash *>::ref cb_t;
  sfs_kmsg cmsg;
  sfs_encryptres res;
  ptr<axprt> x;
  ref<rabin_priv> sk;
  sfs_servinfo si;
  sfs_connectinfo ci;
  cb_t cb;
  sfs_client_crypt_state (ref<rabin_priv> s, const sfs_servinfo &ssi,
			  const sfs_connectinfo &cci, cb_t c)
    : sk (s), si (ssi), ci (cci), cb (c) {}
};

static void
sfs_client_crypt_cb (sfs_client_crypt_state *st, clnt_stat err)
{
  if (err) {
    warnx << st->si.host.hostname << ": negotiating session key: "
	  << err << "\n";
    (*st->cb) (NULL);
    delete st;
    return;
  }

  sfs_hash ksc, kcs;
  str smsgptxt = st->sk->decrypt (st->res, sizeof (sfs_kmsg));
  if (!smsgptxt) {
    /* Bad ciphertext -- just start encrypting with a random key */
    rnd.getbytes (ksc.base (), ksc.size ());
    rnd.getbytes (kcs.base (), kcs.size ());
  }
  else
    sfs_get_sesskey (&ksc, &kcs, st->si, sfs_get_kmsg (smsgptxt),
		     st->ci, st->sk->n, &st->cmsg);

  sfs_hash sessid;
  sfs_get_sessid (&sessid, &ksc, &kcs);

#ifdef MAINTAINER
  if (!sfs_nocrypt)
    xprt2crypt (st->x)->encrypt (&kcs, sizeof (kcs), &ksc, sizeof (ksc));
#endif /* MAINTAINER */

  bzero (&ksc, sizeof (ksc));
  bzero (&kcs, sizeof (kcs));
  bzero (&st->cmsg, sizeof (st->cmsg));

  (*st->cb) (&sessid);
  delete st;
}

void
sfs_client_crypt (ptr<aclnt> c, ptr<rabin_priv> clntkey,
		  const sfs_connectinfo &ci, const sfs_connectok &cres,
		  callback<void, const sfs_hash *>::ref cb,
		  ptr<axprt_crypt> cx)
{
  assert (c->rp.progno == SFS_PROGRAM && c->rp.versno == SFS_VERSION);
  sfs_client_crypt_state *st
    = New sfs_client_crypt_state (clntkey, cres.servinfo, ci, cb);
  sfs_encryptarg arg;

  if (cres.servinfo.host.pubkey.nbits () < sfs_minpubkeysize) {
    warn << cres.servinfo.host.hostname << ": public key too small\n";
    goto fail;
  }
  if (cres.servinfo.host.pubkey.nbits () > sfs_maxpubkeysize) {
    warn << cres.servinfo.host.hostname << ": public key too large\n";
    goto fail;
  }
  if (cres.charge.bitcost > sfs_maxhashcost) {
    warn << cres.servinfo.host.hostname << ": hashcash charge too great\n";
    goto fail;
  }

  if (cx)
    st->x = cx;
  else
    st->x = c->xprt ();
  rnd.getbytes (&st->cmsg, sizeof (st->cmsg));
 
  hashcash_pay(arg.payment.base (), ci.hostid.base (), 
	       cres.charge.target.base (), cres.charge.bitcost);
  arg.kmsg = rabin_pub (st->si.host.pubkey).encrypt (wstr (&st->cmsg,
							   sizeof (st->cmsg)));
  arg.pubkey = clntkey->n;
  /* The timed call is for forward secrecy.  clntkey gets used for
   * other session keys, so we don't want to sit on it forever. */
  c->timedcall (300, SFSPROC_ENCRYPT, &arg, &st->res,
		wrap (sfs_client_crypt_cb, st));

  return;

 fail:
  delete st;
  set_random_key (xprt2crypt (c->xprt ()), NULL);
  (*cb) (NULL);
  return;
}
