/* $Id: network.c,v 1.16 2003/06/14 12:43:26 sjoerd Exp $ */
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <assert.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <time.h>
#include <netdb.h>
#include <netinet/in.h>
#include <errno.h>

#include "global.h"
#include "list.h"
#include "network.h"

#define NETDEBUG(...) DEBUG(DNETWORK,"Network",__VA_ARGS__)

/* structure for keeping some info ;) */
struct network_info {
  fd_set *rfds;
  fd_set *wfds;
  int maxfd;
};

static void network_packet_ref(Network_packet *packet);
static void network_packet_unref(Network_packet *packet);

Network_listener *
new_network_listener(int port, int ai_family) {
#define BACKLOG 5
  Network_listener *result;
  int fd, ret, yes = 1;
  struct addrinfo req, *ans;

  memset(&req, 0, sizeof(req));

  req.ai_flags = AI_PASSIVE;
  req.ai_family = ai_family;
  req.ai_socktype = SOCK_STREAM;
  req.ai_protocol = IPPROTO_TCP;

  if ((ret = getaddrinfo(NULL, "0", &req, &ans)) != 0) {
    NETDEBUG("getaddrinfo failed: %s", gai_strerror(ret));
    return NULL;
  }

  ((struct sockaddr_in *) ans->ai_addr)->sin_port = ntohs(port);

  fd = socket(ans->ai_family, ans->ai_socktype, ans->ai_protocol);

  if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(int)) == -1) {
    NETDEBUG( "%s", strerror(errno));
    freeaddrinfo(ans);
    return NULL;
  }

  if (bind(fd, ans->ai_addr, ans->ai_addrlen) < 0) {
    NETDEBUG( "bind failed: %s", strerror(errno));
    return NULL;
  }

  ret = listen(fd, BACKLOG);
  if (ret == -1) {
    NETDEBUG( "listen: %s", strerror(errno));
    return NULL;
  }

  result = malloc(sizeof(Network_listener));
  result->fd = fd;
  result->new_connection = new_list();
  NETDEBUG("Listining on port %d",port);
  return result;
}

void
del_network_listener(Network_listener * listener) {
  Network_connection *con;

  while ((con = network_listener_pop(listener)) != NULL) {
    del_network_connection(con);
  }

  if (listener->fd >= 0)
    close(listener->fd);
  del_list(listener->new_connection);
  free(listener);
}

Network_connection *
network_listener_pop(Network_listener * listener) {
  return list_pop(listener->new_connection);
}

static Network_connection *
init_network_connection(void) {
  Network_connection *result;

  result = malloc(sizeof(Network_connection));
  assert(result != NULL);
  result->fd = -1;

  result->in_queue = new_list();
  result->out_queue = new_list();

  result->inp = NULL;
  result->bytes_received = 0;
  result->size = 0;

  result->outp = NULL;
  result->hostname = NULL;
  result->shostname = NULL;
  result->localhost = NULL;

  return result;
}

Network_connection *
new_network_connection(char *hostname, int port) {
  Network_connection *result;
  int sockfd = -1, ret = -1;
  struct addrinfo req, *ans, *tmpaddr;
  char name[NI_MAXHOST], portname[NI_MAXSERV];


  memset(&req, 0, sizeof(req));
  req.ai_flags = 0;
  req.ai_family = AF_UNSPEC;
  req.ai_socktype = SOCK_STREAM;
  req.ai_protocol = IPPROTO_TCP;

  if (getaddrinfo(hostname, NULL, &req, &ans) != 0) {
    NETDEBUG("getaddrinfo failed: %s", strerror(errno));
    return NULL;
  }

  tmpaddr = ans;
  while (tmpaddr != NULL) {
    ((struct sockaddr_in *) tmpaddr->ai_addr)->sin_port = htons(port);
    getnameinfo(tmpaddr->ai_addr, tmpaddr->ai_addrlen,
                name, sizeof(name), portname, sizeof(portname),
                NI_NUMERICHOST | NI_NUMERICSERV);

    NETDEBUG( "Trying %s port %s...", name, portname);

    sockfd =
      socket(tmpaddr->ai_family, tmpaddr->ai_socktype, tmpaddr->ai_protocol);
    if (sockfd < 0) {
      NETDEBUG("socket failed: %s", strerror(errno));
    } else if ((ret = connect(sockfd, tmpaddr->ai_addr, tmpaddr->ai_addrlen))
               < 0) {
      NETDEBUG( "connect failed: %s", strerror(errno));
    } else
      break;

    tmpaddr = tmpaddr->ai_next;
  }

  if (ret != 0 || sockfd < 0) {
    freeaddrinfo(ans);
    return NULL;
  }
  NETDEBUG("succeeded");

  result = init_network_connection();
  result->fd = sockfd;

  freeaddrinfo(ans);

  return result;
}

