Move get_raw_packet() to packet.c and make get_packet() and get_raw_packet()

static functions.
This commit is contained in:
Nicholas J. Kain 2011-06-11 11:05:53 -04:00
parent 52306aa217
commit 6191a07eb3
3 changed files with 75 additions and 7 deletions

View File

@ -1,5 +1,5 @@
/* dhcpmsg.c - dhcp packet generation and sending functions
* Time-stamp: <2011-03-31 03:27:02 nk>
* Time-stamp: <2011-06-11 11:03:22 njk>
*
* (c) 2004-2011 Nicholas J. Kain <njkain at gmail dot com>
* (c) 2001 Russ Dill <Russ.Dill@asu.edu>
@ -46,6 +46,5 @@ int send_selecting(uint32_t xid, uint32_t server, uint32_t requested);
int send_renew(uint32_t xid, uint32_t server, uint32_t ciaddr);
int send_decline(uint32_t xid, uint32_t server, uint32_t requested);
int send_release(uint32_t server, uint32_t ciaddr);
int get_raw_packet(struct dhcpMessage *payload, int fd);
#endif

View File

@ -1,5 +1,5 @@
/* packet.c - send and react to DHCP message packets
* Time-stamp: <2011-06-11 10:58:37 njk>
* Time-stamp: <2011-06-11 11:03:05 njk>
*
* (c) 2004-2011 Nicholas J. Kain <njkain at gmail dot com>
* (c) 2001 Russ Dill <Russ.Dill@asu.edu>
@ -40,8 +40,9 @@
#include "io.h"
#include "options.h"
/* Read a packet from socket fd, return -1 on read error, -2 on packet error */
int get_packet(struct dhcpMessage *packet, int fd)
// Read a packet from a cooked socket. Returns -1 on fatal error, -2 on
// transient error.
static int get_packet(struct dhcpMessage *packet, int fd)
{
int bytes;
@ -59,6 +60,75 @@ int get_packet(struct dhcpMessage *packet, int fd)
return bytes;
}
// Read a packet from a raw socket. Returns -1 on fatal error, -2 on
// transient error.
static int get_raw_packet(struct dhcpMessage *payload, int fd)
{
struct ip_udp_dhcp_packet packet;
uint16_t check;
memset(&packet, 0, IP_UPD_DHCP_SIZE);
int len = safe_read(fd, (char *)&packet, IP_UPD_DHCP_SIZE);
if (len == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
return -2;
log_line("get_raw_packet: read error %s", strerror(errno));
return -1;
}
/* ignore any extra garbage bytes */
len = ntohs(packet.ip.tot_len);
// Validate the IP and UDP headers.
if (packet.ip.protocol != IPPROTO_UDP) {
log_line("IP header is not UDP: %d", packet.ip.protocol);
return -2;
}
if (packet.ip.version != IPVERSION) {
log_line("IP version is not IPv4");
return -2;
}
if (packet.ip.ihl != sizeof packet.ip >> 2) {
log_line("IP header length incorrect");
return -2;
}
check = packet.ip.check;
packet.ip.check = 0;
if (check != checksum(&packet.ip, sizeof packet.ip)) {
log_line("IP header checksum incorrect");
return -2;
}
if (packet.udp.dest != htons(DHCP_CLIENT_PORT)) {
log_line("UDP destination port incorrect: %d", ntohs(packet.udp.dest));
return -2;
}
if (len > IP_UPD_DHCP_SIZE) {
log_line("Data longer than that of a IP+UDP+DHCP message: %d", len);
return -2;
}
if (ntohs(packet.udp.len) != (short)(len - sizeof packet.ip)) {
log_line("UDP header length incorrect");
return -2;
}
/* verify the UDP checksum by replacing the header with a psuedo header */
memset(&packet.ip, 0, offsetof(struct iphdr, protocol));
/* preserved fields: protocol, check, saddr, daddr */
packet.ip.tot_len = packet.udp.len; /* cheat on the psuedo-header */
check = packet.udp.check;
packet.udp.check = 0;
if (check && check != checksum(&packet, len)) {
log_error("Packet with bad UDP checksum received, ignoring");
return -2;
}
memcpy(payload, &packet.data,
len - sizeof packet.ip - sizeof packet.udp);
log_line("Received a packet via raw socket.");
return len - sizeof packet.ip - sizeof packet.udp;
}
/* Compute Internet Checksum for @count bytes beginning at location @addr. */
uint16_t checksum(void *addr, int count)
{

View File

@ -1,5 +1,5 @@
/* packet.h - send and react to DHCP message packets
* Time-stamp: <2011-03-30 23:57:02 nk>
* Time-stamp: <2011-06-11 11:03:14 njk>
*
* (c) 2004-2011 Nicholas J. Kain <njkain at gmail dot com>
* (c) 2001 Russ Dill <Russ.Dill@asu.edu>
@ -66,7 +66,6 @@ enum {
DHCP_SIZE = sizeof(struct dhcpMessage),
};
int get_packet(struct dhcpMessage *packet, int fd);
uint16_t checksum(void *addr, int count);
int raw_packet(struct dhcpMessage *payload, uint32_t source_ip,
int source_port, uint32_t dest_ip, int dest_port,