#include "mgenSocketList.h"

MgenSocketList::MgenSocketList()
  :  head(NULL), tail(NULL), 
     default_tx_buffer(0), default_rx_buffer(0), default_tos(0), 
     default_ttl(3), 
     default_tos_lock(false), default_ttl_lock(false), 
     default_tx_buffer_lock(false), default_rx_buffer_lock(false), 
     default_interface_lock(false)
{
    default_interface[0] = '\0';
}

MgenSocketList::~MgenSocketList()
{
    Destroy();   
}

void MgenSocketList::Destroy()
{
    Item* next = head;
    while (next)
    {
        Item* current = next;
        next = next->next;
        if (current->socket.IsOpen()) current->socket.Close();
        delete current;
        
    }
    head = tail = NULL;  
}  // end MgenSocketList::Destroy()

void MgenSocketList::Prepend(Item* item)
{
    item->prev = NULL;
    if ((item->next = head))
        head->prev = item;
    else
        tail = item;
    head = item;
}  // end MgenSocketList::Prepend()

void MgenSocketList::Append(Item* item)
{
    item->next = NULL;
    if ((item->prev = tail))
        tail->next = item;
    else
        head = item;
    tail = item;
}  // end MgenSocketList::Append()


void MgenSocketList::Remove(Item* item)
{
    if (item->prev)
        item->prev->next = item->next;
    else
        head = item->next;
    if (item->next)
        item->next->prev = item->prev;
    else
        tail = item->prev;
}  // end MgenSocketList::Remove()


// This finds us a socket item with matching "port" number
MgenSocketList::Item* MgenSocketList::FindItemByPort(ProtoSocket::Protocol theProtocol,
                                                     unsigned short        thePort)
{
    Item* next = head;
    while (next)
    {
        if ((next->socket.GetProtocol() == theProtocol) &&
            (next->port == thePort)) return next;
        next = next->next;   
    }
    return (Item*)NULL;
}  // end MgenSocketList::FindItemByPort()


// This search finds an item suitable for joining a multicast group
// We want a matching multicast interface name, a matching port number
// if "port" is non-zero
MgenSocketList::Item* MgenSocketList::FindItemByInterface(ProtoSocket::Protocol theProtocol,
                                                          const char*           interfaceName, 
                                                          unsigned short        thePort)
{
    Item* next = head;
    while (next)
    {
        if (next->socket.GetProtocol() == theProtocol)
        {
            // If "thePort" is non-zero we must a socket of same port
            if (0 != thePort) 
            {
                if (thePort == next->port)
                {
                    // OK, protocol and port matches, so ...
                    if (interfaceName)
                    {
                        if (('\0' == next->interface_name[0]) && // unspecified interface &&
                            (0 == next->group_count))           // no joins yet, so we'll 
                        {                                       // call it a "match"
                            strncpy(next->interface_name, interfaceName, 16);
                            return next;
                        }
                        else
                        {
                            if (!strncmp(next->interface_name, interfaceName, 16))
                                return next;  // it's a "match"
                            else
                            {
                                DMSG(0, "MgenSocketList::FindItemByInterface() matching port but interface mismatch!\n");
                                return NULL;
                            }
                        }
                    }
                    else
                    {
                        if (0 == next->group_count)
                        {
                            next->interface_name[0] = '\0';
                            return next;
                        }
                        else
                        {
                            if ('\0' == next->interface_name[0])
                                return next;
                            else
                            {
                                DMSG(0, "MgenSocketList::FindItemByInterface() matching port but interface mismatch!\n");
                                return NULL;
                            }
                        }
                    }
                }
            }
            else
            {

                bool match = interfaceName ? (0 == strncmp(interfaceName, next->interface_name, 16)) :
                                             ('\0' == next->interface_name[0]);
                if (match) return next;
            }
        }
        next = next->next;   
    }
    return (Item*)NULL;
}  // end MgenSocketList::FindItemByInterface()

/*MgenSocketList::Item* MgenSocketList::FindItemBySocket(const ProtoSocket* socket)
{
    Item* next = head;
    while (next)
    {
        if (&next->socket == socket)
            return next;
        next = next->next;   
    }
    return (Item*)NULL;
}  // end MgenSocketList::FindItemBySocket()*/

