diff --git a/ndhc/packet.c b/ndhc/packet.c index 95017f8..fc87acd 100644 --- a/ndhc/packet.c +++ b/ndhc/packet.c @@ -202,14 +202,25 @@ static int get_packet(struct dhcpMessage *packet, int fd) return bytes; } +// When summing ones-complement 16-bit values using a 32-bit unsigned +// representation, fold the carry bits that have spilled into the upper +// 16-bits of the 32-bit unsigned value back into the 16-bit ones-complement +// binary value. +static inline uint16_t foldcarry(uint32_t v) +{ + v = (v >> 16) + (v & 0xffff); + v += v >> 16; + return v; +} + // This function is not suitable for summing buffers that are greater than // 128k-1 bytes in length: failure case will be incorrect checksums via // unsigned overflow, which is a defined operation and is safe. This limit // should not be an issue for IPv4 or IPv6 packet, which are limited to // at most 64k bytes. -static uint16_t net_checksum(void *buf, size_t size, uint16_t init) +static uint16_t net_checksum(void *buf, size_t size) { - uint32_t sum = init; + uint32_t sum = 0; int odd = size & 0x01; size_t i; size &= ~((size_t)0x01); @@ -225,9 +236,36 @@ static uint16_t net_checksum(void *buf, size_t size, uint16_t init) uint16_t lo = 0; sum += ntohs((lo + (hi << 8))); } - sum = (sum >> 16) + (sum & 0xffff); - sum += sum >> 16; - return ~sum; + return ~foldcarry(sum); +} + +// For two sequences of bytes A and B that return checksums CS(A) and CS(B), +// this function will calculate the checksum CS(AB) of the concatenated value +// AB given the checksums of the individual parts CS(A) and CS(B). +static inline uint16_t net_checksum_add(uint16_t a, uint16_t b) +{ + return ~foldcarry((~a & 0xffff) + (~b & 0xffff)); +} + +// Returns 1 if IP checksum is correct, otherwise 0. +static int ip_checksum(struct ip_udp_dhcp_packet *packet) +{ + return net_checksum(&packet->ip, sizeof packet->ip) == 0; +} + +// Returns 1 if UDP checksum is correct, otherwise 0. +static int udp_checksum(struct ip_udp_dhcp_packet *packet) +{ + struct iphdr ph = { + .saddr = packet->ip.saddr, + .daddr = packet->ip.daddr, + .protocol = packet->ip.protocol, + .tot_len = packet->udp.len, + }; + uint16_t udpcs = net_checksum(&packet->udp, ntohs(packet->udp.len)); + uint16_t hdrcs = net_checksum(&ph, sizeof ph); + uint16_t cs = net_checksum_add(udpcs, hdrcs); + return cs == 0; } // Read a packet from a raw socket. Returns -1 on fatal error, -2 on @@ -235,9 +273,8 @@ static uint16_t net_checksum(void *buf, size_t size, uint16_t init) static int get_raw_packet(struct dhcpMessage *payload, int fd) { struct ip_udp_dhcp_packet packet; - uint16_t check; - memset(&packet, 0, IP_UPD_DHCP_SIZE); + int len = safe_read(fd, (char *)&packet, IP_UPD_DHCP_SIZE); if (len == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK) @@ -247,7 +284,11 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd) } /* ignore any extra garbage bytes */ - len = ntohs(packet.ip.tot_len); + if (ntohs(packet.ip.tot_len) != len) { + log_line("Received %u bytes for a %u byte UDP packet. Discarding extra.", + len, packet.ip.tot_len); + len = ntohs(packet.ip.tot_len); + } // Validate the IP and UDP headers. if (packet.ip.protocol != IPPROTO_UDP) { @@ -262,9 +303,7 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd) log_line("IP header length incorrect"); return -2; } - check = packet.ip.check; - packet.ip.check = 0; - if (check != net_checksum(&packet.ip, sizeof packet.ip, 0)) { + if (!ip_checksum(&packet)) { log_line("IP header checksum incorrect"); return -2; } @@ -281,13 +320,7 @@ static int get_raw_packet(struct dhcpMessage *payload, int fd) return -2; } - /* verify the UDP checksum by replacing the header with a psuedo header */ - memset(&packet.ip, 0, offsetof(struct iphdr, protocol)); - /* preserved fields: protocol, check, saddr, daddr */ - packet.ip.tot_len = packet.udp.len; /* cheat on the psuedo-header */ - check = packet.udp.check; - packet.udp.check = 0; - if (check && check != net_checksum(&packet, len, 0)) { + if (packet.udp.check && !udp_checksum(&packet)) { log_error("Packet with bad UDP checksum received, ignoring"); return -2; } @@ -354,13 +387,13 @@ static int send_dhcp_raw(struct dhcpMessage *payload) .data = *payload, }; // UDP checksumming needs a temporary pseudoheader with a fake length. - packet.udp.check = net_checksum(&packet, IP_UPD_DHCP_SIZE - padding, 0); + packet.udp.check = net_checksum(&packet, IP_UPD_DHCP_SIZE - padding); // Set the true IP packet length for the final packet. packet.ip.tot_len = htons(IP_UPD_DHCP_SIZE - padding); packet.ip.ihl = sizeof packet.ip >> 2; packet.ip.version = IPVERSION; packet.ip.ttl = IPDEFTTL; - packet.ip.check = net_checksum(&packet.ip, sizeof packet.ip, 0); + packet.ip.check = net_checksum(&packet.ip, sizeof packet.ip); r = safe_sendto(fd, (const char *)&packet, IP_UPD_DHCP_SIZE - padding, 0, (struct sockaddr *)&dest, sizeof dest);