resolver: add thread mutex and use new zone module

master
Justin Riley 2016-08-17 16:18:03 -04:00
parent 2c0283fdb7
commit b735c84a04
6 changed files with 92 additions and 57 deletions

View File

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function
import time import time
import threading
import dnslib import dnslib
from dnslib import server from dnslib import server
from dnslib import dns
from IPy import IP from wrapt import synchronized
from onedns import zone
from onedns import utils from onedns import utils
from onedns import exception
from onedns.logger import log
class DynamicResolver(server.BaseResolver): class DynamicResolver(server.BaseResolver):
@ -16,16 +18,18 @@ class DynamicResolver(server.BaseResolver):
Dynamic In-Memory DNS Resolver Dynamic In-Memory DNS Resolver
""" """
def __init__(self, domain, one_kwargs={}): _lock = threading.RLock()
def __init__(self, domain):
""" """
Initialise resolver from zone list Initialise resolver from zone list
Stores RRs as a list of (label, type, rr) tuples Stores RRs as a list of (label, type, rr) tuples
""" """
self.domain = domain self.zone = zone.Zone(domain)
self.zone = []
self._tcp_server = None self._tcp_server = None
self._udp_server = None self._udp_server = None
@synchronized(_lock)
def resolve(self, request, handler): def resolve(self, request, handler):
""" """
Respond to DNS request - parameters are request packet & handler. Respond to DNS request - parameters are request packet & handler.
@ -33,56 +37,45 @@ class DynamicResolver(server.BaseResolver):
""" """
reply = request.reply() reply = request.reply()
qname = request.q.qname qname = request.q.qname
qtype = dnslib.QTYPE[request.q.qtype] qtype = request.q.qtype
A_RECORDS = ['A', 'AAAA'] try:
for name, rtype, rr in self.zone: if qtype in (dnslib.QTYPE.A, dnslib.QTYPE.AAAA):
# Check if label & type match forward = self.zone.get_forward(qname)
if qname == name and (qtype in [rtype, 'ANY'] or rtype == 'CNAME'): reply.add_answer(forward)
reply.add_answer(rr) elif qtype == dnslib.QTYPE.PTR:
# Check for A/AAAA records associated with reply and reverse = self.zone.get_reverse(
# add in additional section utils.reverse_to_ip(qname.idna()))
if rtype in ['CNAME', 'NS', 'MX', 'PTR']: reply.add_answer(reverse)
for a_name, a_rtype, a_rr in self.zone: forward = self.zone.get_forward(str(reverse.rdata))
if a_name == rr.rdata.label and a_rtype in A_RECORDS: if forward:
reply.add_ar(a_rr) reply.add_ar(forward)
if not reply.rr: except exception.RecordDoesNotExist:
reply.header.rcode = dnslib.RCODE.NXDOMAIN reply.header.rcode = dnslib.RCODE.NXDOMAIN
return reply return reply
@synchronized(_lock)
def clear(self): def clear(self):
self.zone = [] self.zone.clear()
def add_host(self, name, ip, zone=None): @synchronized(_lock)
zone = zone or self.zone def load(self, zone):
self._add_forward(name, ip, zone=zone) self.zone = zone
self._add_reverse(ip, name, zone=zone)
def _get_fqdn(self, name): @synchronized(_lock)
return utils.get_fqdn(name, self.domain) def add_host(self, name, ip):
self.zone.add_host(name, ip)
def _add_forward(self, name, ip, zone=None): @synchronized(_lock)
zone = zone or self.zone def remove_host(self, name, ip):
f = dnslib.RR(rname=dnslib.DNSLabel(self._get_fqdn(name)), self.zone.remove_host(name, ip)
rtype=dnslib.QTYPE.reverse['A'],
rclass=dnslib.CLASS.reverse['IN'],
rdata=dns.A(ip))
zone.append((f.rname, 'A', f))
def _add_reverse(self, ip, name, zone=None):
zone = zone or self.zone
ip = IP(ip)
r = dnslib.RR(rname=dnslib.DNSLabel(ip.reverseName()),
rtype=dnslib.QTYPE.reverse['PTR'],
rclass=dnslib.CLASS.reverse['IN'],
rdata=dns.PTR(self._get_fqdn(name)))
zone.append((r.rname, 'PTR', r))
def start(self, dns_address='0.0.0.0', dns_port=53, def start(self, dns_address='0.0.0.0', dns_port=53,
api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0, api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0,
log="request,reply,truncated,error", log_prefix=False): log_components="request,reply,truncated,error",
logger = server.DNSLogger(log, log_prefix) log_prefix=False):
logger = server.DNSLogger(log_components, log_prefix)
print("Starting OneDNS (%s:%d) [%s]" % log.info("Starting OneDNS (%s:%d) [%s]" %
(dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP")) (dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP"))
server.DNSHandler.udplen = udplen server.DNSHandler.udplen = udplen

View File

@ -3,7 +3,6 @@ import dnslib
from IPy import IP from IPy import IP
from onedns import zone from onedns import zone
from onedns import server
from onedns import resolver from onedns import resolver
from onedns.tests import vcr from onedns.tests import vcr
from onedns.clients import one from onedns.clients import one
@ -30,6 +29,7 @@ TEST_GET_FORWARD_DATA = [
HOST + '.', HOST + '.',
] ]
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def dns(request): def dns(request):
dns = resolver.DynamicResolver(domain=DOMAIN) dns = resolver.DynamicResolver(domain=DOMAIN)

