/*
 **************************************************************************
 *
 * Boot-ROM-Code to load an operating system across a TCP/IP network.
 *
 * Module:  resolve.c
 * Purpose: Implement a simple DNS name resolver
 * Entries: resolve, res_config
 *
 **************************************************************************
 *
 * Copyright (C) 1995-2003 Gero Kuhlmann <gero@gkminix.han.de>
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: resolve.c,v 1.5 2003/01/25 23:29:41 gkminix Exp $
 *
 **************************************************************************
 *
 * We can print error messages within this module as it gets never called
 * by a PXE function.
 */


#include <general.h>
#include <kernel/net.h>
#include <kernel/arpa.h>
#include <kernel/romlib.h>
#include <pxe/common.h>
#include "resolve.h"



/*
 **************************************************************************
 * 
 * Local error codes:
 */
#define ERR_NOERR	0		/* no error		*/
#define ERR_TIMEOUT	1		/* timeout error	*/
#define ERR_INVALID	2		/* invalid packet	*/
#define ERR_SERVER	3		/* server error		*/
#define ERR_SEND	4		/* send error		*/



/*
 **************************************************************************
 * 
 * Global variables:
 *
 * Note that IP addresses are in network order!!!
 */
#ifndef NODNS
static unsigned char dnsbuf[DNS_UDP_LEN];	/* DNS packet buffer	*/
static unsigned char namebuf[MAX_NAME_LEN];	/* temp name buffer	*/
static unsigned char reqbuf[MAX_NAME_LEN];	/* temp request buffer	*/
static unsigned char mydomain[MAX_NAME_LEN];	/* my own domain name	*/
static t_ipaddr servers[MAX_NS];		/* known name servers	*/
static t_ipaddr ip_last = n_IP_ANY;	/* last IP address from server	*/
static unsigned short dns_xid;		/* Transaction ID		*/
static unsigned int dns_len;		/* length of data in dnsbuf	*/
static unsigned int req_len;		/* length of name in reqbuf	*/
static int ns_num;			/* number of name servers	*/
static int recv_err;			/* receive error		*/
#endif



/*
 **************************************************************************
 * 
 * Convert an internet address in ascii form into binary network form
 */
static t_ipaddr conv_ip(name)
char *name;
{
  register char *cp;
  unsigned char ip[4];
  int dots;
  int num, i;

  dots = 0;
  cp = name;
  while (dots < 4) {
	i = 0;
	num = 0;
	while (*cp >= '0' && *cp <= '9') {
		num = num * 10 + (*cp++ - '0');
		i++;
	}
	if (i == 0 || num > 255 ||
	    (dots < 3 && *cp != '.') ||
	    (dots == 3 && *cp != '\0' && *cp != ':'))
		return(n_IP_ANY);
	ip[dots++] = num;
	cp++;
  }
  return(*((t_ipaddr *)ip));
}



#ifndef NODNS
/*
 **************************************************************************
 * 
 * Copy a domain name into the request buffer in label form
 */
static int copy_name(buf, name)
unsigned char **buf;
unsigned char *name;
{
  register unsigned char *cp;
  unsigned char c;
  unsigned char *lenp;
  unsigned char *bp;
  int dots;
  int i;

  lenp = *buf;
  bp = (*buf) + 1;
  if (name == NULL)
	return(-1);

  dots = 0;
  for (cp = name, i = 0; *cp && *cp != ':'; cp++) {
	if (bp >= reqbuf + sizeof(reqbuf) - 1)
		return(-1);
	if (*cp == '.') {
		if (i == 0 || i > MAX_LABEL_LEN) return(-1);
		dots++;
		*lenp = i;
		lenp = bp++;
		i = 0;
	} else {
		c = *cp;
		if (c >= 'A' && c <= 'Z')
			c += 32;
		if ((c < 'a' || c > 'z') &&
		    (i == 0 || ((c < '0' || c > '9') && c != '-')))
			return(-1);
		*bp++ = c;
		i++;
	}
  }

  if (i == 0 || i > MAX_LABEL_LEN)
	return(-1);
  *lenp = i;
  *buf = bp;
  return(dots);
}



/*
 **************************************************************************
 * 
 * Create a request domain name in labelized form. This also checks that
 * the name conforms to the relevant RFC's.
 */
static int create_name(name)
char *name;
{
  unsigned char *rp;
  int i;

  rp = reqbuf;
  if ((i = copy_name(&rp, (unsigned char *)name)) < 0 ||
      (i == 0 && copy_name(&rp, mydomain) < 0))
	return(FALSE);
  *rp++ = '\0';			/* this zero byte represents the root domain */
  req_len = rp - reqbuf;
  return(TRUE);
}