MgenSocketList::Item* MgenSocketList::GetItem(ProtoSocket::Protocol theProtocol,
                                              unsigned short        thePort)
{
    Item* item = FindItemByPort(theProtocol, thePort);
    if (!item)
    { 
        if (!(item = new Item(theProtocol, thePort)))
        {
            DMSG(0, "MgenSocketList::GetItem() memory allocation error\n");
            return NULL;
        }
        if (default_ttl)
            item->SetTTL(default_ttl);
        if (default_tos)
            item->SetTOS(default_tos);
        if (default_tx_buffer)
            item->SetTxBufferSize(default_tx_buffer);
        if (default_rx_buffer)
            item->SetRxBufferSize(default_rx_buffer);
        if (GetDefaultMulticastInterface())
            item->SetMulticastInterface(default_interface);
        item->SetSocketNotifier(socket_notifier);
        Prepend(item);
    }
    return item;
}  // end MgenSocketList::GetItem()


// "thePort" is for WIN32 where you must join the group on a specific socket
MgenSocketList::Item* MgenSocketList::JoinGroup(const ProtoAddress&   groupAddress, 
                                                const char*           interfaceName,
                                                unsigned short        thePort)
{

#ifndef IP_MAX_MEMBERSHIPS
#ifdef WIN32
// WIN32 allows one IP multicast membership per socket
#define IP_MAX_MEMBERSHIPS 1
#else
// Solaris (and perhaps other Unix) have no 
// pre-defined limit on group memberhips
#define IP_MAX_MEMBERSHIPS -1    
#endif // if/else WIN32  
#endif // !IP_MAX_MEMBERSHIPS

#ifdef SIMULATE
#ifdef IP_MAX_MEMBERSHIPS
#undef IP_MAX_MEMBERSHIPS
#endif // IP_MAX_MEMBERSHIPS
#define IP_MAX_MEMBERSHIPS -1
#endif // SIMULATE

    // Sockets with space to join are at top of list
    // (TBD) find socket of approprate ProtoAddress::Type ???
    
    bool newItem = false;
    Item* item = FindItemByInterface(ProtoSocket::UDP, interfaceName, thePort);
    if (!item || 
        ((IP_MAX_MEMBERSHIPS > 0) &&
         (item->GroupCount() >= (unsigned int)IP_MAX_MEMBERSHIPS)))
    {
        // Create new "dummy" socket for group joins only
        if (!(item = new Item(ProtoSocket::UDP, thePort)))
        {
            DMSG(0, "MgenSocketList::JoinGroup() memory allocation error: %s\n",
                    GetErrorString());
            return false;
        }
        item->SetSocketNotifier(socket_notifier);
        Prepend(item);
        newItem = true;
    }
    if (item->JoinGroup(groupAddress, interfaceName))
    {
        if ((IP_MAX_MEMBERSHIPS > 0) &&
            (item->GroupCount() >= (unsigned int)IP_MAX_MEMBERSHIPS))
        {
            Remove(item);  // move "full" group socket to end of list
            Append(item);   
        }
        return item;   
    }
    else
    {
        if (newItem)
        {
            Remove(item);
            delete(item);
        }
        return (Item*)NULL;   
    }
}  // end MgenSocketList::JoinGroup()

bool MgenSocketList::LeaveGroup(MgenSocketList::Item* item,
                                const ProtoAddress&   groupAddress, 
                                const char*           interfaceName)
{
    if (item->LeaveGroup(groupAddress, interfaceName))
    {
        Remove(item);
        Prepend(item);  
        return true; 
    }
    else
    {
        return false;
    }
}  // end MgenSocketList::LeaveGroup()

///////////////////////////////////////////////////////////
// MgenSocketList::Item() implementation


MgenSocketList::Item::Item(ProtoSocket::Protocol theProtocol,
                           unsigned short        thePort)
    : socket(theProtocol), tos(0), ttl(0), tx_buffer(0), rx_buffer(0),
      port(thePort), reference_count(0), group_count(0),
      prev(NULL), next(NULL)
{
     interface_name[0] = '\0';
}


bool MgenSocketList::Item::SetTxBufferSize(unsigned int bufferSize)
{
    if (socket.IsOpen())
    {
        if (!socket.SetTxBufferSize(bufferSize))
        {
           return false;    
        } 
    }
    tx_buffer = bufferSize;
    return true;
}  // end MgenSocketList::Item::SetTxBufferSize()