void
del_network_connection(Network_connection * con) {
  Network_packet *p;

  if (con->fd >= 0)
    close(con->fd);

  while ((p = network_connection_pop(con)) != NULL) {
    del_network_packet(p);
  }
  del_list(con->in_queue);
  free(con->inp);

  while ((p = list_pop(con->out_queue)) != NULL) {
    del_network_packet(p);
  }
  del_list(con->out_queue);
  free(con->hostname);
  free(con->shostname);
  free(con->localhost);
  free(con->outp);
  free(con);
}

Network_packet *
network_connection_pop(Network_connection * con) {
  return list_pop(con->in_queue);
}

int
network_connection_sent(Network_connection * con, Network_packet *packet) {
  if (con->fd < 0)
    return FALSE;

  list_append(con->out_queue, packet);
  network_packet_ref(packet);

  return TRUE;
}

Network_state
network_connection_state(Network_connection * con) {
  return con->fd < 0 ? NW_ERROR : NW_OK;
}

char *
network_connection_get_remote(Network_connection *con) {
  struct sockaddr_storage sock;
  socklen_t len = sizeof(struct sockaddr_storage);

  if (con->hostname == NULL) {
    con->hostname = malloc(sizeof(char) * NI_MAXHOST);
    assert(getpeername(con->fd,(struct sockaddr *)&sock,&len) == 0);
    if (getnameinfo((struct sockaddr*)&sock,len,
                con->hostname,NI_MAXHOST,
                NULL,0,NI_NUMERICHOST | NI_NUMERICSERV) != 0) {

      WARN("getnameinfo failed: %s",strerror(errno));
    }
  } 
  return con->hostname;
}

char *
network_connection_get_sremote(Network_connection *con) {
  struct sockaddr_storage sock;
  socklen_t len = sizeof(struct sockaddr_storage);

  if (con->shostname == NULL) {
    con->shostname = malloc(sizeof(char) * NI_MAXHOST);
    assert(getpeername(con->fd,(struct sockaddr *)&sock,&len) == 0);
    if (getnameinfo((struct sockaddr*)&sock,len,
                con->shostname,NI_MAXHOST,
                NULL,0,0) != 0) {

      WARN("getnameinfo failed: %s",strerror(errno));
    }
  } 
  return con->shostname;
}

char *
network_connection_get_local(Network_connection *con) {
  struct sockaddr_storage sock;
  socklen_t len = sizeof(struct sockaddr_storage);

  if (con->localhost == NULL) {
    con->localhost = malloc(sizeof(char) * NI_MAXHOST);
    assert(getsockname(con->fd,(struct sockaddr *)&sock,&len) == 0);

    if (getnameinfo((struct sockaddr*)&sock,len,
                con->localhost,NI_MAXHOST,
                NULL,0,NI_NUMERICHOST | NI_NUMERICSERV) != 0) {

      WARN("getnameinfo failed: %s",strerror(errno));
    }
  }
  printf("->>> %s <<<<<\n",con->localhost);
  return con->localhost;
}

static void
network_packet_ref(Network_packet *packet) {
  packet->refcount++;
}

static void
network_packet_unref(Network_packet *packet) {
  packet->refcount--;
  if (packet->refcount <= 0) {
    free(packet->data);
    free(packet);
  }
}

Network_packet *
new_network_packet(int size) {
  Network_packet *result;

  result = malloc(sizeof(Network_packet));
  result->size = size;
  result->data = malloc(size);
  result->refcount = 0;
  network_packet_ref(result);
  return result;
}

