From b735c84a04c24407a9351ef745a9453a2db5b271 Mon Sep 17 00:00:00 2001 From: Justin Riley Date: Wed, 17 Aug 2016 16:18:03 -0400 Subject: [PATCH] resolver: add thread mutex and use new zone module --- onedns/resolver.py | 85 +++++++++++++--------------- onedns/tests/conftest.py | 2 +- onedns/tests/test_dynamicresolver.py | 48 ++++++++++++---- onedns/utils.py | 10 ++++ requirements.txt | 2 + setup.py | 2 + 6 files changed, 92 insertions(+), 57 deletions(-) diff --git a/onedns/resolver.py b/onedns/resolver.py index 08195fd..5d5ef48 100644 --- a/onedns/resolver.py +++ b/onedns/resolver.py @@ -1,14 +1,16 @@ -# -*- coding: utf-8 -*- from __future__ import print_function import time +import threading import dnslib from dnslib import server -from dnslib import dns -from IPy import IP +from wrapt import synchronized +from onedns import zone from onedns import utils +from onedns import exception +from onedns.logger import log class DynamicResolver(server.BaseResolver): @@ -16,16 +18,18 @@ class DynamicResolver(server.BaseResolver): Dynamic In-Memory DNS Resolver """ - def __init__(self, domain, one_kwargs={}): + _lock = threading.RLock() + + def __init__(self, domain): """ Initialise resolver from zone list Stores RRs as a list of (label, type, rr) tuples """ - self.domain = domain - self.zone = [] + self.zone = zone.Zone(domain) self._tcp_server = None self._udp_server = None + @synchronized(_lock) def resolve(self, request, handler): """ Respond to DNS request - parameters are request packet & handler. @@ -33,57 +37,46 @@ class DynamicResolver(server.BaseResolver): """ reply = request.reply() qname = request.q.qname - qtype = dnslib.QTYPE[request.q.qtype] - A_RECORDS = ['A', 'AAAA'] - for name, rtype, rr in self.zone: - # Check if label & type match - if qname == name and (qtype in [rtype, 'ANY'] or rtype == 'CNAME'): - reply.add_answer(rr) - # Check for A/AAAA records associated with reply and - # add in additional section - if rtype in ['CNAME', 'NS', 'MX', 'PTR']: - for a_name, a_rtype, a_rr in self.zone: - if a_name == rr.rdata.label and a_rtype in A_RECORDS: - reply.add_ar(a_rr) - if not reply.rr: + qtype = request.q.qtype + try: + if qtype in (dnslib.QTYPE.A, dnslib.QTYPE.AAAA): + forward = self.zone.get_forward(qname) + reply.add_answer(forward) + elif qtype == dnslib.QTYPE.PTR: + reverse = self.zone.get_reverse( + utils.reverse_to_ip(qname.idna())) + reply.add_answer(reverse) + forward = self.zone.get_forward(str(reverse.rdata)) + if forward: + reply.add_ar(forward) + except exception.RecordDoesNotExist: reply.header.rcode = dnslib.RCODE.NXDOMAIN return reply + @synchronized(_lock) def clear(self): - self.zone = [] + self.zone.clear() - def add_host(self, name, ip, zone=None): - zone = zone or self.zone - self._add_forward(name, ip, zone=zone) - self._add_reverse(ip, name, zone=zone) + @synchronized(_lock) + def load(self, zone): + self.zone = zone - def _get_fqdn(self, name): - return utils.get_fqdn(name, self.domain) + @synchronized(_lock) + def add_host(self, name, ip): + self.zone.add_host(name, ip) - def _add_forward(self, name, ip, zone=None): - zone = zone or self.zone - f = dnslib.RR(rname=dnslib.DNSLabel(self._get_fqdn(name)), - rtype=dnslib.QTYPE.reverse['A'], - rclass=dnslib.CLASS.reverse['IN'], - rdata=dns.A(ip)) - zone.append((f.rname, 'A', f)) - - def _add_reverse(self, ip, name, zone=None): - zone = zone or self.zone - ip = IP(ip) - r = dnslib.RR(rname=dnslib.DNSLabel(ip.reverseName()), - rtype=dnslib.QTYPE.reverse['PTR'], - rclass=dnslib.CLASS.reverse['IN'], - rdata=dns.PTR(self._get_fqdn(name))) - zone.append((r.rname, 'PTR', r)) + @synchronized(_lock) + def remove_host(self, name, ip): + self.zone.remove_host(name, ip) def start(self, dns_address='0.0.0.0', dns_port=53, api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0, - log="request,reply,truncated,error", log_prefix=False): - logger = server.DNSLogger(log, log_prefix) + log_components="request,reply,truncated,error", + log_prefix=False): + logger = server.DNSLogger(log_components, log_prefix) - print("Starting OneDNS (%s:%d) [%s]" % - (dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP")) + log.info("Starting OneDNS (%s:%d) [%s]" % + (dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP")) server.DNSHandler.udplen = udplen diff --git a/onedns/tests/conftest.py b/onedns/tests/conftest.py index 5807730..2ff080f 100644 --- a/onedns/tests/conftest.py +++ b/onedns/tests/conftest.py @@ -3,7 +3,6 @@ import dnslib from IPy import IP from onedns import zone -from onedns import server from onedns import resolver from onedns.tests import vcr from onedns.clients import one @@ -30,6 +29,7 @@ TEST_GET_FORWARD_DATA = [ HOST + '.', ] + @pytest.fixture(scope="function") def dns(request): dns = resolver.DynamicResolver(domain=DOMAIN) diff --git a/onedns/tests/test_dynamicresolver.py b/onedns/tests/test_dynamicresolver.py index 32d0f43..fc1a0c4 100644 --- a/onedns/tests/test_dynamicresolver.py +++ b/onedns/tests/test_dynamicresolver.py @@ -1,14 +1,22 @@ import pytest + import dnslib +from IPy import IP + +from onedns import zone from onedns.tests import utils from onedns.tests import conftest +def _clear_and_add_test_host(dns, name, ip): + dns.clear() + dns.add_host(name, ip) + + @pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA) def test_lookup(dns, qname, qtype, output): - dns.clear() - dns.add_host(conftest.HOST, conftest.HOST_IP) + _clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP) try: a = utils.dnsquery(qname, qtype) assert a.short() == output @@ -25,15 +33,35 @@ def test_nxdomain(dns): dns.close() -@pytest.mark.parametrize("name,ip", conftest.TEST_GET_FQDN_DATA) -def test_get_fqdn(dns, name, ip): - dns.clear() - dns.add_host(name, ip) - assert dns.zone[0][0].label[0] == name.split('.')[0] - assert '.'.join(dns.zone[0][0].label[1:]) == conftest.DOMAIN - - def test_daemon(dns): dns.close() dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT, tcp=True, testing=True) + + +@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA) +def test_remove_host(dns, qname, qtype, output): + _clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP) + try: + a = utils.dnsquery(qname, qtype) + assert a.short() == output + dns.remove_host(conftest.HOST, conftest.HOST_IP) + a = utils.dnsquery(qname, qtype) + assert dnslib.RCODE.get(a.header.rcode) == 'NXDOMAIN' + finally: + dns.close() + + +def test_load_zone(dns): + new_ip = IP(IP(conftest.HOST_IP).int() + 1) + new_zone = zone.Zone(conftest.DOMAIN) + new_zone.add_host(conftest.HOST, new_ip) + _clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP) + try: + a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A) + assert a.short() == conftest.HOST_IP + dns.load(new_zone) + a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A) + assert a.short() == str(new_ip) + finally: + dns.close() diff --git a/onedns/utils.py b/onedns/utils.py index afee875..2fe1332 100644 --- a/onedns/utils.py +++ b/onedns/utils.py @@ -1,7 +1,12 @@ import os +import dns.name +import dns.reversename + import dnslib +from IPy import IP + from onedns.logger import log @@ -14,6 +19,11 @@ def get_fqdn(name, domain): return name.idna() +def reverse_to_ip(reverse): + rname = dns.name.Name(reverse.split('.')) + return IP(dns.reversename.to_address(rname)) + + def get_kwargs_from_dict(d, prefix, lower=False): tups_list = [] for i in d: diff --git a/requirements.txt b/requirements.txt index 137d95b..cf7351d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ oca==4.10.0 IPy==0.83 dnslib==0.9.6 +dnspython==1.14.0 +wrapt==1.10.8 diff --git a/setup.py b/setup.py index feb79f6..593ea91 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,8 @@ setup( "oca>=4.10.0", "IPy>=0.83", "dnslib>=0.9.6", + "dnspython>=1.14.0", + "wrapt>=1.10.8", ], setup_requires=[ 'pytest-runner>=2.9'