#include "parse.h"
#include "storage.h"

RequestHeaderState parseRequest(Request *request) {
  char *headerEnd, *requestEnd, *serverStart, *portStart, *resourceStart;
  RequestHeaderState state = PARTIAL_REQUEST;
  int major, minor, serverLength;
  char CRLF[3], CRLFCRLF[5];

  if (request->requestFromClient.method[0] == '\0') {
    // We haven't received yet the request headers.
    // Assume same CRLF everywhere in the request.
    if ((requestEnd = strstr(request->in->start, "\r\n\r\n")) != NULL) {
      strcpy(CRLF, "\r\n");
      strcpy(CRLFCRLF, "\r\n\r\n");
    }
    else if ((requestEnd = strstr(request->in->start, "\n\n")) != NULL) {
      strcpy(CRLF, "\n");
      strcpy(CRLFCRLF, "\n\n");
    }
    else if ((requestEnd = strstr(request->in->start, "\r\r")) != NULL) {
      strcpy(CRLF, "\r");
      strcpy(CRLFCRLF, "\r\r");
    }
    else {
      // Don't have a full request yet.
      goto end;
    }

    // Have a full request.
    GTIMER_START(request->processClientRequestTimer);

    // Parse the first line.
    if ((headerEnd = strstr(request->in->start, CRLF)) == NULL) {
      state = BAD_REQUEST;
      goto end;
    }
    *headerEnd = '\0';
    if (strlen(request->in->start) >= MAX_HEADER_SIZE) {
      // Request line is larger than what's allowed; it will look like it's malformed.
      state = BAD_REQUEST;
      goto end;
    }
    if (sscanf(request->in->start, "%s %s HTTP/%d.%d", request->requestFromClient.method, request->requestFromClient.url, &major, &minor) != 4) {
      // Malformed request.
      state = BAD_REQUEST;
      goto end;
    }
    if (major == 1 && minor >= 1) {
      request->requestFromClient.HTTP11 = true;
      // For HTTP/1.1, the default behavior is persistent connections and request->requestFromClient.connectionClose = false by default.
    }
    else {
      // For HTTP/1.0, the default behavior is non-persistent connections.
      request->requestFromClient.connectionClose = true;
    }
    if (strcmp(request->requestFromClient.method, "GET") != 0 && strcmp(request->requestFromClient.method, "POST") != 0) {
      // GET and POST are the only methods supported for now.
	XPROXY_ERROR ("ERROR: method %s not implemented\n", request->requestFromClient.method);
      state = METHOD_NOT_IMPLEMENTED;
      goto end;
    }

    // Extract server and port (if any) from the URL.
    // Look for lower-case "http://" in the beginning of url and skip it.
    serverStart = request->requestFromClient.url;
    if (strstr(request->requestFromClient.url, "http://") == request->requestFromClient.url) {
      serverStart += 7; // strlen("http://") = 7;
    }
    if ((resourceStart = strchr(serverStart, '/')) == NULL) {
      state = BAD_REQUEST;
      goto end;
    }
    strcpy(request->requestFromClient.resource, resourceStart);
    // Temporarily replace '/' in url with NULL.
    *resourceStart = '\0';
    if ((portStart = strchr(serverStart, ':')) == NULL) {
      // Port is not specified, default is 80.
      strcpy(request->requestFromClient.server, serverStart);
    }
    else {
      // Port is specified; portStart points to ':'.
      serverLength = portStart - serverStart;
      strncpy(request->requestFromClient.server, serverStart, serverLength);
      request->requestFromClient.server[serverLength] = '\0';
      portStart++;
      request->requestFromClient.port = atoi(portStart);
    }
    // Put back '/' in url.
    *resourceStart = '/';

    request->in->length -= headerEnd - request->in->start + strlen(CRLF);
    request->in->start = headerEnd + strlen(CRLF);

    state = (parseAllHeaders(request->in, &(request->requestFromClient)) ? FULL_REQUEST : BAD_REQUEST);
    if (state == BAD_REQUEST) {
      goto end;
    }
  }

  if (strcmp(request->requestFromClient.method, "POST") == 0) {
    if (request->requestFromClient.contentLength == -1) {
      // Client didn't include content length for a POST request.
      state = BAD_REQUEST;
      goto end;
    }
    // Check if there is enough space in the bodyFromServer buffer and allocate more if needed (the condition below should be true only once and thus space should be reallocated only once); add 2 for possible CRLF at the end of the request entity body.
    if ((request->requestFromClient.contentLength + 2) > request->bodyFromServerAllocatedLength) {
      reallocateBuffer(&(request->bodyFromServer), &(request->bodyFromServerAllocatedLength), request->requestFromClient.contentLength + 2);
      request->bodyFromServerStart = request->bodyFromServer;
    }
    if (request->in->length > 0) {
      // recvBytes was called already.
      int n = (((request->requestFromClient.contentLength - request->bodyFromServerLength) < request->in->length) ? (request->requestFromClient.contentLength - request->bodyFromServerLength) : request->in->length);
      memcpy(request->bodyFromServer + request->bodyFromServerLength, request->in->start, n);
      request->bodyFromServerLength += n;
      request->in->start += n;
      request->in->length -= n;
    }
    if (request->bodyFromServerLength == request->requestFromClient.contentLength) {
      // Received the request entity body.
      // Check for CRLF at the end of the entity body (although this might not be included, check explanation above regarding HTTP/1.0 RFC specification) and consume that from the input socket buffer; we're assuming here that the CRLF has been received, if not, it won't be consumed and that next request (if persistent connections are used) will look malformed.
      CRLF[0] = '\0';
      if (strstr(request->in->start, "\r\n") == request->in->start) {
	strcpy(CRLF, "\r\n");
      }
      else if (strstr(request->in->start, "\n") == request->in->start) {
	strcpy(CRLF, "\n");
      }
      else if (strstr(request->in->start, "\r") == request->in->start) {
	strcpy(CRLF, "\r");
      }
      if (CRLF[0] != '\0') {
	  //DEBUG_PARSE ("Entity body in POST request ended with CLRF\n");
	request->in->start += strlen(CRLF);
	request->in->length -= strlen(CRLF);
      }
      // Null terminate entity body.
      request->bodyFromServer[request->bodyFromServerLength] = '\0';

      DEBUG_POST ("\n%s\n", request->bodyFromServer);

      state = FULL_REQUEST;
    }
    else {
      state = PARTIAL_REQUEST;
    }
  }

 end:
  if (state == BAD_REQUEST || state == METHOD_NOT_IMPLEMENTED) {
    // Skip this request from the input socket buffer.
    // Skip everything from the input socket buffer; safer if multiple requests have been received and some have been POST requests that didn't contain a valid length, then we don't know where that request ends, so we can't skip all the way past CRLFCRLF.
    request->in->start += request->in->length;
    request->in->length = 0;
  }

  return state;
}