#ifdef NSDEBUG
/*
 **************************************************************************
 * 
 * Print a labelized name
 */
static void print_name(cp)
unsigned char *cp;
{
  int i;

  while (*cp) {
	i = *cp++;
	printf("%ls.", cp, i);
	cp += i;
  }
}
#endif



/*
 **************************************************************************
 * 
 * Send a query record
 */
static int send_query()
{
#define dp ((struct dnshdr *)dnsbuf)

  register unsigned short *bp;

  /*
   * Setup the request header. Note that this is a rather simple resolver
   * which requires the server to support recursion if it doesn't have an
   * authoratative answer. By setting the recursion flag we avoid doing
   * all the nameserver and alias lookup stuuf ourselves. However, a draw-
   * back is that some name servers don't support recursion or have it
   * intentionally disabled...
   */
  memset(dnsbuf, 0, sizeof(dnsbuf));
  dp->xid = dns_xid;
  dp->flags = htons(HEADER_RD | OPCODE_QUERY);
  dp->qdcount = htons(1);

  /*
   * Setup the request record by copying the labelized name and type and class
   * values into the send buffer.
   */
  bp = (unsigned short *)(dnsbuf + sizeof(struct dnshdr));
  memcpy((unsigned char *)bp, reqbuf, req_len);
  bp = (unsigned short *)((unsigned char *)bp + req_len);
  *bp++ = htons(DNS_TYPE_A);
  *bp++ = htons(DNS_CLASS_IN);
#ifdef NSDEBUG
  printf("Sending DNS request for ");
  print_name(reqbuf);
  printf("\n");
#endif

  /* Finally send the request */
  return(udp_write((unsigned char *)dnsbuf, getds(),
		(int)((unsigned char *)bp - dnsbuf)) == PXENV_STATUS_SUCCESS);
#undef dp
}



/*
 **************************************************************************
 * 
 * Decode a string in dnsbuf to labelized form while uncompressing it
 */
static unsigned char *ndecode(name)
unsigned char *name;
{
  register unsigned char *endp;
  register unsigned char *cp;
  unsigned char *bp;
  int i;

  endp = NULL;
  cp = name;
  bp = namebuf;
  do {
	if ((*cp & COMPR_MASK) == COMPR_MASK) {
		if (endp == NULL)
			endp = cp + 2;
		i = ntohs(*((unsigned short *)cp)) & OFFSET_MASK;
		cp = (unsigned char *)dnsbuf + i;
	} else for (i = *cp + 1; i > 0 && bp < namebuf + MAX_NAME_LEN; i--)
		*bp++ = *cp++;
	if (cp >= (unsigned char *)dnsbuf + dns_len)
		return(NULL);
  } while (*cp);

  return(endp == NULL ? ++cp : endp);
}



/*
 **************************************************************************
 * 
 * Decode a received DNS packet
 */
static t_ipaddr recv_answer()
{
#define rpp ((struct rr *)cp)
#define dp ((struct dnshdr *)dnsbuf)

  register unsigned char *cp;
  int flags, recnum, status;

  /* Wait for a UDP packet and return with error if timeout */
  dns_len = sizeof(dnsbuf);
  status = udp_read((unsigned char *)dnsbuf, getds(), &dns_len,
							DNS_TIMEOUT, 0);
  if (status != PXENV_STATUS_SUCCESS) {
	recv_err = ERR_TIMEOUT;
	return(n_IP_ANY);
  }

  /* Check for correct header and simply skip invalid packets */
  flags = ntohs(dp->flags);
  if (dns_len > DNS_UDP_LEN ||
      (flags & HEADER_QR) != HEADER_QR ||
      (flags & HEADER_OPCODE) != OPCODE_QUERY ||
      dp->xid != dns_xid) {
	recv_err = ERR_INVALID;
	return(n_IP_ANY);
  }

  /* Decode the packet */
  if ((flags & HEADER_RCODE) == RCODE_NOERR) {
	/* Skip all query records */
	cp = dnsbuf + sizeof(struct dnshdr);
	if (ntohs(dp->qdcount) != 1 || (cp = ndecode(cp)) == NULL ||
	    memcmp(reqbuf, namebuf, req_len)) {
		recv_err = ERR_INVALID;
		return(n_IP_ANY);
	}
	cp += 4;

	/* Decode the first address record */
	for (recnum = 0; recnum < ntohs(dp->ancount); recnum++) {
		if ((cp = ndecode(cp)) == NULL) {
			recv_err = ERR_INVALID;
			return(n_IP_ANY);
		}
#ifdef NSDEBUG
		printf("Received DNS answer for ");
		print_name(namebuf);
		printf("\n");
#endif
		if (rpp->type == htons(DNS_TYPE_A) &&
		    rpp->class == htons(DNS_CLASS_IN) &&
		    rpp->rdlength >= htons(IP_ALEN) &&
		    !memcmp(reqbuf, namebuf, req_len)) {
			recv_err = ERR_NOERR;
			return(*((t_ipaddr *)(cp + sizeof(struct rr))));
		}
		cp += ntohs(rpp->rdlength) + sizeof(struct rr);
	}
  }

  /* Decode error condition */
  flags &= HEADER_RCODE;
  recv_err = ERR_SERVER;
  if (flags != RCODE_NAME) {
	if (flags == RCODE_NOERR)
		recv_err = ERR_INVALID;
	else
		printf("\nDNS: server error %x\n", flags);
  }
  return(n_IP_ANY);

#undef dp
#undef rrp
}
#endif /* NODNS */



