From 2c0283fdb7347bb61bdb0ac22d306d1bd05b7f03 Mon Sep 17 00:00:00 2001 From: Justin Riley Date: Wed, 17 Aug 2016 10:13:43 -0400 Subject: [PATCH] add zone module with in-memory zone implementation --- onedns/exception.py | 10 ++++ onedns/tests/conftest.py | 31 +++++++++++- onedns/tests/test_dynamicresolver.py | 20 ++------ onedns/tests/test_zone.py | 61 +++++++++++++++++++++++ onedns/utils.py | 14 +++--- onedns/zone.py | 73 ++++++++++++++++++++++++++++ 6 files changed, 185 insertions(+), 24 deletions(-) create mode 100644 onedns/tests/test_zone.py create mode 100644 onedns/zone.py diff --git a/onedns/exception.py b/onedns/exception.py index 63cb4dd..eac61f0 100644 --- a/onedns/exception.py +++ b/onedns/exception.py @@ -26,3 +26,13 @@ class NoNetworksError(OneDnsException): def __init__(self, vm): self.msg = "No networks found for VM {id}: {vm}".format(vm=vm.name, id=vm.id) + + +class RecordDoesNotExist(OneDnsException): + """ + Raised when a zone record does not exist + """ + def __init__(self, key, val=None): + self.msg = "Record Does Not Exist: {}".format(key) + if val is not None: + self.msg += " -> {}".format(val) diff --git a/onedns/tests/conftest.py b/onedns/tests/conftest.py index 6ee93b2..5807730 100644 --- a/onedns/tests/conftest.py +++ b/onedns/tests/conftest.py @@ -1,5 +1,9 @@ import pytest +import dnslib +from IPy import IP +from onedns import zone +from onedns import server from onedns import resolver from onedns.tests import vcr from onedns.clients import one @@ -8,7 +12,23 @@ from onedns.clients import one DOMAIN = 'onedns.test' INTERFACE = '127.0.0.1' PORT = 9053 - +HOST_SHORT = 'testhost' +HOST = '.'.join([HOST_SHORT, DOMAIN]) +HOST_IP = '10.242.118.112' +TEST_LOOKUP_DATA = [ + (HOST, dnslib.QTYPE.A, HOST_IP), + (IP(HOST_IP).reverseName(), dnslib.QTYPE.PTR, HOST + '.') +] +TEST_GET_FQDN_DATA = [ + ('hostwithnodot', '192.168.1.23'), + ('hostwithdot.', '192.168.1.19'), +] +TEST_GET_FORWARD_DATA = [ + HOST_SHORT, + HOST_SHORT + '.', + HOST, + HOST + '.', +] @pytest.fixture(scope="function") def dns(request): @@ -30,3 +50,12 @@ def oneclient(request): @vcr.use_cassette() def vms(oneclient): return oneclient.vms() + + +@pytest.fixture(scope="function") +def onezone(): + z = zone.Zone(DOMAIN) + z.add_host(HOST_SHORT, HOST_IP) + for name, ip in TEST_GET_FQDN_DATA: + z.add_host(name, ip) + return z diff --git a/onedns/tests/test_dynamicresolver.py b/onedns/tests/test_dynamicresolver.py index 33159a5..32d0f43 100644 --- a/onedns/tests/test_dynamicresolver.py +++ b/onedns/tests/test_dynamicresolver.py @@ -1,28 +1,14 @@ import pytest import dnslib -from IPy import IP - from onedns.tests import utils from onedns.tests import conftest -HOST = '.'.join(['testhost', conftest.DOMAIN]) -HOST_IP = '10.242.118.112' -TEST_LOOKUP_DATA = [ - (HOST, dnslib.QTYPE.A, HOST_IP), - (IP(HOST_IP).reverseName(), dnslib.QTYPE.PTR, HOST + '.') -] -TEST_GET_FQDN_DATA = [ - ('hostwithnodot', '192.168.1.23'), - ('hostwithdot.', '192.168.1.19'), -] - - -@pytest.mark.parametrize("qname,qtype,output", TEST_LOOKUP_DATA) +@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA) def test_lookup(dns, qname, qtype, output): dns.clear() - dns.add_host(HOST, HOST_IP) + dns.add_host(conftest.HOST, conftest.HOST_IP) try: a = utils.dnsquery(qname, qtype) assert a.short() == output @@ -39,7 +25,7 @@ def test_nxdomain(dns): dns.close() -@pytest.mark.parametrize("name,ip", TEST_GET_FQDN_DATA) +@pytest.mark.parametrize("name,ip", conftest.TEST_GET_FQDN_DATA) def test_get_fqdn(dns, name, ip): dns.clear() dns.add_host(name, ip) diff --git a/onedns/tests/test_zone.py b/onedns/tests/test_zone.py new file mode 100644 index 0000000..ef1c8d0 --- /dev/null +++ b/onedns/tests/test_zone.py @@ -0,0 +1,61 @@ +import pytest + +import dnslib + +from IPy import IP + +from onedns import exception +from onedns.tests import conftest + + +def test_zone_clear(onezone): + assert onezone._forward + assert onezone._reverse + onezone.clear() + assert not onezone._forward + assert not onezone._reverse + + +def test_add_host(onezone): + onezone.clear() + with pytest.raises(exception.RecordDoesNotExist): + onezone._get_forward(conftest.HOST_SHORT, conftest.HOST_IP) + with pytest.raises(exception.RecordDoesNotExist): + onezone._get_reverse(conftest.HOST_IP, conftest.HOST_SHORT) + onezone.add_host(conftest.HOST_SHORT, conftest.HOST_IP) + assert onezone._get_forward(conftest.HOST_SHORT, conftest.HOST_IP) + assert onezone._get_reverse(conftest.HOST_IP, conftest.HOST_SHORT) + + +def test_remove_host(onezone): + onezone.clear() + onezone.add_host(conftest.HOST_SHORT, conftest.HOST_IP) + assert onezone._get_forward(conftest.HOST_SHORT, conftest.HOST_IP) + assert onezone._get_reverse(conftest.HOST_IP, conftest.HOST_SHORT) + onezone.remove_host(conftest.HOST_SHORT, conftest.HOST_IP) + with pytest.raises(exception.RecordDoesNotExist): + onezone._get_forward(conftest.HOST_SHORT, conftest.HOST_IP) + with pytest.raises(exception.RecordDoesNotExist): + onezone._get_reverse(conftest.HOST_IP, conftest.HOST_SHORT) + + +@pytest.mark.parametrize("host", conftest.TEST_GET_FORWARD_DATA) +def test_get_forward(onezone, host): + forward = onezone.get_forward(host) + fqdn = conftest.HOST + '.' + assert isinstance(forward, dnslib.RR) + assert forward.rname == fqdn + assert forward.rtype == dnslib.QTYPE.A + assert forward.rclass == dnslib.CLASS.IN + assert str(forward.rdata) == conftest.HOST_IP + + +def test_get_reverse(onezone): + reverse = onezone.get_reverse(conftest.HOST_IP) + fqdn = conftest.HOST + '.' + revip = IP(conftest.HOST_IP).reverseName() + assert isinstance(reverse, dnslib.RR) + assert reverse.rname == revip + assert reverse.rtype == dnslib.QTYPE.PTR + assert reverse.rclass == dnslib.CLASS.IN + assert str(reverse.rdata) == fqdn diff --git a/onedns/utils.py b/onedns/utils.py index e9fca0e..afee875 100644 --- a/onedns/utils.py +++ b/onedns/utils.py @@ -1,15 +1,17 @@ import os +import dnslib + from onedns.logger import log def get_fqdn(name, domain): - if not name.endswith(domain): - if name.endswith('.'): - return name + domain - else: - return '.'.join([name, domain]) - return name + domain = dnslib.DNSLabel(domain) + name = dnslib.DNSLabel(name) + if name.label[-1 * len(domain.label):] != domain.label: + return dnslib.DNSLabel(name.label + domain.label).idna() + else: + return name.idna() def get_kwargs_from_dict(d, prefix, lower=False): diff --git a/onedns/zone.py b/onedns/zone.py new file mode 100644 index 0000000..a93c55e --- /dev/null +++ b/onedns/zone.py @@ -0,0 +1,73 @@ +import dnslib +from dnslib import dns + +from IPy import IP + +from onedns import utils +from onedns import exception + + +class Zone(object): + def __init__(self, domain): + self.domain = domain + self._forward = {} + self._reverse = {} + + def clear(self): + self._forward = {} + self._reverse = {} + + def _get_fqdn(self, name): + return utils.get_fqdn(name, self.domain) + + def _get_rr(self, rname, rtype, rdata): + return dnslib.RR(rname=dnslib.DNSLabel(rname), + rtype=dnslib.QTYPE.reverse[rtype], + rclass=dnslib.CLASS.reverse['IN'], + rdata=getattr(dns, rtype)(rdata)) + + def _add_forward(self, name, ip): + self._forward[self._get_fqdn(name)] = IP(ip) + + def _get_forward(self, name, ip=None): + fqdn = self._get_fqdn(name) + fip = self._forward.get(fqdn) + if not fip or (ip and fip != IP(ip)): + raise exception.RecordDoesNotExist(name, ip) + return fip + + def _remove_forward(self, name, ip=None): + self._get_forward(name, ip) + del self._forward[self._get_fqdn(name)] + + def _add_reverse(self, ip, name): + self._reverse[IP(ip)] = self._get_fqdn(name) + + def _get_reverse(self, ip, name=None): + reverse = self._reverse.get(IP(ip)) + fqdn = self._get_fqdn(name) if name else None + if not reverse or (name and fqdn != reverse): + raise exception.RecordDoesNotExist(ip, fqdn) + return reverse + + def _remove_reverse(self, ip, name=None): + self._get_reverse(ip, name) + del self._reverse[IP(ip)] + + def add_host(self, name, ip): + self._add_forward(name, ip) + self._add_reverse(ip, name) + + def remove_host(self, name, ip): + self._remove_forward(name, ip) + self._remove_reverse(ip, name) + + def get_forward(self, name): + fqdn = self._get_fqdn(name) + forward = self._get_forward(fqdn) + return self._get_rr(fqdn, 'A', str(forward)) + + def get_reverse(self, ip): + ip = IP(ip) + reverse = self._get_reverse(ip) + return self._get_rr(ip.reverseName(), 'PTR', reverse)