/*-
 * Copyright (c) 2000 GOTOU YUUZOU <gotoyuzo@notwork.org>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id: icmpmodule.c,v 1.5 2001/09/21 17:19:30 gotoyuzo Exp $
 */
#include <ruby.h>

#include <sys/types.h>
#include <netinet/in_systm.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>
#include <netdb.h>

VALUE mICMPModule;
VALUE cICMP;
VALUE cIP;
VALUE eICMPError;

#define nop(x) x

#define DEFINE_GETTER(type, field, size, func) \
static VALUE                             \
get_##field(self)                        \
    VALUE self;                          \
{                                        \
    type * pack;                         \
    size_t len = 0;                      \
    pack = (type*)rb_str2cstr(self, &len);  \
    if(len < size){                      \
        rb_raise(eICMPError, "invalid packet size."); \
    }                                    \
    return UINT2NUM(func(pack->field));  \
}

#define DEFINE_SETTER(type, field, size, func) \
static VALUE                             \
set_##field(self, val)                   \
    VALUE self;                          \
{                                        \
    type *pack;                          \
    size_t len = 0;                      \
    rb_str_modify(self);                 \
    pack = (type*)rb_str2cstr(self, &len);  \
    if(len < size){                      \
        rb_raise(eICMPError, "invalid packet size."); \
    }                                    \
    pack->field = func(NUM2UINT(val));   \
    return val;                          \
}

#ifndef offsetof
#  define offsetof(f, s) (&(((f*)0)->(s)))
#endif

static u_long
get_ipv4_addr(val)
    VALUE val;
{
    u_long s_addr;
    char *host;
    size_t len;
    struct hostent *hp;

    host = rb_str2cstr(val, &len);
    if((s_addr = inet_addr(host)) == -1){
        if((hp = gethostbyname(host)) == NULL){
           rb_raise(eICMPError, "invalid packet size.");
        }
        s_addr = **(u_long**)hp->h_addr_list;
    }
    return s_addr;
}

static VALUE
icmp_new(argc, argv, klass)
    int argc;
    VALUE *argv, klass;
{
    char icmp_buf[ICMP_ADVLENMIN];
    VALUE tmp, new_icmp;

    if(rb_scan_args(argc, argv, "01", &tmp) == 1){
        new_icmp = rb_str_dup(tmp);
        OBJSETUP(new_icmp, klass, T_STRING);
        return new_icmp;
    }

    memset(&icmp_buf, 0, ICMP_ADVLENMIN);
    new_icmp = rb_str_new((char*)&icmp_buf, ICMP_ADVLENMIN);
    OBJSETUP(new_icmp, klass, T_STRING);
    return new_icmp;
}

static VALUE
get_icmp_gwaddr(self)
    VALUE self;
{
    struct icmp *icmp;
    size_t len;
    u_char *p;
    char buf[16];

    icmp = (struct icmp*)rb_str2cstr(self, &len);
    if(len < ICMP_MINLEN){
        rb_raise(eICMPError, "invalid packet size.");
    }
    p = (char*)&icmp->icmp_gwaddr.s_addr;
    sprintf(buf, "%d.%d.%d.%d", p[0], p[1], p[2], p[3]);
    return rb_str_new2(buf);
}

static VALUE
set_icmp_gwaddr(self, val)
    VALUE self, val;
{
    struct icmp *icmp;
    size_t len;

    rb_str_modify(self);
    icmp = (struct icmp*)rb_str2cstr(self, &len);
    if(len < ICMP_MINLEN){
        rb_raise(eICMPError, "invalid packet size.");
    }
    icmp->icmp_gwaddr.s_addr = NIL_P(val) ? 0 : get_ipv4_addr(val);
    return val;
}

static VALUE
get_icmp_ip(self)
    VALUE self;
{
    struct icmp *icmp;
    size_t len;
    VALUE str;

    icmp = (struct icmp*)rb_str2cstr(self, &len);
    if(len < ICMP_ADVLENMIN){
        rb_raise(eICMPError, "invalid packet size.");
    }
    str = rb_str_new((char*)&icmp->icmp_ip,
                     len - offsetof(struct icmp, icmp_ip));
    OBJSETUP(str, cIP, T_STRING);

    return str;
}

static VALUE
set_icmp_ip(self, val)
    VALUE self, val;
{
    int len;

    rb_str2cstr(val, &len);
    if(len < sizeof(struct ip)){
        rb_raise(eICMPError, "invalid packet size.");
    }
    rb_str2cstr(self, &len);
    if(len < ICMP_ADVLENMIN){
        rb_raise(eICMPError, "invalid packet size.");
    }
    rb_str_resize(self, offsetof(struct icmp, icmp_ip));
    rb_str_concat(self, val);

    return val;
}

