/*
 * Copyright (C), 2000-2007 by the monit project group.
 * All Rights Reserved.
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include "config.h"

#ifdef HAVE_SYS_TYPES_H
#include <sys/types.h>
#endif

#ifdef HAVE_STRING_H
#include <string.h>
#endif

#ifdef HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif 

#ifdef HAVE_STRINGS_H
#include <strings.h>
#endif

#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif

#include "net.h"
#include "ssl.h"
#include "monitor.h"
#include "socket.h"



/**
 * Implementation of the socket interface.
 * 
 * @author Jan-Henrik Haukeland, <hauk@tildeslash.com>
 * @version \$Id: socket.c,v 1.63 2007/10/02 13:58:49 hauk Exp $
 * @file
 */


/* ------------------------------------------------------------- Definitions */

#define TYPE_LOCAL   0
#define TYPE_ACCEPT  1
#define RBUFFER_SIZE 1024

struct Socket_T {
  int port;
  int type;
  int socket;
  char *host;
  Port_T Port;
  int timeout;
  int connection_type;
  ssl_connection *ssl;
  ssl_server_connection *sslserver;
  int length;
  int offset;
  unsigned char buffer[RBUFFER_SIZE+1];
};


/* -------------------------------------------------------------- Prototypes */


static int fill(Socket_T S, int timeout);


/* ------------------------------------------------------------------ Public */


/**
 * Create a new Socket opened against host:port. The returned Socket
 * is a connected socket. This method can be used to create either TCP
 * or UDP sockets and the type parameter is used to select the socket
 * type. If the use_ssl parameter is TRUE the socket is created using
 * SSL. Only TCP sockets may use SSL.
 * @param host The remote host to open the Socket against. The host
 * may be a hostname found in the DNS or an IP address string.
 * @param port The port number to connect to
 * @param type The socket type to use (SOCKET_TCP or SOCKET_UPD)
 * @param use_ssl If TRUE the socket is created supporting SSL
 * @param timeout The timeout value in seconds
 * @return The connected Socket or NULL if an error occurred
 */
Socket_T socket_new(const char *host, int port, int type, int use_ssl,
                    int timeout) {
  
  Ssl_T ssl;    
  
  ssl.use_ssl= use_ssl;
  ssl.version= SSL_VERSION_AUTO;
  ssl.certmd5= NULL;
  
  return socket_create_t(host, port, type, ssl, timeout);
  
}


/**
 * Factory method for creating a new Socket from a monit Port object
 * @param port The port object to create a socket from
 * @return The connected Socket or NULL if an error occurred
 */
Socket_T socket_create(void *port) {
  
  int s;
  Port_T p= port;
  
  ASSERT(port);
  
  if((s= create_generic_socket(p)) != -1) {
    
    Socket_T S= NULL;
    
    NEW(S);
    S->socket= s;
    S->length= 0;
    S->offset= 0;
    S->type= p->type;
    S->port= p->port;
    S->timeout= p->timeout;
    S->connection_type= TYPE_LOCAL;
    
    if(p->family==AF_UNIX) {
      S->host= xstrdup(LOCALHOST);
    } else {
      S->host= xstrdup(p->hostname);
    }
    
    if(p->SSL.use_ssl && !socket_switch2ssl(S, p->SSL)) {
      socket_free(&S);
      return NULL;
    }
    
    S->Port= port;
    return S;
  }
  
  return NULL;
}


/**
 * Create a new Socket opened against host:port with an explicit
 * ssl value for connect and read. Otherwise, same as socket_new()
 * @param host The remote host to open the Socket against. The host
 * may be a hostname found in the DNS or an IP address string.
 * @param port The port number to connect to
 * @param type The socket type to use (SOCKET_TCP or SOCKET_UPD)
 * @param ssl Options for SSL
 * @param timeout The timeout value in seconds
 * @return The connected Socket or NULL if an error occurred
 */
Socket_T socket_create_t(const char *host, int port, int type, Ssl_T ssl,
                         int timeout) {
  
  int s;
  int proto= type==SOCKET_UDP?SOCK_DGRAM:SOCK_STREAM;
  
  ASSERT(host);
  ASSERT((type==SOCKET_UDP)||(type==SOCKET_TCP));
  if(ssl.use_ssl) {
    ASSERT(type==SOCKET_TCP);
  }
  ASSERT(timeout>0);
  
  if((s= create_socket(host, port, proto, timeout)) != -1) {
    
    Socket_T S= NULL;
    
    NEW(S);
    S->socket= s;
    S->length= 0;
    S->offset= 0;
    S->port= port;
    S->type= proto;
    S->timeout= timeout;
    S->host= xstrdup(host);
    S->connection_type= TYPE_LOCAL;
    
    if(ssl.use_ssl && !socket_switch2ssl(S, ssl)) {
      socket_free(&S);
      return NULL;
    }
    
    return S;
  }
  
  return NULL;
}


