From a80a85191c0b80edf11adee4038340ee25c0ed15 Mon Sep 17 00:00:00 2001 From: Justin Riley Date: Tue, 2 Aug 2016 11:19:19 -0400 Subject: [PATCH] add DynamicResolver class with 100% test coverage --- onedns/server.py | 102 +++++++++++++++++++++++++++ onedns/tests/__init__.py | 0 onedns/tests/conftest.py | 16 +++++ onedns/tests/test_dynamicresolver.py | 51 ++++++++++++++ 4 files changed, 169 insertions(+) create mode 100644 onedns/server.py create mode 100644 onedns/tests/__init__.py create mode 100644 onedns/tests/conftest.py create mode 100644 onedns/tests/test_dynamicresolver.py diff --git a/onedns/server.py b/onedns/server.py new file mode 100644 index 0000000..efe892b --- /dev/null +++ b/onedns/server.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function + +import dnslib +from dnslib import server +from dnslib import dns + +from IPy import IP + + +class DynamicResolver(server.BaseResolver): + """ + Dynamic In-Memory DNS Resolver + """ + + def __init__(self, domain, one_kwargs={}): + """ + Initialise resolver from zone list + Stores RRs as a list of (label, type, rr) tuples + """ + self.domain = domain + self.zone = [] + self._tcp_server = None + self._udp_server = None + + def resolve(self, request, handler): + """ + Respond to DNS request - parameters are request packet & handler. + Method is expected to return DNS response + """ + reply = request.reply() + qname = request.q.qname + qtype = dnslib.QTYPE[request.q.qtype] + for name, rtype, rr in self.zone: + # Check if label & type match + if qname == name and (qtype in [rtype, 'ANY'] or rtype == 'CNAME'): + reply.add_answer(rr) + # Check for A/AAAA records associated with reply and + # add in additional section + if rtype in ['CNAME', 'NS', 'MX', 'PTR']: + for a_name, a_rtype, a_rr in self.zone: + if a_name == rr.rdata.label and a_rtype in ['A', 'AAAA']: + reply.add_ar(a_rr) + if not reply.rr: + reply.header.rcode = dnslib.RCODE.NXDOMAIN + return reply + + def clear(self): + self.zone = [] + + def add_host(self, name, ip): + self._add_forward(name, ip) + self._add_reverse(ip, name) + + def _get_fqdn(self, name): + if not name.endswith(self.domain): + if name.endswith('.'): + return name + self.domain + else: + return '.'.join([name, self.domain]) + return name + + def _add_forward(self, name, ip): + f = dnslib.RR(rname=dnslib.DNSLabel(self._get_fqdn(name)), + rtype=dnslib.QTYPE.reverse['A'], + rclass=dnslib.CLASS.reverse['IN'], + rdata=dns.A(ip)) + self.zone.append((f.rname, 'A', f)) + + def _add_reverse(self, ip, name): + ip = IP(ip) + r = dnslib.RR(rname=dnslib.DNSLabel(ip.reverseName()), + rtype=dnslib.QTYPE.reverse['PTR'], + rclass=dnslib.CLASS.reverse['IN'], + rdata=dns.PTR(self._get_fqdn(name))) + self.zone.append((r.rname, 'PTR', r)) + + def start(self, dns_address='0.0.0.0', dns_port=53, + api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0, + log="request,reply,truncated,error", log_prefix=False): + logger = server.DNSLogger(log, log_prefix) + + print("Starting OneDNS (%s:%d) [%s]" % + (dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP")) + + server.DNSHandler.udplen = udplen + + self._udp_server = server.DNSServer(self, port=dns_port, + address=dns_address, logger=logger) + self._udp_server.start_thread() + + if tcp: + self._tcp_server = server.DNSServer(self, port=dns_port, + address=dns_address, tcp=True, + logger=logger) + self._tcp_server.start_thread() + + def close(self): + for srv in [self._tcp_server, self._udp_server]: + if srv: + srv.stop() + srv.server.socket.close() diff --git a/onedns/tests/__init__.py b/onedns/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/onedns/tests/conftest.py b/onedns/tests/conftest.py new file mode 100644 index 0000000..f0a3276 --- /dev/null +++ b/onedns/tests/conftest.py @@ -0,0 +1,16 @@ +import pytest + +from onedns import server + + +DOMAIN = 'onedns.test' +INTERFACE = '127.0.0.1' +PORT = 9053 + + +@pytest.fixture(scope="function") +def dns(request): + dns = server.DynamicResolver(domain=DOMAIN) + dns.start(dns_address=INTERFACE, dns_port=PORT, tcp=True) + request.addfinalizer(dns.close) + return dns diff --git a/onedns/tests/test_dynamicresolver.py b/onedns/tests/test_dynamicresolver.py new file mode 100644 index 0000000..d73480b --- /dev/null +++ b/onedns/tests/test_dynamicresolver.py @@ -0,0 +1,51 @@ +import pytest +import dnslib + +from IPy import IP + +from onedns.tests import conftest + + +HOST = '.'.join(['testhost', conftest.DOMAIN]) +HOST_IP = '10.242.118.112' +TEST_LOOKUP_DATA = [ + (HOST, dnslib.QTYPE.A, HOST_IP), + (IP(HOST_IP).reverseName(), dnslib.QTYPE.PTR, HOST + '.') +] +TEST_GET_FQDN_DATA = [ + ('hostwithnodot', '192.168.1.23'), + ('hostwithdot.', '192.168.1.19'), +] + + +@pytest.mark.parametrize("qname,qtype,output", TEST_LOOKUP_DATA) +def test_lookup(dns, qname, qtype, output): + dns.clear() + dns.add_host(HOST, HOST_IP) + try: + q = dnslib.DNSRecord(q=dnslib.DNSQuestion(qname, qtype)) + a_pkt = q.send(conftest.INTERFACE, conftest.PORT, tcp=False) + a = dnslib.DNSRecord.parse(a_pkt) + assert a.short() == output + finally: + dns.close() + + +def test_nxdomain(dns): + dns.clear() + try: + q = dnslib.DNSRecord(q=dnslib.DNSQuestion( + 'unknownhost', dnslib.QTYPE.A)) + a_pkt = q.send(conftest.INTERFACE, conftest.PORT, tcp=False) + a = dnslib.DNSRecord.parse(a_pkt) + assert dnslib.RCODE.get(a.header.rcode) == 'NXDOMAIN' + finally: + dns.close() + + +@pytest.mark.parametrize("name,ip", TEST_GET_FQDN_DATA) +def test_get_fqdn(dns, name, ip): + dns.clear() + dns.add_host(name, ip) + assert dns.zone[0][0].label[0] == name.split('.')[0] + assert '.'.join(dns.zone[0][0].label[1:]) == conftest.DOMAIN