static VALUE
get_icmp_data(self)
    VALUE self;
{
    struct icmp *icmp;
    size_t len;

    icmp = (struct icmp*)rb_str2cstr(self, &len);
    if(len <= ICMP_MINLEN){
        return Qnil;
    }
    return rb_str_new(icmp->icmp_data, len - ICMP_MINLEN);
}

static VALUE
set_icmp_data(self, val)
    VALUE self, val;
{
    struct icmp *icmp;
    char *data;
    size_t vlen, len;

    rb_str_modify(self);
    data = rb_str2cstr(val, &vlen);
    
    icmp = (struct icmp*)rb_str2cstr(self, &len);
    if(len - ICMP_MINLEN == vlen){
        memcpy(icmp->icmp_data, data, vlen);
    }
    else{
        rb_str_resize(self, ICMP_MINLEN + vlen);
        icmp = (struct icmp*)rb_str2cstr(self, &len);
        memcpy(icmp->icmp_data, data, vlen);
    }
    return val;
}

static int
icmp_calc_cksum(s, len)
    char *s;
    size_t len;
{
    u_short *p = (u_short*)s;
    int sum = 0;

    while(len > 1){
        sum += *p++;
        len -= 2;
    }
    if(len == 1){
        sum += *(u_char*)p;
    }
    sum = (sum & 0xffff) + (sum >> 16);
    sum += (sum >> 16);
    return ~sum & 0xffff;
}

static VALUE
icmp_truncate(self)
    VALUE self;
{
    struct icmp *icmp;
    size_t olen, nlen; 

    icmp = (struct icmp*)rb_str2cstr(self, &olen);
    switch(icmp->icmp_type){
    case ICMP_IREQ:              /* FALLTHROUGH */
    case ICMP_IREQREPLY:
        nlen = ICMP_MINLEN;
        break;
    case ICMP_UNREACH:           /* FALLTHROUGH */
    case ICMP_TIMXCEED:          /* FALLTHROUGH */
    case ICMP_PARAMPROB:         /* FALLTHROUGH */
    case ICMP_SOURCEQUENCH:      /* FALLTHROUGH */
    case ICMP_REDIRECT:
        nlen = ICMP_ADVLENMIN;
        break;
    case ICMP_TSTAMP:            /* FALLTHROUGH */
    case ICMP_TSTAMPREPLY:
        nlen = ICMP_TSLEN;
        break;
    case ICMP_ROUTERADVERT:      /* FALLTHROUGH */
    case ICMP_ROUTERSOLICIT:
        nlen = ICMP_ADVLEN(icmp);
        break;
    case ICMP_MASKREQ:           /* FALLTHROUGH */
    case ICMP_MASKREPLY:
        nlen = ICMP_MASKLEN;
        break;
    case ICMP_ECHO:              /* FALLTHROUGH */
    case ICMP_ECHOREPLY:
        nlen = olen;
        break;
    default:
        rb_raise(eICMPError, "unknown icmp_type.");
        break;
    }
    if(nlen < olen){
        rb_warning("packet lentgth is not enough.");
    }
    rb_str_resize(self, nlen);
    return self;
}

static VALUE
icmp_set_cksum(self)
    VALUE self;
{
    struct icmp *icmp;
    size_t len;

    rb_str_modify(self);
    icmp = (struct icmp*)rb_str2cstr(self, &len);
    icmp->icmp_cksum = 0;
    icmp->icmp_cksum = icmp_calc_cksum(icmp, len);
    return self;
}

static VALUE
icmp_setup(self)
    VALUE self;
{
    icmp_truncate(self);
    icmp_set_cksum(self);
    return self;
}

DEFINE_GETTER(struct icmp, icmp_type, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_type, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_code, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_code, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_cksum, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_cksum, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_pptr, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_pptr, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_id, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_id, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_seq, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_seq, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_void, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_void, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_pmvoid, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_pmvoid, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_nextmtu, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_nextmtu, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_num_addrs, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_num_addrs, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_wpa, ICMP_MINLEN, nop);
DEFINE_SETTER(struct icmp, icmp_wpa, ICMP_MINLEN, nop);
DEFINE_GETTER(struct icmp, icmp_lifetime, ICMP_MINLEN, ntohs);
DEFINE_SETTER(struct icmp, icmp_lifetime, ICMP_MINLEN, htons);
DEFINE_GETTER(struct icmp, icmp_otime, ICMP_TSLEN, ntohl);
DEFINE_SETTER(struct icmp, icmp_otime, ICMP_TSLEN, htonl);
DEFINE_GETTER(struct icmp, icmp_rtime, ICMP_TSLEN, ntohl);
DEFINE_SETTER(struct icmp, icmp_rtime, ICMP_TSLEN, htonl);
DEFINE_GETTER(struct icmp, icmp_ttime, ICMP_TSLEN, ntohl);
DEFINE_SETTER(struct icmp, icmp_ttime, ICMP_TSLEN, htonl);
DEFINE_GETTER(struct icmp, icmp_mask, ICMP_MASKLEN, nop);
DEFINE_SETTER(struct icmp, icmp_mask, ICMP_MASKLEN, nop);

