#include "test.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define TEST(c, ...) ((c) || (t_error(#c " failed: " __VA_ARGS__), 0)) #define DNS_HEADER_OFFSET 12 #define DNS_FIELD_SIZES 12 typedef struct { char *domain_name; uint16_t type; uint16_t class; char *expected_response_data; size_t num_answers; size_t response_size; } test_dns_packet; #define DNS_PACKET(domain, _type, _class, response, _num_answers, size) \ (test_dns_packet) \ { \ .domain_name = domain, .type = _type, .class = _class, \ .expected_response_data = response, \ .num_answers = _num_answers, .response_size = size \ } // Simple answer resource record #define EXAMPLE_ANSWER_RR \ "\xc0\x0c" \ "\x00\x01" \ "\x00\x01" \ "\x00\x00\x02\x58" \ "\x00\x04" \ "\xc0\xa8\x11\x01" // IPv4 192.168.17.1 // Extended answer resource record #define EXTENDED_ANSWER_RR \ "\xc0\x0c" \ "\x00\x01" \ "\x00\x01" \ "\x00\x00\x02\x58" \ "\x00\x04" \ "\xc0\xa8\x11\x02" \ "\xc0\x1c" \ "\x00\x01" \ "\x00\x01" \ "\x00\x00\x02\x58" \ "\x00\x04" \ "\xc0\xa8\x11\x03" #define TEST_IPV4 \ DNS_PACKET("example.com", 0x01, 0x01, EXAMPLE_ANSWER_RR, 0x01, \ sizeof(EXAMPLE_ANSWER_RR)) #define TEST_LONG_DOMAIN \ DNS_PACKET("foo.bar.example.com", 0x01, 0x01, EXAMPLE_ANSWER_RR, 0x01, \ sizeof(EXAMPLE_ANSWER_RR)) #define TEST_EXTENDED_RESPONSE \ DNS_PACKET("fizz.buzz.com", 0x01, 0x01, EXTENDED_ANSWER_RR, 0x02, \ sizeof(EXTENDED_ANSWER_RR)) static test_dns_packet dns_tests[] = {TEST_IPV4, TEST_LONG_DOMAIN, TEST_EXTENDED_RESPONSE}; static const size_t dns_test_count = sizeof(dns_tests) / sizeof(*dns_tests); // Wait until serving thread is ready to receive pthread_barrier_t sync_barrier; static size_t construct_response(uint16_t id, unsigned char *question, unsigned char *response, int response_index) { HEADER dns_header; const size_t dns_header_offset = sizeof(dns_header); const size_t expected_question_size = strlen(dns_tests[response_index].domain_name) + 6; memset(&dns_header, 0, dns_header_offset); dns_header.id = id; dns_header.qr = 0x01U; dns_header.rd = 0x01U; dns_header.ra = 0x01U; dns_header.qdcount = 0x0100U; // 1 question dns_header.ancount = htons(dns_tests[response_index].num_answers); memcpy(response, &dns_header, sizeof(dns_header)); memcpy(&response[dns_header_offset], &question[dns_header_offset], expected_question_size); char *answer_buffer = dns_tests[response_index].expected_response_data; memcpy(&response[dns_header_offset + expected_question_size], &answer_buffer[0], dns_tests[response_index].response_size); return dns_header_offset + expected_question_size + dns_tests[response_index].response_size - 1; // ignore null terminator } static int bind_to_socket(int s) { struct sockaddr_in dns_server; memset(&dns_server, 0, sizeof(dns_server)); dns_server.sin_addr.s_addr = inet_addr("127.0.0.1"); dns_server.sin_family = AF_INET; dns_server.sin_port = htons(53); return bind(s, (struct sockaddr *)&dns_server, sizeof(dns_server)); } static int set_environment(void) { FILE *ft = fopen("/etc/resolv.conf", "w"); if (ft == NULL) { t_error("unable to open namespaced resolv.conf\n"); return -1; } fprintf(ft, "nameserver 127.0.0.1"); fclose(ft); ft = fopen("/etc/hosts", "w"); if (ft == NULL) { t_error("unable to open namespaced resolv.conf\n"); return -1; } fprintf(ft, "127.0.0.1 localhost"); fclose(ft); return 0; } void *dns_server(void *arguments) { int s = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); int status = bind_to_socket(s); struct sockaddr_in from = {0}; socklen_t from_length = sizeof(from); unsigned char packet_buffer[NS_PACKETSZ]; pthread_barrier_wait(&sync_barrier); int packets_received = 0; while (packets_received != dns_test_count) { int packet_size = recvfrom(s, packet_buffer, NS_PACKETSZ, 0, (struct sockaddr *)&from, &from_length); unsigned char response_buffer[NS_PACKETSZ] = {0}; const uint16_t response_id = (packet_buffer[1] << 8) | packet_buffer[0]; size_t response_size = construct_response(response_id, packet_buffer, response_buffer, packets_received); status = sendto(s, response_buffer, response_size, 0, (struct sockaddr *)&from, from_length); packets_received++; } return 0; } static unsigned char *check_domain_name(const char *domain_name, unsigned char *buffer, unsigned *length) { unsigned name_location = 0; while (*buffer != 0x00) { TEST(!memcmp(buffer + 1, domain_name + name_location, *buffer), "Expected domain name %s, got %s\n", domain_name + name_location, buffer + 1); name_location += *buffer + 1; buffer += *buffer + 1; } *length = name_location; return buffer; } static unsigned char *check_dns_questions(unsigned char *buffer, test_dns_packet *packet) { // Check the question is returned properly unsigned char *x = buffer + DNS_HEADER_OFFSET; unsigned name_length = 0; x = check_domain_name(packet->domain_name, x, &name_length); // Increment to check returned class x++; uint16_t class = ntohs(*(uint16_t *)x); TEST(class == packet->class, "Expected class 0x%04x, got 0x%04x\n", packet->class, class); // Increment to check returned type x += sizeof(uint16_t); uint16_t type = ntohs(*(uint16_t *)x); TEST(type == packet->type, "Expected type 0x%04x, got 0x%04x\n", packet->class, type); x += sizeof(uint16_t); return x; } static unsigned char *check_dns_answers(unsigned char *buffer, test_dns_packet *packet, int *previous_answer_length) { unsigned additional_offset = 0; if (ntohs(*(uint16_t *)buffer) != 0xc00c + *previous_answer_length) { buffer = check_domain_name(packet->domain_name, buffer, &additional_offset); buffer++; } else { buffer += sizeof(uint16_t); } uint16_t class = ntohs(*(uint16_t *)buffer); TEST(class == packet->class, "Expected class 0x%04x, got 0x%04x\n", packet->class, class); buffer += sizeof(uint16_t); uint16_t type = ntohs(*(uint16_t *)buffer); TEST(type == packet->type, "Expected type 0x%04x, got 0x%04x\n", packet->class, type); buffer += sizeof(uint16_t); uint32_t ttl = ntohl(*(uint32_t *)buffer); TEST(ttl > 0, "Expected TTL %ld to be greater than 0\n", ttl); buffer += sizeof(uint32_t); uint16_t resource_length = ntohs(*(uint16_t *)buffer); TEST(resource_length > 0, "Expected resource length %d to be greater than 0\n", resource_length); buffer += sizeof(uint16_t); uint32_t expected_ip = *(uint32_t *)&packet ->expected_response_data[DNS_FIELD_SIZES + *previous_answer_length + additional_offset]; TEST(!memcmp(buffer, &expected_ip, resource_length), "Expected IPv4 addresses to match: 0x%08x, 0x%08x\n", ntohl(*(uint32_t *)buffer), ntohl(expected_ip)); *previous_answer_length = additional_offset + resource_length + DNS_FIELD_SIZES; buffer += resource_length; return buffer; } static void dns_test(test_dns_packet *test) { unsigned char res_buffer[NS_PACKETSZ] = {0}; int length = res_query(test->domain_name, test->class, test->type, res_buffer, sizeof(res_buffer)); size_t num_answers = ntohs(*(uint16_t *)(res_buffer + 6)); unsigned char *answer_buffer = check_dns_questions(res_buffer, test); int previous_answer_length = 0; for (size_t answer = 0; answer < num_answers; ++answer) { answer_buffer = check_dns_answers(answer_buffer, test, &previous_answer_length); } TEST(*answer_buffer == 0x00, "Expected end of DNS packet to equal 0x00\n"); } int main(void) { if (t_enter_dns_ns() < 0) { t_error("Failed to enter test namespace: %s\n", strerror(errno)); return t_status; } if (set_environment() < 0) { t_error("Failed to set environment\n"); return t_status; } pthread_barrier_init(&sync_barrier, NULL, 2); pthread_t dns_thread; int status = pthread_create(&dns_thread, 0, dns_server, 0); pthread_barrier_wait(&sync_barrier); for (int test_index = 0; test_index < dns_test_count; ++test_index) { dns_test(&dns_tests[test_index]); } void *thread_return; pthread_join(dns_thread, &thread_return); return t_status; }