bool parseAllHeaders(Socket *socket, HTTPState *state) {
  char *headerEnd, *headerValue;
  int headerNameLength;
  char headerName[MAX_HEADER_SIZE];
  char CRLF[3];

  // Parse all headers.
  while (1) {
    if ((headerEnd = strstr(socket->start, "\r\n")) != NULL) {
      strcpy(CRLF, "\r\n");
    }
    else if ((headerEnd = strstr(socket->start, "\n")) != NULL) {
      strcpy(CRLF, "\n");
    }
    else if ((headerEnd = strstr(socket->start, "\r")) != NULL) {
      strcpy(CRLF, "\r");
    }
    else {
      // Malformed request.
      return false;
    }
    if (headerEnd == socket->start) {
      // This is the last line (only CRLF).
      socket->length -= strlen(CRLF);
      socket->start += strlen(CRLF);
      return true;
    }
    *headerEnd = '\0'; // null terminate the header; blow away '\r'
    if (strlen(socket->start) >= MAX_HEADER_SIZE) {
      // Header length is larger than allowed length.
	XPROXY_ERROR ("ERROR: %d header length is larger than allowed length %d\n", strlen(socket->start), MAX_HEADER_SIZE);
      return false;
    }
    if ((headerValue = strchr(socket->start, ':')) == NULL) {
      // Couldn't find ":", the end of the header name.
      return false;
    }
    headerNameLength = headerValue - socket->start;
    strncpy(headerName, socket->start, headerNameLength);
    headerName[headerNameLength] = '\0';
    headerValue++; // skip over ":"
    while (*headerValue == ' ') {
      // Eat white space.
      headerValue++;
    }
    // Now headerValue points to the header value.
    if (!parseHeader(state, headerName, headerValue)) {
      // Malformed request.
      return false;
    }
    socket->length -= headerEnd - socket->start + strlen(CRLF);
    socket->start = headerEnd + strlen(CRLF);
  }

  return true; // to avoid compilation warnings about not having a return value
}