static VALUE
ip_new(argc, argv, klass)
    int argc;
    VALUE *argv, klass;
{
    struct ip buf;
    struct ip *ip;
    size_t len;
    VALUE tmp, new_ip;

    if(rb_scan_args(argc, argv, "01", &tmp) == 1){
        ip = (struct ip*)rb_str2cstr(tmp, &len);
        if(len < sizeof(struct ip)){
            rb_raise(eICMPError, "invalid packet size.");
        }
        new_ip = rb_str_dup(tmp);
        OBJSETUP(new_ip, klass, T_STRING);
        return new_ip;
    }
    memset(&buf, 0, sizeof(buf));
    new_ip = rb_str_new((char*)&buf, sizeof(buf));
    OBJSETUP(new_ip, klass, T_STRING);
    return new_ip;
}

static VALUE
get_ip_src(self)
    VALUE self;
{
    struct ip *ip;
    size_t len;
    u_char *p;
    char buf[16];

    ip = (struct ip*)rb_str2cstr(self, &len);
    if(len < sizeof(struct ip)){
        rb_raise(eICMPError, "invalid packet size.");
    }
    p = (char*)&ip->ip_src.s_addr;
    sprintf(buf, "%d.%d.%d.%d", p[0], p[1], p[2], p[3]);
    return rb_str_new2(buf);
}

static VALUE
set_ip_src(self, val)
    VALUE self, val;
{
    struct ip *ip;
    size_t len;

    rb_str_modify(self);
    ip = (struct ip*)rb_str2cstr(self, &len);
    if(len < sizeof(struct ip)){
        rb_raise(eICMPError, "invalid packet size.");
    }
    ip->ip_src.s_addr = NIL_P(val) ? 0 : get_ipv4_addr(val);
    return val;
}

static VALUE
get_ip_dst(self)
    VALUE self;
{
    struct ip *ip;
    size_t len;
    u_char *p;
    char buf[16];

    ip = (struct ip*)rb_str2cstr(self, &len);
    if(len < sizeof(struct ip)){
        rb_raise(eICMPError, "invalid packet size.");
    }
    p = (char*)&ip->ip_dst.s_addr;
    sprintf(buf, "%d.%d.%d.%d", p[0], p[1], p[2], p[3]);
    return rb_str_new2(buf);
}

DEFINE_GETTER(struct ip, ip_v, sizeof(struct ip), nop);
DEFINE_GETTER(struct ip, ip_hl, sizeof(struct ip), nop);
DEFINE_GETTER(struct ip, ip_tos, sizeof(struct ip), nop);
DEFINE_GETTER(struct ip, ip_len, sizeof(struct ip), ntohs);
DEFINE_GETTER(struct ip, ip_id, sizeof(struct ip), ntohs);
DEFINE_GETTER(struct ip, ip_off, sizeof(struct ip), ntohs);
DEFINE_GETTER(struct ip, ip_ttl, sizeof(struct ip), nop);
DEFINE_GETTER(struct ip, ip_p, sizeof(struct ip), nop);
DEFINE_GETTER(struct ip, ip_sum, sizeof(struct ip), ntohs);

static VALUE
get_ip_body(self)
    VALUE self;
{
    int len, hlen;
    char *ip;
    VALUE body;

    ip = rb_str2cstr(self, &len);
    if(len < sizeof(struct ip)){
        rb_raise(eICMPError, "invalid IP packet size.");
    }
    hlen = FIX2INT(get_ip_hl(self)) * 4;
    if(hlen < sizeof(struct ip)) rb_raise(eICMPError, "invalid header size.");
    if(len > hlen){
       return rb_str_new(ip + hlen, len - hlen);
    }
    return Qnil;
}

static VALUE
icmp_const_get(id)
    VALUE id;
{
    return rb_const_get(rb_cObject, id);
}