void
del_network_packet(Network_packet * packet) {
  network_packet_unref(packet);
}

/* belows is the working stuff */
static void
network_connection_wnext(Network_connection *c) {
  assert(c->outp == NULL);
  c->outp = list_pop(c->out_queue);
  /* the integer to be sent */
  c->bytes_sent = -(int)sizeof(uint32_t);
}

static int
network_do_add_listeners(void *data,void *user_data) {
  Network_listener *l = (Network_listener *)data;
  struct network_info *n = (struct network_info *)user_data;

  NETDEBUG("Read listen fd added -> %d",l->fd);
  FD_SET(l->fd,n->rfds);
  n->maxfd = MAX(l->fd,n->maxfd);

  return TRUE;
}

static void
network_add_listeners(List *listeners,fd_set *fds,int *maxfd) {
  struct network_info n;
  if (listeners == NULL) return;
  n.rfds = fds;
  n.wfds = NULL;
  n.maxfd = *maxfd;
  list_foreach(listeners,network_do_add_listeners,&n);
  *maxfd = n.maxfd;
}

static int
network_do_add_connections(void *data,void *user_data) {
  Network_connection *c = (Network_connection *)data;
  struct network_info *n = (struct network_info *)user_data;
  /* Always read */
  if (c->fd<0) WARN("Not reading closed connection");
  if (c->fd < 0) return TRUE;

  NETDEBUG("add read fd %d",c->fd);
  FD_SET(c->fd,n->rfds);
  n->maxfd = MAX(c->fd,n->maxfd);
  /* Only write when needed */
  if (c->outp == NULL) network_connection_wnext(c);
  if (c->outp != NULL) {
    NETDEBUG("add write fd %d",c->fd);
    FD_SET(c->fd,n->wfds);
    n->maxfd = MAX(c->fd,n->maxfd);
  }
  return TRUE;
}

static void
network_add_connections(List *con,fd_set *rfds,fd_set *wfds,int *maxfd) {
  struct network_info n;
  if (con == NULL) return;
  n.rfds = rfds;
  n.wfds = wfds;
  n.maxfd = *maxfd;
  list_foreach(con,network_do_add_connections,&n);
  *maxfd = n.maxfd;
}

static void
network_listener_accept_connection(Network_listener *listener) {
  Network_connection *result;
  int fd;
  struct sockaddr_storage socket;
  socklen_t len = sizeof(socket);

  fd = accept(listener->fd,(struct sockaddr *)&socket,&len);
  if (fd < 0) {
    NETDEBUG("Accept failed: %s",strerror(errno));
    return;
  }
  result = init_network_connection();
  result->fd = fd;

  list_append(listener->new_connection,result);
  NETDEBUG("New connection accepted on fd %d",fd);
}

static int
network_check_listener(void *data, void *user_data) {
  Network_listener *l = (Network_listener *)data;
  struct network_info *n = (struct network_info *)user_data;

  if (FD_ISSET(l->fd,n->rfds)) {
    NETDEBUG("Read available on listener");
    network_listener_accept_connection(l);
  }
  return TRUE;
}

static void
network_check_listeners(List *listeners,fd_set *fds) {
  struct network_info n;
  if (listeners == NULL) return;
  n.rfds = fds;
  n.wfds = NULL;
  list_foreach(listeners,network_check_listener,&n);
}

static int
network_connection_do_recv(Network_connection *con, void *data,size_t len) {
  int ret;

  ret = recv(con->fd,data,len,MSG_DONTWAIT);
  if (ret < 0 && errno == EAGAIN) {
    return 0;
  }
  if (ret <= 0) {
    if (ret < 0) 
      NETDEBUG("Network error on fd %d -> %s",con->fd,strerror(errno));
    else 
      NETDEBUG("Connection to fd %d closed",con->fd);
    close(con->fd);
    con->fd = -1;
    return -1;
  }
  return ret;
}