/**
 * Factory method for creating a Socket object from an accepted
 * socket. The given socket must be a socket created from accept(2).
 * If the sslserver context is non-null the socket will support
 * ssl. This method does only support TCP sockets.
 * @param socket The accepted socket
 * @param remote_host The remote host from where the socket connection
 * originated
 * @param port The localhost port number from where the connection
 * arrived.
 * @param sslserver A ssl server connection context, may be NULL
 * @return A Socket or NULL if an error occurred
 */
Socket_T socket_create_a(int socket, const char *remote_host,
                         int port, void *sslserver) {
  
  Socket_T S;
  
  ASSERT(socket>=0);
  ASSERT(remote_host);
  
  NEW(S);
  S->length= 0;
  S->offset= 0;
  S->port= port;
  S->socket= socket;
  S->type= SOCK_STREAM;
  S->timeout= NET_TIMEOUT;
  S->host= xstrdup(remote_host);
  S->connection_type= TYPE_ACCEPT;
  
  if(sslserver) {
    S->sslserver= sslserver;
    if(! (S->ssl= insert_accepted_ssl_socket(S->sslserver))) {
      goto ssl_error;
    }
    if(! embed_accepted_ssl_socket(S->ssl, S->socket)) {
      goto ssl_error;
    }
  }
  
  return S;
  
ssl_error:
    socket_free(&S);
  return NULL;
  
}


/**
 * Destroy a Socket object. Close the socket and release allocated
 * resources. 
 * @param S A Socket object reference
 */
void socket_free(Socket_T *S) {
  
  ASSERT(S && *S);
  
#ifdef HAVE_OPENSSL
  if((*S)->ssl && (*S)->ssl->handler) {
    if((*S)->connection_type==TYPE_LOCAL) {
      close_ssl_socket((*S)->ssl);
      delete_ssl_socket((*S)->ssl);
    } else if((*S)->connection_type==TYPE_ACCEPT && (*S)->sslserver) {
      close_accepted_ssl_socket((*S)->sslserver, (*S)->ssl);
    }
  }
#endif
  
  close_socket((*S)->socket);
  FREE((*S)->host);
  FREE(*S);
  
}


/* ------------------------------------------------------------ Properties */


/**
 * Returns TRUE if the socket is ready for i|o
 * @param S A Socket object
 * @return TRUE if the socket is ready otherwise FALSE 
 */
int socket_is_ready(Socket_T S) {
  
  ASSERT(S);
  
  switch(S->type) {
    
    case SOCK_STREAM:
      return check_socket(S->socket);
      
    case SOCK_DGRAM:
      return check_udp_socket(S->socket);
      
    default:
      break;
  }
  
  return FALSE;
  
}


/**
 * Get the underlying socket descriptor
 * @param S A Socket object
 * @return The socket descriptor
 */
int socket_get_socket(Socket_T S) {
  
  ASSERT(S);
  
  return S->socket;
  
}


/**
 * Get the type of this socket.
 * @param S A Socket object
 * @return The socket type
 */
int socket_get_type(Socket_T S) {
  
  ASSERT(S);
  
  return S->type;
  
}


/**
 * Get the Port object used to create this socket. If no Port object
 * was used this method returns NULL.
 * @param S A Socket object
 * @return The Port object or NULL
 */
void *socket_get_Port(Socket_T S) {
  
  ASSERT(S);
  
  return S->Port;
  
}


/**
 * Get the remote port number the socket is connected to
 * @param S A Socket object
 * @return The remote host's port number
 */
int socket_get_remote_port(Socket_T S) {
  
  ASSERT(S);
  
  return S->port;
  
}


/**
 * Get the remote host this socket is connected to. The host is either
 * a host name in DNS or an IP address string.
 * @param S A Socket object
 * @return The remote host
 */
const char *socket_get_remote_host(Socket_T S) {
  
  ASSERT(S);
  
  return S->host;
  
}


/* ---------------------------------------------------------------- Public */


/**
 * Switches a connected socket to ssl.
 * @param S The already connected socket
 * @param ssl Options for ssl
 * @return TRUE if ssl is ready otherwise FALSE
 */
int socket_switch2ssl(Socket_T S, Ssl_T ssl)  {

  if(! (S->ssl= new_ssl_connection(NULL, ssl.version)))
    return FALSE;

  if(! embed_ssl_socket(S->ssl, S->socket))
    return FALSE;

  if(ssl.certmd5 && !check_ssl_md5sum(S->ssl, ssl.certmd5)) {
    LogError("md5sum of certificate does not match!");
    return FALSE;
  }

  return TRUE;
}


/**
 * Writes a character string. Use this function to send text based
 * messages to a client.
 * @param S A Socket_T object
 * @param m A String to send to the client
 * @return The bytes sent or -1 if an error occured
 */
