Define stateless functions to validate the checksum of an ip header and of a

udp packet and use them in get_raw_packet().

Print a warning if the raw UDP receive gets a quantum of data that is greater
than that of a single UDP datagram.

Remove unnecessary argument from net_checksum().  Initializing a nonzero
checksum value is not very helpful in practice.

Define a function net_checksum_add() that, for two sequences of bytes A and B
that return checksums CS(A) and CS(B), will calculate the checksum CS(AB) of
the concatenated value AB given the checksums of the individual parts CS(A)
and CS(B).
This commit is contained in:
Nicholas J. Kain 2011-06-26 16:33:07 -04:00
parent 7d0e05504f
commit b70070e592

View File

@ -202,14 +202,25 @@ static int get_packet(struct dhcpMessage *packet, int fd)
return bytes;
}
// When summing ones-complement 16-bit values using a 32-bit unsigned
// representation, fold the carry bits that have spilled into the upper
// 16-bits of the 32-bit unsigned value back into the 16-bit ones-complement
// binary value.
static inline uint16_t foldcarry(uint32_t v)
{
v = (v >> 16) + (v & 0xffff);
v += v >> 16;
return v;
}
// This function is not suitable for summing buffers that are greater than
// 128k-1 bytes in length: failure case will be incorrect checksums via
// unsigned overflow, which is a defined operation and is safe. This limit
// should not be an issue for IPv4 or IPv6 packet, which are limited to
// at most 64k bytes.
static uint16_t net_checksum(void *buf, size_t size, uint16_t init)
static uint16_t net_checksum(void *buf, size_t size)
{
uint32_t sum = init;
uint32_t sum = 0;
int odd = size & 0x01;
size_t i;
size &= ~((size_t)0x01);
@ -225,9 +236,36 @@ static uint16_t net_checksum(void *buf, size_t size, uint16_t init)
uint16_t lo = 0;
sum += ntohs((lo + (hi << 8)));
}
sum = (sum >> 16) + (sum & 0xffff);
sum += sum >> 16;
return ~sum;
return ~foldcarry(sum);
}
// For two sequences of bytes A and B that return checksums CS(A) and CS(B),
// this function will calculate the checksum CS(AB) of the concatenated value
// AB given the checksums of the individual parts CS(A) and CS(B).
static inline uint16_t net_checksum_add(uint16_t a, uint16_t b)
{
return ~foldcarry((~a & 0xffff) + (~b & 0xffff));
}
// Returns 1 if IP checksum is correct, otherwise 0.
static int ip_checksum(struct ip_udp_dhcp_packet *packet)
{
return net_checksum(&packet->ip, sizeof packet->ip) == 0;
}
// Returns 1 if UDP checksum is correct, otherwise 0.
static int udp_checksum(struct ip_udp_dhcp_packet *packet)
{
struct iphdr ph = {
.saddr = packet->ip.saddr,
.daddr = packet->ip.daddr,
.protocol = packet->ip.protocol,
.tot_len = packet->udp.len,
};
uint16_t udpcs = net_checksum(&packet->udp, ntohs(packet->udp.len));
uint16_t hdrcs = net_checksum(&ph, sizeof ph);
uint16_t cs = net_checksum_add(udpcs, hdrcs);
return cs == 0;
}
// Read a packet from a raw socket. Returns -1 on fatal error, -2 on
@ -235,9 +273,8 @@ static uint16_t net_checksum(void *buf, size_t size, uint16_t init)
static int get_raw_packet(struct dhcpMessage *payload, int fd)
{
struct ip_udp_dhcp_packet packet;
uint16_t check;
memset(&packet, 0, IP_UPD_DHCP_SIZE);
int len = safe_read(fd, (char *)&packet, IP_UPD_DHCP_SIZE);
if (len == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
@ -247,7 +284,11 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd)
}
/* ignore any extra garbage bytes */
if (ntohs(packet.ip.tot_len) != len) {
log_line("Received %u bytes for a %u byte UDP packet. Discarding extra.",
len, packet.ip.tot_len);
len = ntohs(packet.ip.tot_len);
}
// Validate the IP and UDP headers.
if (packet.ip.protocol != IPPROTO_UDP) {
@ -262,9 +303,7 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd)
log_line("IP header length incorrect");
return -2;
}
check = packet.ip.check;
packet.ip.check = 0;
if (check != net_checksum(&packet.ip, sizeof packet.ip, 0)) {
if (!ip_checksum(&packet)) {
log_line("IP header checksum incorrect");
return -2;
}
@ -281,13 +320,7 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd)
return -2;
}
/* verify the UDP checksum by replacing the header with a psuedo header */
memset(&packet.ip, 0, offsetof(struct iphdr, protocol));
/* preserved fields: protocol, check, saddr, daddr */
packet.ip.tot_len = packet.udp.len; /* cheat on the psuedo-header */
check = packet.udp.check;
packet.udp.check = 0;
if (check && check != net_checksum(&packet, len, 0)) {
if (packet.udp.check && !udp_checksum(&packet)) {
log_error("Packet with bad UDP checksum received, ignoring");
return -2;
}
@ -354,13 +387,13 @@ static int send_dhcp_raw(struct dhcpMessage *payload)
.data = *payload,
};
// UDP checksumming needs a temporary pseudoheader with a fake length.
packet.udp.check = net_checksum(&packet, IP_UPD_DHCP_SIZE - padding, 0);
packet.udp.check = net_checksum(&packet, IP_UPD_DHCP_SIZE - padding);
// Set the true IP packet length for the final packet.
packet.ip.tot_len = htons(IP_UPD_DHCP_SIZE - padding);
packet.ip.ihl = sizeof packet.ip >> 2;
packet.ip.version = IPVERSION;
packet.ip.ttl = IPDEFTTL;
packet.ip.check = net_checksum(&packet.ip, sizeof packet.ip, 0);
packet.ip.check = net_checksum(&packet.ip, sizeof packet.ip);
r = safe_sendto(fd, (const char *)&packet, IP_UPD_DHCP_SIZE - padding,
0, (struct sockaddr *)&dest, sizeof dest);