bool MgenSocketList::Item::SetRxBufferSize(unsigned int bufferSize)
{
    if (socket.IsOpen())
    {
        if (!socket.SetRxBufferSize(bufferSize)) 
            return false; 
    }
    rx_buffer = bufferSize;
    return true;
}  // end MgenSocketList::Item::SetRxBufferSize()


bool MgenSocketList::Item::SetTOS(unsigned char tosValue)
{
   if (socket.IsOpen())
   {
       if (!socket.SetTOS(tosValue)) 
           return false;
   }
   tos = tosValue;
   return true;
}  // end MgenSocketList::Item::SetTOS()


bool MgenSocketList::Item::SetTTL(unsigned char ttlValue)
{
   if (socket.IsOpen())
   {
       if (!socket.SetTTL(ttlValue)) 
           return false;
   }
   ttl = ttlValue;
   return true;
}  // end MgenSocketList::Item::SetTTL()

bool MgenSocketList::Item::SetMulticastInterface(const char* interfaceName)
{
    if (interfaceName)
    {
        if (socket.IsOpen())
        {
            if (!socket.SetMulticastInterface(interfaceName))
                return false;
        }
        strncpy(interface_name, interfaceName, 16);
    }
    else
    {
        interface_name[0] = '\0';
    }
    return true;
}  // end MgenSocketList::Item::SetMulticastInterface()


bool MgenSocketList::Item::Open(ProtoAddress::Type addrType, bool bindOnOpen)
{
    if (socket.IsOpen())
    {        
        if (socket.GetAddressType() != addrType)
            DMSG(0, "MgenSocketList::Item::Open() Warning: socket address type mismatch\n");
        if (bindOnOpen && !socket.IsBound())
        {
            if (!socket.Bind(port))
            {
                DMSG(0, "MgenSocketList::Item::Open() socket bind error\n");
                return false;
            }
        }
    }
    else if (!socket.Open(port, addrType, bindOnOpen))
    {
        socket.SetLoopback(false);  //  by default
        DMSG(0, "MgenSocketList::Item::Open() Error: socket open error\n");
        return false;
    }
    if (tx_buffer)
       socket.SetTxBufferSize(tx_buffer);

    if (rx_buffer)
       socket.SetRxBufferSize(rx_buffer);

    if (tos)
       socket.SetTOS(tos);

    if (ttl)
       socket.SetTTL(ttl);

    if ('\0' != interface_name[0])
       socket.SetMulticastInterface(interface_name);

    reference_count++;
    return true;
}  // end MgenSocketList::Item::Open()

void MgenSocketList::Item::Close()
{
    ASSERT(reference_count);
    if (reference_count)
    {
        reference_count--;
        if (!reference_count) socket.Close();   
    }  
}  // end MgenSocketList::Item::Close()

bool MgenSocketList::Item::JoinGroup(const ProtoAddress& groupAddress, 
                                     const char*           interfaceName)
{
    // We make sure to bind sockets with non-zero port number for WIN32
    // "dummy" sockets created to allow large group number joins on Unix
    //  have zero port number and are left unbound
    if (Open(groupAddress.GetType(), (0 != port)))
    {
        if (socket.JoinGroup(groupAddress, interfaceName))
        {
            socket.SetLoopback(true);
            group_count++;
            return true;
        }
        else
        {
            Close();  // decrement reference count
            DMSG(0, "MgenSocketList::Item::JoinGroup() Error: socket join error\n");
            return false;
        }
    }
    else
    {
        DMSG(0, "MgenSocketList::Item::JoinGroup() Error: socket open error\n");
        return false;   
    }
}  // end MgenSocketList::Item::JoinGroup()

bool MgenSocketList::Item::LeaveGroup(const ProtoAddress& groupAddress, 
                                      const char*         interfaceName)
{
    if (socket.LeaveGroup(groupAddress, interfaceName))
    {
        ASSERT(group_count);
        group_count--;
        Close();  // decrements reference_count, closes socket as needed
        return true;
    }
    else
    {
        return false;
    }
}  // end MgenSocketList::Item::LeaveGroup()



syntax highlighted by Code2HTML, v. 0.9.1