int socket_print(Socket_T S, const char *m, ...) {
  
  int n;
  long l;
  va_list ap;
  char *buf= NULL;
  
  ASSERT(S);
  ASSERT(m);
  
  va_start(ap, m);
  buf= Util_formatString(m, ap, &l);
  va_end(ap);
  
  n= socket_write(S, buf, l);
  FREE(buf);
  
  return n;
  
}


/**
 * Write size bytes from the buffer b.
 * @param S A Socket_T object
 * @param b The data to be written
 * @param size The size of the data in b
 * @return The bytes sent or -1 if an error occured
 */
int socket_write(Socket_T S, void *b, int size) {
  
  int n= 0;
  void *p= b;
  
  ASSERT(S);
  
  /* Clear any extra data read from the server */
  socket_reset(S);

  while(size > 0) {
    
    if(S->ssl) {
      n= send_ssl_socket(S->ssl, p, size, S->timeout);
    } else {
      if(S->type==SOCK_DGRAM)
        n= udp_write(S->socket,  p, size, S->timeout);
      else
        n= sock_write(S->socket,  p, size, S->timeout);
    }
    
    if(n <= 0) break;
    p+= n;
    size-= n;
    
  }
  
  if(n < 0) {
    /* No write or a partial write is an error */
    return -1;
  }
  
  return  (int)(p - b);
  
}


/**
 * Read a single byte. The byte is returned as an int in the range 0
 * to 255.
 * @param S A Socket_T object
 * @return The byte read, or -1 if the end of the stream has been reached
 */
int socket_read_byte(Socket_T S) {
  
  ASSERT(S);
  
  if(S->offset >= S->length) {
    if(fill(S, S->timeout) <= 0)
      return -1;
  }
  
  return S->buffer[S->offset++];
  
}


/**
 * Reads size bytes and stores them into the byte buffer pointed to by b.
 * @param S A Socket_T object
 * @param b A Byte buffer
 * @param size The size of the buffer b
 * @return The bytes read or -1 if an error occured
 */
int socket_read(Socket_T S, void *b, int size) {
  
  int c;
  unsigned char *p= b;
  
  ASSERT(S);
  
  while((size-- > 0) && ((c= socket_read_byte(S)) >= 0)) { 
    *p++= c;
  }
  
  return  (long)p - (long)b;
  
}


/**
 * Reads in at most one less than size <code>characters</code> and
 * stores them into the buffer pointed to by s. Reading stops after
 * an EOF or a newline.  If a newline is read, it is stored into the
 * buffer.   A '\0' is stored after the last character in the buffer.
 * @param S A Socket_T object
 * @param s A character buffer to store the string in
 * @param size The size of the string buffer, s
 * @return s on success, and NULL on error or when end of file occurs
 * while no characters have been read.
 */
char *socket_readln(Socket_T S, char *s, int size) {
  
  int c;
  unsigned char *p= (unsigned char *)s;
  
  ASSERT(S);
  
  while(--size && ((c= socket_read_byte(S)) > 0)) {
    *p++= c;
    if(c == '\n' || c == '\0')
      break;
  }
  
  *p= 0;
  
  if(*s)
    return s;
  
  return NULL;
  
}


/**
 * Clears any data that exists in the input buffer
 * @param S A Socket_T object
 */
void socket_reset(Socket_T S) {

  ASSERT(S);
  
  /* Throw away any pending incomming data */
  while(fill(S, 0) > 0);
  S->offset= 0;
  S->length= 0;
  
}


/* --------------------------------------------------------------- Private */


/*
 * Fill the internal buffer. If an error occurs or if the read
 * operation timed out -1 is returned.
 * @param S A Socket object
 * @param timeout The number of seconds to wait for data to be read
 * @return TRUE (the length of data read) or -1 if an error occured
 */
static int fill(Socket_T S, int timeout) {
  
  int n;
  
  S->offset= 0;
  S->length= 0;
  /* Optimizing, assuming a request/response pattern and that a udp_write
     was issued before we are called, we don't have to wait for data */
  if(S->type==SOCK_DGRAM) timeout= 0; 
  
  /* Read as much as we can, but only block on the first read */
  while(RBUFFER_SIZE > S->length) {
    
    if(S->ssl) {
      n= recv_ssl_socket(S->ssl, S->buffer + S->length, 
                         RBUFFER_SIZE-S->length, timeout);
    } else {
      n= sock_read(S->socket, S->buffer + S->length, 
                   RBUFFER_SIZE-S->length, timeout);
    }
    
    timeout= 0;
    
    if(n > 0) {
      S->length+= n;
      continue;
    }  else if(n < 0) {
      if(errno == EAGAIN || errno == EWOULDBLOCK || S->type==SOCK_DGRAM) break;
      return -1;
    } else
      break;

  }
  
  return S->length;
  
}


syntax highlighted by Code2HTML, v. 0.9.1