#include "headers.h"
#include "myUtil.cc"
#pragma comment(lib, "Ws2_32.lib")

char readOneChar(SOCKET sock){
  char buf;
  int n = recv(sock, &buf, 1, 0);
  if (n != 1){
    mexErrMsgTxt("recv() failed");
  }
  return buf;
}

static int readDigit(SOCKET sock){
  char tmp[2];
  tmp[0] = readOneChar(sock);
  if (!(tmp[0] >= '0' && tmp[0] <= '9')){
    mexErrMsgTxt("expected digit");
  }
  tmp[1] = 0;  
  return atoi(tmp);
}

static int readNDigits(SOCKET sock, int n){
  
  int r = 0;
  int ix;
  for (ix = 0; ix < n; ++ix){
    r *= 10;
    r += readDigit(sock);
  }
  return r;
}

static char peekOneChar(SOCKET sock){
  char buf;
  int n = recv(sock, &buf, 1, MSG_PEEK);
  if (n != 1){
    mexErrMsgTxt("recv() failed");
  }
  return buf;
}

void readBinary(SOCKET sock, char** resultBuf, int* nBytesRead){
  int bSize = 100000; /* initial buffer size */
  char* buf = (char*) mxCalloc(bSize, sizeof(char));
  int usedBuf = 0;
  int freeBuf = bSize;

  /***********************************************
   * parse header
   ***********************************************/

  /* Skip binary identifier '#' */
  char tmp[2];
  int n = recv(sock, tmp, 1, 0);
  if (n != 1){
    mexErrMsgTxt("recv() failed");
  }
  if (tmp[0] != '#'){
    mexErrMsgTxt("expected # as first character");
  }
  
  /* Read length of file size */
  int nDigits = readDigit(sock);

  /* Read file size */
  int nData = readNDigits(sock, nDigits);
  
  /***********************************************
   * read data body
   ***********************************************/
  while (usedBuf < nData){
    int nRead = recv(sock, buf+usedBuf, freeBuf, 0);
    if (nRead < 1){
      mexErrMsgTxt("recv() failed");
    }
    
    usedBuf += nRead;
    freeBuf -= nRead;
    
    /* out of space? Reallocate */
    if (freeBuf == 0){
      bSize *= 2;
      char* tmp = (char*) mxCalloc(bSize, sizeof(char));
      memcpy(/* dest */tmp, /* src */buf, usedBuf * sizeof(char));
      mxFree(buf);
      buf = tmp;
      freeBuf = bSize - usedBuf;
    }
  }

  /***********************************************
   * discard trailing newline
   ***********************************************/
  if (usedBuf > nData){
    usedBuf = nData;
  }
  
  /***********************************************
   * return values
   ***********************************************/
  *resultBuf = buf;
  *nBytesRead = usedBuf;
}

/* Reads to a newline termination as final character in a block read. */
static void readASC(SOCKET sock, char** resultBuf, int* nBytesRead){
  int bSize = 10;
  char* buf = (char*) mxCalloc(bSize, sizeof(char));
  int usedBuf = 0;
  int freeBuf = bSize;
  *resultBuf = NULL;
  *nBytesRead = 0;
  
  while (1){
    int nRead = recv(sock, buf+usedBuf, freeBuf, 0);
    if (nRead < 1){
      mexErrMsgTxt("?? zero-byte recv() ??");
    }

    usedBuf += nRead;
    freeBuf -= nRead;
    
    /* newline terminates ASCII */
    if (*(buf+usedBuf-1) == '\n'){
      --usedBuf; /* omit newline */
      break;
    }

    /* out of space? Reallocate */
    if (freeBuf == 0){
      bSize *= 2;
      char* tmp = (char*) mxCalloc(bSize, sizeof(char));
      memcpy(/* dest */tmp, /* src */buf, usedBuf * sizeof(char));
      mxFree(buf);
      buf = tmp;
      freeBuf = bSize - usedBuf;
    }
  }

  /* Return result */
  *resultBuf = buf;
  *nBytesRead = usedBuf;
}

static mxArray* ascii2charMatrix(char* resultBuf, int nBytesRead){
  mxArray* out;
#define NDIM (2)
  int dim[NDIM];
  dim[0] = 1;
  dim[1] = nBytesRead;
  out = mxCreateCharArray(NDIM, dim);
  
  mxChar* writePtr = mxGetChars(out);
  char* readPtr = resultBuf;
  int ix;
  for (ix = 0; ix < nBytesRead; ++ix){
    *(writePtr++) = (mxChar)(*(readPtr++));
  }
  
  return out;
}

static mxArray* binary2charMatrix(char* resultBuf, int nBytesRead){
  if ((nBytesRead % 4) != 0){ 
    mexErrMsgTxt("binary-to-float: received number of bytes is not multiple of four!");
  }

  mxArray* out = mxCreateDoubleMatrix(1, nBytesRead/4, mxREAL);
  double* writePtr = mxGetPr(out);

  char* databuf = resultBuf;
  int ix;
  for (ix=0; ix < nBytesRead/4; ++ix){
    float val;
    char* p = (char*)&val;
    *(p+0) = *(databuf+4*ix+3);
    *(p+1) = *(databuf+4*ix+2);
    *(p+2) = *(databuf+4*ix+1);
    *(p+3) = *(databuf+4*ix+0);
    *(writePtr++) = val;
  }
  return out;
}

void mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]){
  
  if (nrhs != 1){ 
    mexErrMsgTxt("need 1 input args");
  } else if (nlhs != 1){
    mexErrMsgTxt("need 1 output args");
  }

  if (mxIsChar(prhs[0]) != 1){
    mexErrMsgTxt("socket arg must be char");
  }

  /***********************************************/
  /* copy socket into data structure */
  /***********************************************/  
  SOCKET s;
  memcpy(/* dest */&s, /* src */mxGetChars(prhs[0]), sizeof(SOCKET));
  
  /***********************************************/
  /* read data */
  /***********************************************/  
  char* resultBuf; 
  int nBytesRead;
  
  char c = peekOneChar(s);
  if (c == '#'){
    
    /***********************************************
     * first character '#' indicates binary transmission
     * that may encode zeros
     ***********************************************/
    readBinary(s, &resultBuf, &nBytesRead);
    plhs[0] = binary2charMatrix(resultBuf, nBytesRead);
 
  } else {
    
    /***********************************************
     * ASCII, delimited by \r
     ***********************************************/
    readASC(s, &resultBuf, &nBytesRead);
    plhs[0] = ascii2charMatrix(resultBuf, nBytesRead);
  }
}