bool parseHeader(HTTPState *state, char *headerName, char *headerValue) {
  if (strcasecmp(headerName, "date") == 0) {
    if (strToTime(headerValue, &(state->date))) {
      return true;
    }
  }
  else if (strcasecmp(headerName, "expires") == 0) {
    if (!strToTime(headerValue, &(state->expires))) {
      state->expires = 0; // "0" or other invalid expiration dates mean "already expired"
    }
    return true;
  }
  else if (strcasecmp(headerName, "if-modified-since") == 0) {
    if (strToTime(headerValue, &(state->ifModifiedSince))) {
      return true;
    }
  }
  else if (strcasecmp(headerName, "last-modified") == 0) {
    if (strToTime(headerValue, &(state->lastModified))) {
      return true;
    }
  }
  else if (strcasecmp(headerName, "content-length") == 0) {
    state->contentLength = atoi(headerValue);
    return true;
  }
  else if (strcasecmp(headerName, "content-md5") == 0) {
    decodeMIME64(headerValue, state->contentMD5);
    state->haveContentMD5 = true;
    return true;
  }
  else if (strcasecmp(headerName, "delta-base") == 0) {
    decodeMIME64(headerValue, state->deltaBase);
    state->haveDeltaBase = true;
    return true;
  }
  else if (strcasecmp(headerName, "content-type") == 0) {
    strcpy(state->contentType, headerValue);
    return true;
  }
  else if (strcasecmp(headerName, "im") == 0) {
    if (strcasecmp(headerValue, "xdelta") == 0) {
      // Only "xdelta" can be specified.
      state->IMXDelta = true;
      return true;
    }
  }
  else if (strcasecmp(headerName, "a-im") == 0) {
    if (strstr(headerValue, "xdelta") != NULL) {
      // "xdelta" must be on the list of delta algorithms; assume that lower-case "xdelta" is specified.
      state->aIMXDelta = true;
      return true;
    }
  }
  else if (strcasecmp(headerName, "connection") == 0 ||
	   strcasecmp(headerName, "proxy-connection") == 0) {
    if (strcasecmp(headerValue, "close") == 0) {
      state->connectionClose = true;
    }
    else {
      state->connectionClose = false;
    }
    return true;
  }
  else if (strcasecmp(headerName, "transfer-encoding") == 0) {
    if (strcasecmp(headerValue, "chunked") == 0) {
      // Only "chunked" can be specified.
      state->transferEncodingChunked = true;
      return true;
    }
  }
  else if (strcasecmp(headerName, "content-encoding") == 0) {
    if (strcasecmp(headerValue, "gzip") == 0) {
      // Only "gzip" can be specified.
      state->contentEncodingGZIP = true;
      return true;
    }
  }
  else if (strcasecmp(headerName, "accept-encoding") == 0) {
    if (strstr(headerValue, "gzip") != NULL) {
      // "gzip" must be on the list of acceptable encodings; assume lower-case "gzip".
      state->acceptEncodingGZIP = true;
      return true;
    }
  }
  else if (strcasecmp(headerName, "if-none-match") == 0 || strcasecmp(headerName, "etag") == 0) {
    strcpy(state->eTag, headerValue);
    return true;
  }
  else if (strcasecmp(headerName, "server") == 0) {
    strcpy(state->server, headerValue);
    return true;
  }
  else {
    // Copy header to additionalHeaders.
    // Make sure we don't overflow the additionalHeaders array.
    if ((strlen(state->additionalHeaders) + MAX_HEADER_SIZE) <= ((MAX_NUM_HEADERS/2)*MAX_HEADER_SIZE)) {
      // Check for some special headers.
      if (strcasecmp(headerName, "cache-control") == 0) {
	char *temp;
	if ((strstr(headerValue, "max-age") != NULL) && ((temp = strstr(headerValue, "=")) != NULL)) {
	    state->maxAge = atoi(temp + 1);
	}
	return true; // don't add "Cache-Control: max-age=xxx" to additionalHeaders
      }
      else if ((strcasecmp(headerName, "cookie")) == 0 || (strcasecmp(headerName, "set-cookie") == 0)) {
	state->hasCookie = true;
      }

      sprintf(state->additionalHeaders + strlen(state->additionalHeaders), "%s: %s\r\n", headerName, headerValue);
      return true;
    }
  }

  return false; // the header value was not acceptable
}