void
Init_icmpmodule()
{
    mICMPModule = rb_rescue(icmp_const_get, rb_intern("ICMPModule"),
                            rb_define_module_id, rb_intern("ICMPModule"));
    eICMPError = rb_define_class_under(mICMPModule,
                                       "ICMPError_c", rb_eStandardError);

    cICMP = rb_define_class_under(mICMPModule, "ICMP_c", rb_cString);
    rb_define_singleton_method(cICMP, "new", icmp_new, -1);
    rb_define_method(cICMP, "icmp_gwaddr", get_icmp_gwaddr, 0);
    rb_define_method(cICMP, "icmp_gwaddr=", set_icmp_gwaddr, 1);
    rb_define_method(cICMP, "icmp_ip", get_icmp_ip, 0);
    rb_define_method(cICMP, "icmp_ip=", set_icmp_ip, 1);
    rb_define_method(cICMP, "icmp_data", get_icmp_data, 0);
    rb_define_method(cICMP, "icmp_data=", set_icmp_data, 1);
    rb_define_method(cICMP, "truncate", icmp_truncate, 0);
    rb_define_method(cICMP, "set_cksum", icmp_set_cksum, 0);
    rb_define_method(cICMP, "setup", icmp_setup, 0);

    rb_define_method(cICMP, "icmp_type", get_icmp_type, 0);
    rb_define_method(cICMP, "icmp_type=", set_icmp_type, 1);
    rb_define_method(cICMP, "icmp_code", get_icmp_code, 0);
    rb_define_method(cICMP, "icmp_code=", set_icmp_code, 1);
    rb_define_method(cICMP, "icmp_cksum", get_icmp_cksum, 0);
    rb_define_method(cICMP, "icmp_cksum=", set_icmp_cksum, 1);
    rb_define_method(cICMP, "icmp_pptr", get_icmp_pptr, 0);
    rb_define_method(cICMP, "icmp_pptr=", set_icmp_pptr, 1);
    rb_define_method(cICMP, "icmp_id", get_icmp_id, 0);
    rb_define_method(cICMP, "icmp_id=", set_icmp_id, 1);
    rb_define_method(cICMP, "icmp_seq", get_icmp_seq, 0);
    rb_define_method(cICMP, "icmp_seq=", set_icmp_seq, 1);
    rb_define_method(cICMP, "icmp_void", get_icmp_void, 0);
    rb_define_method(cICMP, "icmp_void=", set_icmp_void, 1);
    rb_define_method(cICMP, "icmp_pmvoid", get_icmp_pmvoid, 0);
    rb_define_method(cICMP, "icmp_pmvoid=", set_icmp_pmvoid, 1);
    rb_define_method(cICMP, "icmp_nextmtu", get_icmp_nextmtu, 0);
    rb_define_method(cICMP, "icmp_nextmtu=", set_icmp_nextmtu, 1);
    rb_define_method(cICMP, "icmp_num_addrs", get_icmp_num_addrs, 0);
    rb_define_method(cICMP, "icmp_num_addrs=", set_icmp_num_addrs, 1);
    rb_define_method(cICMP, "icmp_wpa", get_icmp_wpa, 0);
    rb_define_method(cICMP, "icmp_wpa=", set_icmp_wpa, 1);
    rb_define_method(cICMP, "icmp_lifetime", get_icmp_lifetime, 0);
    rb_define_method(cICMP, "icmp_lifetime=", set_icmp_lifetime, 1);
    rb_define_method(cICMP, "icmp_otime", get_icmp_otime, 0);
    rb_define_method(cICMP, "icmp_otime=", set_icmp_otime, 1);
    rb_define_method(cICMP, "icmp_rtime", get_icmp_rtime, 0);
    rb_define_method(cICMP, "icmp_rtime=", set_icmp_rtime, 1);
    rb_define_method(cICMP, "icmp_ttime", get_icmp_ttime, 0);
    rb_define_method(cICMP, "icmp_ttime=", set_icmp_ttime, 1);
    rb_define_method(cICMP, "icmp_mask", get_icmp_mask, 0);
    rb_define_method(cICMP, "icmp_mask=", set_icmp_mask, 1);

    cIP = rb_define_class_under(mICMPModule, "IP_c", rb_cString);
    rb_define_singleton_method(cIP, "new", ip_new, -1);
    rb_define_method(cIP, "ip_src", get_ip_src, 0);
    rb_define_method(cIP, "ip_dst", get_ip_dst, 0);
    rb_define_method(cIP, "ip_v", get_ip_v, 0);
    rb_define_method(cIP, "ip_hl", get_ip_hl, 0);
    rb_define_method(cIP, "ip_tos", get_ip_tos, 0);
    rb_define_method(cIP, "ip_len", get_ip_len, 0);
    rb_define_method(cIP, "ip_id", get_ip_id, 0);
    rb_define_method(cIP, "ip_off", get_ip_off, 0);
    rb_define_method(cIP, "ip_ttl", get_ip_ttl, 0);
    rb_define_method(cIP, "ip_p", get_ip_p, 0);
    rb_define_method(cIP, "ip_sum", get_ip_sum, 0);
    rb_define_method(cIP, "body", get_ip_body, 0);
}


syntax highlighted by Code2HTML, v. 0.9.1