/*
 **************************************************************************
 * 
 * Resolve a host name using DNS
 */
t_ipaddr resolve(name)
char *name;
{
  t_ipaddr ip;
#ifndef NODNS
  int ns, retry;
#endif

  /* First check if the name is already in IP number form */
  if ((ip = conv_ip(name)) != n_IP_ANY)
	return(ip);

#ifndef NODNS
  /* For local addresses we don't have to go through the resolver */
  if (!memcmp(name, "localhost", 9))
	return(n_IP_LOCALHOST);
  else if (!memcmp(name, "broadcast", 9))
	return(n_IP_BROADCAST);

  /*
   * If there are no domain name servers initialized, we can't continue. Other-
   * wise create a labelized domain name to resolve.
   */
  if (ns_num == 0 || !create_name(name)) {
	printf("\nDNS: invalid host name\n");
	return(n_IP_ANY);
  }

  /*
   * If this request is for the same name as the previous request, we can
   * just return the IP number and don't need to ask the server.
   */
  if (req_len > 0 && ip_last != n_IP_ANY && !memcmp(reqbuf, namebuf, req_len))
	return(ip_last);

  /* Now query each name server in turn */
  for (ns = 0; ns < ns_num; ns++) {
	dns_xid = (int)(get_ticks() + random(0x7fff));
	if (udp_open(servers[ns], n_IP_ANY,
			htons(DNS_C_PORT),
			htons(DNS_S_PORT)) != PXENV_STATUS_SUCCESS) {
		printf("\nDNS: udp_open error\n");
		break;
	}
	retry = 0;
	while (retry++ < DNS_RETRY) {
		if (!send_query()) {
			recv_err = ERR_SEND;
			break;
		}
		if ((ip = recv_answer()) != n_IP_ANY) {
			udp_close();
			ip_last = ip;
			return(ip);
		}
		/* If a server error occurred, retries are useless */
		if (recv_err == ERR_SERVER)
			break;
	}
	if (recv_err == ERR_TIMEOUT)
		printf("\nDNS: timeout\n");
	else if (recv_err == ERR_INVALID)
		printf("\nDNS: invalid packets\n");
	else if (recv_err == ERR_SEND)
		printf("\nDNS: unable to send request\n");
	if (udp_close() != PXENV_STATUS_SUCCESS) {
		printf("\nDNS: udp_close error\n");
		break;
	}
  }
#endif

  printf("\nDNS: can't resolve host name\n");
  return(n_IP_ANY);
}



/*
 **************************************************************************
 * 
 * Initialize name server list and current domain from BOOTP record
 */
void res_config()
{
#ifndef NODNS
  register unsigned char *cp;

  /* Copy name server information */
  cur_bootp_buf = BOOTP_REPLY;
  if ((cp = get_vend(VEND_DNS)) != NULL) {
	t_ipaddr *ip = (t_ipaddr *)(cp + 1);
	int i = *cp;
	ns_num = 0;
	while (ns_num < MAX_NS && i >= IP_ALEN) {
		servers[ns_num++] = *ip++;
		i -= IP_ALEN;
	}
  }

  /* Copy domain name information */
  if ((cp = get_vend(VEND_DOMAIN)) != NULL) {
	int i = *cp++;
	if (i < (MAX_NAME_LEN - 1)) {
		memcpy(mydomain, cp, i);
		mydomain[i] = '\0';
	}
  }
#endif
}



syntax highlighted by Code2HTML, v. 0.9.1