ResponseHeaderState parseResponse(Request *request, bool *recvBytesCalled) {
  char *headerEnd, *reasonPhraseStart;
  char temp[MAX_HEADER_SIZE];
  int n = 0;
  char CRLF[3], CRLFCRLF[5];

  // Receive and parse response headers if we haven't already done so.
  if (request->responseFromServer.HTTPCode == -1) {
    if (!*recvBytesCalled) {
      if (request->out->recvBytes(NULL, MAX_SOCKET_BUFFER) <= 0) {
	return BAD_RESPONSE;
      }
      *recvBytesCalled = true;

      DEBUG_PARSE ("RESPONSE RECEIVED:\n%s\nRESPONSE RECEIVED LENGTH: %d\n", request->out->start, request->out->length);
    }

    // Check to see if we have all response headers.
    // Assume same CRLF everywhere in the response.
    if (strstr(request->out->start, "\r\n\r\n") != NULL) {
      strcpy(CRLF, "\r\n");
      strcpy(CRLFCRLF, "\r\n\r\n");
    }
    else if (strstr(request->out->start, "\n\n") != NULL) {
      strcpy(CRLF, "\n");
      strcpy(CRLFCRLF, "\n\n");
    }
    else if (strstr(request->out->start, "\r\r") != NULL) {
      strcpy(CRLF, "\r");
      strcpy(CRLFCRLF, "\r\r");
    }
    else {
      return PARTIAL_RESPONSE;
    }

    // Have full response header.
    GTIMER_START(request->parseResponseHeaderTimer);

    // Parse the first line.
    if ((headerEnd = strstr(request->out->start, CRLF)) == NULL) {
      return BAD_RESPONSE;
    }
    *headerEnd = '\0';
    if (strlen(request->out->start) >= MAX_HEADER_SIZE) {
      // Response line larger than what's allowed; it will look like it's malformed.
      return BAD_RESPONSE;
    }
    if (sscanf(request->out->start, "%s %d %s", temp, &(request->responseFromServer.HTTPCode), request->responseFromServer.reasonPhrase) != 3) {
      return BAD_RESPONSE;
    }

    // By this time, reasonPhrase might contain just the first word of the phrase.
    if ((reasonPhraseStart = strstr(request->out->start, request->responseFromServer.reasonPhrase)) == NULL) {
      return BAD_RESPONSE; // this should never happen
    }
    // Now we'll have the full reason phrase.
    strcpy(request->responseFromServer.reasonPhrase, reasonPhraseStart);

    request->out->length -= headerEnd - request->out->start + strlen(CRLF);
    request->out->start = headerEnd + strlen(CRLF);

    // Parse all headers.
    if (!parseAllHeaders(request->out, &(request->responseFromServer))) {
      return BAD_RESPONSE;
    }

    GTIMER_STOP(request->parseResponseHeaderTimer);
    request->parseResponseHeaderTime = (int) (1000000 * GTIMER_ELAPSED(request->parseResponseHeaderTimer));
  }

  // Receive response body.
  if (request->responseFromServer.transferEncodingChunked) {
    return parseChunkedResponse(request, recvBytesCalled);
  }
  else if (request->responseFromServer.contentLength >= 0) {

      DEBUG_PARSE ("Response content length: %d\n", request->responseFromServer.contentLength);
    // Receive exactly contentLength bytes from the socket.
    // Check if there is enough space in the bodyFromServer buffer and allocate more if needed (the condition below should be true only once and thus space should be reallocated only once).
    if (request->responseFromServer.contentLength > request->bodyFromServerAllocatedLength) {
      reallocateBuffer(&(request->bodyFromServer), &(request->bodyFromServerAllocatedLength), request->responseFromServer.contentLength);
      request->bodyFromServerStart = request->bodyFromServer;
    }
    // The response might have contained all headers and part of the response body.  If yes, then we must have already called recvBytes above and shouldn't do it because it will block.  Instead, just copy the bytes already available in the socket buffer.
    n = 0;
    if (request->out->length > 0) {
      // recvBytes was called already.
      n = (((request->responseFromServer.contentLength - request->bodyFromServerLength) < request->out->length) ? (request->responseFromServer.contentLength - request->bodyFromServerLength) : request->out->length);
      memcpy(request->bodyFromServer + request->bodyFromServerLength, request->out->start, n);
      request->out->start += n;
      request->out->length -= n;
    }
    else if (!*recvBytesCalled) {
      // recvBytes was not called yet.
      if ((n = request->out->recvBytes(request->bodyFromServer + request->bodyFromServerLength, request->responseFromServer.contentLength - request->bodyFromServerLength)) <= 0) {
	return BAD_RESPONSE;
      }
      *recvBytesCalled = true;
    }
    request->bodyFromServerLength += n;
    if (request->bodyFromServerLength == request->responseFromServer.contentLength) {
      // Received the entire response body.
      return FULL_RESPONSE;
    }
  }
  else if (request->responseFromServer.HTTPCode == 200) {
    // Receive until we read zero bytes from the socket (by closing the socket, the other side signals the end of the response)
    // Check if there is enough space in the bodyFromServer buffer and allocate more if needed; the check below is fine as long as MAX_SOCKET_BUFFER is smaller than the initial request->bodyFromServerAllocatedLength.
    if (MAX_SOCKET_BUFFER > (request->bodyFromServerAllocatedLength - request->bodyFromServerLength)) {
      reallocateBuffer(&(request->bodyFromServer), &(request->bodyFromServerAllocatedLength), request->bodyFromServerAllocatedLength * 2);
      request->bodyFromServerStart = request->bodyFromServer;
    }
    // The response might have contained all headers and part of the response body.  If yes, then we must have already called recvBytes above and shouldn't do it because it will block.  Instead, just copy the bytes already available in the socket buffer.
    n = 0;
    if (request->out->length > 0) {
      // recvBytes was called already.
      memcpy(request->bodyFromServer + request->bodyFromServerLength, request->out->start, request->out->length);
      n = request->out->length;
      request->out->start += request->out->length;
      request->out->length = 0;
    }
    else if (!*recvBytesCalled) {
      n = request->out->recvBytes(request->bodyFromServer + request->bodyFromServerLength, MAX_SOCKET_BUFFER);
      *recvBytesCalled = true;
      if (n < 0) {
	// Error in receiving response body.
	return BAD_RESPONSE;
      }
      if (n == 0) {
	// Received full response body.
	return FULL_RESPONSE;
      }
    }
    request->bodyFromServerLength += n;
  }
  else {
    // 100, 304 or other responses with no bodies.
    return FULL_RESPONSE;
  }

  return PARTIAL_RESPONSE;
}