static void 
network_connection_do_read(Network_connection *con) {
  uint32_t amount;
  int ret = 0;

  NETDEBUG("Reading on fd %d",con->fd);
  if (con->inp == NULL) {
    /* force on int to be received */
    ret = network_connection_do_recv(con,&(con->size) + con->bytes_received,
                                     sizeof(uint32_t) - con->bytes_received);
    if (ret < 0) return;
    con->bytes_received += ret;

    if (con->bytes_received == sizeof(uint32_t)) {
      amount = ntohl(con->size);
      if (amount > 100000) {
        fprintf(stderr,"fd %d pushing a to big packet (%d)",con->fd,amount);
        close(con->fd);
        con->fd = -1;
      }
      con->inp = new_network_packet(amount);
      con->bytes_received = 0;
      NETDEBUG("Going to get packet of size %d",amount);
    }
  }
  if (con->inp != NULL) {
    ret = network_connection_do_recv(con,con->inp->data + con->bytes_received,
                                     con->inp->size - con->bytes_received);
    if (ret < 0) return;
    con->bytes_received += ret;
    NETDEBUG("Got %d bytes on fd %d",ret,con->fd);
    if (con->bytes_received == con->inp->size) {
      list_append(con->in_queue,con->inp);
      con->inp = NULL;
      con->bytes_received = 0;
      NETDEBUG("Got packet on fd %d",con->fd);
    }
  }
}

static int
network_connection_do_send(Network_connection *con, void *msg, size_t len) {
  int ret;

  ret = send(con->fd,msg,len,MSG_DONTWAIT);
  if (ret == EAGAIN) {
    return 0;
  }
  if (ret < 0) {
    NETDEBUG("Network error-> %s",strerror(errno));
    close(con->fd);
    con->fd = -1;
  }
  return ret;
}

static void 
network_connection_do_write(Network_connection *con) {
  int ret,size;
  if (con->bytes_sent < 0) {
    NETDEBUG("Writing packet size on fd %d",con->fd);
    size = htonl(con->outp->size);
    ret = network_connection_do_send(con,
                            &size + con->bytes_sent + sizeof(uint32_t),
                            abs(con->bytes_sent));
    if (ret == -1) return;
    con->bytes_sent += ret;
  }
  if (con->bytes_sent >= 0) {
    NETDEBUG("Writing packet on fd %d",con->fd);
    ret = network_connection_do_send(con,con->outp->data + con->bytes_sent,
                                  con->outp->size - con->bytes_sent);
    if (ret == -1) return;
    con->bytes_sent += ret;
    if (con->bytes_sent == con->outp->size) {
      del_network_packet(con->outp);
      con->outp = NULL;
    }
  }
}

static int
network_check_connection(void *data, void *user_data) {
  Network_connection *c = (Network_connection *)data;
  struct network_info *n = (struct network_info *)user_data;
  if (FD_ISSET(c->fd,n->rfds)) {
    network_connection_do_read(c);
  }
    
  if (c->fd != -1 && FD_ISSET(c->fd,n->wfds)) {
    network_connection_do_write(c);
  }
  return TRUE;
}

static void
network_check_connections(List *con,fd_set *rfds,fd_set *wfds) {
  struct network_info n;
  if (con == NULL) return;
  n.rfds = rfds;
  n.wfds = wfds;
  list_foreach(con,network_check_connection,&n);
}


int
network_update(List * listeners, List * connections, int timeout) {
  struct timeval val;
  fd_set rfds;
  fd_set wfds;
  int maxfd = -1;
  int ret;

  FD_ZERO(&rfds);
  FD_ZERO(&wfds);

  if (timeout > 0) {
    val.tv_sec = timeout / 1000000;
    val.tv_usec = timeout % 1000000;
  }
  network_add_listeners(listeners,&rfds,&maxfd);
  network_add_connections(connections,&rfds,&wfds,&maxfd);

  NETDEBUG("Doing update");
  if (maxfd == -1 && timeout == 0) {
    return TRUE;
  }
  ret = select(maxfd+1,&rfds,&wfds,NULL,(timeout > 0) ? &val : NULL);

  if (ret == -1) { 
    WARN("select failed: %s",strerror(errno));
    exit(-1);
  } else if (ret > 0) {
    network_check_listeners(listeners,&rfds);
    network_check_connections(connections,&rfds,&wfds);
    return TRUE;
  } 
  
  return FALSE;
}