View File

@ -1,14 +1,22 @@
import pytest import pytest
import dnslib import dnslib
from IPy import IP
from onedns import zone
from onedns.tests import utils from onedns.tests import utils
from onedns.tests import conftest from onedns.tests import conftest
def _clear_and_add_test_host(dns, name, ip):
dns.clear()
dns.add_host(name, ip)
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA) @pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
def test_lookup(dns, qname, qtype, output): def test_lookup(dns, qname, qtype, output):
dns.clear() _clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
dns.add_host(conftest.HOST, conftest.HOST_IP)
try: try:
a = utils.dnsquery(qname, qtype) a = utils.dnsquery(qname, qtype)
assert a.short() == output assert a.short() == output
@ -25,15 +33,35 @@ def test_nxdomain(dns):
dns.close() dns.close()
@pytest.mark.parametrize("name,ip", conftest.TEST_GET_FQDN_DATA)
def test_get_fqdn(dns, name, ip):
dns.clear()
dns.add_host(name, ip)
assert dns.zone[0][0].label[0] == name.split('.')[0]
assert '.'.join(dns.zone[0][0].label[1:]) == conftest.DOMAIN
def test_daemon(dns): def test_daemon(dns):
dns.close() dns.close()
dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT, dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT,
tcp=True, testing=True) tcp=True, testing=True)
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
def test_remove_host(dns, qname, qtype, output):
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
try:
a = utils.dnsquery(qname, qtype)
assert a.short() == output
dns.remove_host(conftest.HOST, conftest.HOST_IP)
a = utils.dnsquery(qname, qtype)
assert dnslib.RCODE.get(a.header.rcode) == 'NXDOMAIN'
finally:
dns.close()
def test_load_zone(dns):
new_ip = IP(IP(conftest.HOST_IP).int() + 1)
new_zone = zone.Zone(conftest.DOMAIN)
new_zone.add_host(conftest.HOST, new_ip)
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
try:
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
assert a.short() == conftest.HOST_IP
dns.load(new_zone)
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
assert a.short() == str(new_ip)
finally:
dns.close()

View File

@ -1,7 +1,12 @@
import os import os
import dns.name
import dns.reversename
import dnslib import dnslib
from IPy import IP
from onedns.logger import log from onedns.logger import log
@ -14,6 +19,11 @@ def get_fqdn(name, domain):
return name.idna() return name.idna()
def reverse_to_ip(reverse):
rname = dns.name.Name(reverse.split('.'))
return IP(dns.reversename.to_address(rname))
def get_kwargs_from_dict(d, prefix, lower=False): def get_kwargs_from_dict(d, prefix, lower=False):
tups_list = [] tups_list = []
for i in d: for i in d:

View File

@ -1,3 +1,5 @@
oca==4.10.0 oca==4.10.0
IPy==0.83 IPy==0.83
dnslib==0.9.6 dnslib==0.9.6
dnspython==1.14.0
wrapt==1.10.8

View File

@ -22,6 +22,8 @@ setup(
"oca>=4.10.0", "oca>=4.10.0",
"IPy>=0.83", "IPy>=0.83",
"dnslib>=0.9.6", "dnslib>=0.9.6",
"dnspython>=1.14.0",
"wrapt>=1.10.8",
], ],
setup_requires=[ setup_requires=[
'pytest-runner>=2.9' 'pytest-runner>=2.9'