ResponseHeaderState parseChunkedResponse(Request *request, bool *recvBytesCalled) {
  char *sizeAndExtEnd, *sizeEnd;
  char hex;
  int dec, sizeLength, factor;
  char CRLF[3];

  if (!*recvBytesCalled) {
    // Call recvBytes if we haven't already done so.
    if (request->out->recvBytes(NULL, MAX_SOCKET_BUFFER) <= 0) {
      return BAD_RESPONSE;
    }
    *recvBytesCalled = true;
  }

  // Use request->responseFromServer.contentLength as temporary storage for the chunk sizes, then reset it to -1 when done.
  while (1) {
    if (request->responseFromServer.contentLength == -1 && !request->responseFromServer.doneReadingChunks) {
      // Need to read the chunk size.
      if (request->out->length == 0) {
	return PARTIAL_RESPONSE;
      }

      // Check to see if we received completely the chunk length.
      if ((sizeAndExtEnd = strstr(request->out->start, "\r\n")) != NULL) {
	strcpy(CRLF, "\r\n");
      }
      else if ((sizeAndExtEnd = strstr(request->out->start, "\n")) != NULL) {
	strcpy(CRLF, "\n");
      }
      else if ((sizeAndExtEnd = strstr(request->out->start, "\r")) != NULL) {
	strcpy(CRLF, "\r");
      }
      else {
	return PARTIAL_RESPONSE;
      }
      *sizeAndExtEnd = '\0';
      if ((sizeEnd = strchr(request->out->start, ';')) != NULL) {
	// Chunk size was followed by some extension.
	*sizeEnd = '\0';
      }
      if ((sizeEnd = strchr(request->out->start, ' ')) != NULL) {
	// Some sites include spaces after the chunk size.
	*sizeEnd = '\0';
      }

      sizeLength = strlen(request->out->start) - 1;
      request->responseFromServer.contentLength = 0;
      factor = 1;
      // Convert hex size to decimal bytes.
      while (sizeLength >= 0) {
	hex = *(request->out->start + sizeLength);
	if (isdigit(hex)) {
	  dec = hex - '0';
	}
	else if (hex < 'a') {
	  dec = 10 + hex - 'A';
	}
	else {
	  dec = 10 + hex - 'a';
	}
	request->responseFromServer.contentLength += factor * dec;
	factor *= 16;
	sizeLength--;
      }
      // Now request->responseFromServer.contentLength contains the chunk size.

      // Skip to the beginning of the chunk data.
      request->out->length -= sizeAndExtEnd - request->out->start + strlen(CRLF);
      request->out->start = sizeAndExtEnd + strlen(CRLF);

      // If the chunk size that we just parsed is zero, then we're done reading chunks.
      if (request->responseFromServer.contentLength == 0) {
	request->responseFromServer.contentLength = -1;
	request->responseFromServer.doneReadingChunks = true;
      }

      DEBUG_PARSE ("contentLength = %d\n", request->responseFromServer.contentLength);
    }
    else if (request->responseFromServer.contentLength == 0) {
      // We're done reading the chunk, need to read the "\r\n" after the chunk data.
      // Skip the CRLF.
      if (strstr(request->out->start, "\r\n") != NULL) {
	strcpy(CRLF, "\r\n");
      }
      else if (strstr(request->out->start, "\n") != NULL) {
	strcpy(CRLF, "\n");
      }
      else if (strstr(request->out->start, "\r") != NULL) {
	strcpy(CRLF, "\r");
      }
      else {
	return PARTIAL_RESPONSE;
      }
      request->out->start += strlen(CRLF);
      request->out->length -= strlen(CRLF);
      // Restore contentLength.
      request->responseFromServer.contentLength = -1;
    }
    else if (request->responseFromServer.contentLength == -1 && request->responseFromServer.doneReadingChunks) {
      // Read (discard) everything in the footer and return FULL_RESPONSE.
      if ((sizeAndExtEnd = strstr(request->out->start, "\r\n")) != NULL) {
	strcpy(CRLF, "\r\n");
      }
      else if ((sizeAndExtEnd = strstr(request->out->start, "\n")) != NULL) {
	strcpy(CRLF, "\n");
      }
      else if ((sizeAndExtEnd = strstr(request->out->start, "\r")) != NULL) {
	strcpy(CRLF, "\r");
      }
      else {
	return PARTIAL_RESPONSE;
      }
      request->out->length -= sizeAndExtEnd - request->out->start + strlen(CRLF);
      request->out->start = sizeAndExtEnd + strlen(CRLF);

      DEBUG_PARSE ("Done reading chunks\n");
      return FULL_RESPONSE;
    }
    else if (request->out->length > 0 && request->responseFromServer.contentLength > 0) {
      // Read contentLength (or as much as possible) from the socket.
      // Check if there is enough space in the bodyFromServer buffer and allocate more if needed.
      if (request->responseFromServer.contentLength > (request->bodyFromServerAllocatedLength - request->bodyFromServerLength)) {
	reallocateBuffer(&(request->bodyFromServer), &(request->bodyFromServerAllocatedLength), request->bodyFromServerAllocatedLength * 2);
	request->bodyFromServerStart = request->bodyFromServer;
      }
      int n  = (request->responseFromServer.contentLength < request->out->length ? request->responseFromServer.contentLength : request->out->length);
      memcpy(request->bodyFromServer + request->bodyFromServerLength, request->out->start, n);
      request->bodyFromServerLength += n;
      request->responseFromServer.contentLength -=n;
      request->out->start += n;
      request->out->length -= n;

      DEBUG_PARSE ("Bytes read from chunk: %d\n", n);
    }
    else {
      return PARTIAL_RESPONSE;
    }
  }
}
