resolver: add thread mutex and use new zone module

master
Justin Riley 2016-08-17 16:18:03 -04:00
parent 2c0283fdb7
commit b735c84a04
6 changed files with 92 additions and 57 deletions

View File

@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import time
import threading
import dnslib
from dnslib import server
from dnslib import dns
from IPy import IP
from wrapt import synchronized
from onedns import zone
from onedns import utils
from onedns import exception
from onedns.logger import log
class DynamicResolver(server.BaseResolver):
@ -16,16 +18,18 @@ class DynamicResolver(server.BaseResolver):
Dynamic In-Memory DNS Resolver
"""
def __init__(self, domain, one_kwargs={}):
_lock = threading.RLock()
def __init__(self, domain):
"""
Initialise resolver from zone list
Stores RRs as a list of (label, type, rr) tuples
"""
self.domain = domain
self.zone = []
self.zone = zone.Zone(domain)
self._tcp_server = None
self._udp_server = None
@synchronized(_lock)
def resolve(self, request, handler):
"""
Respond to DNS request - parameters are request packet & handler.
@ -33,56 +37,45 @@ class DynamicResolver(server.BaseResolver):
"""
reply = request.reply()
qname = request.q.qname
qtype = dnslib.QTYPE[request.q.qtype]
A_RECORDS = ['A', 'AAAA']
for name, rtype, rr in self.zone:
# Check if label & type match
if qname == name and (qtype in [rtype, 'ANY'] or rtype == 'CNAME'):
reply.add_answer(rr)
# Check for A/AAAA records associated with reply and
# add in additional section
if rtype in ['CNAME', 'NS', 'MX', 'PTR']:
for a_name, a_rtype, a_rr in self.zone:
if a_name == rr.rdata.label and a_rtype in A_RECORDS:
reply.add_ar(a_rr)
if not reply.rr:
qtype = request.q.qtype
try:
if qtype in (dnslib.QTYPE.A, dnslib.QTYPE.AAAA):
forward = self.zone.get_forward(qname)
reply.add_answer(forward)
elif qtype == dnslib.QTYPE.PTR:
reverse = self.zone.get_reverse(
utils.reverse_to_ip(qname.idna()))
reply.add_answer(reverse)
forward = self.zone.get_forward(str(reverse.rdata))
if forward:
reply.add_ar(forward)
except exception.RecordDoesNotExist:
reply.header.rcode = dnslib.RCODE.NXDOMAIN
return reply
@synchronized(_lock)
def clear(self):
self.zone = []
self.zone.clear()
def add_host(self, name, ip, zone=None):
zone = zone or self.zone
self._add_forward(name, ip, zone=zone)
self._add_reverse(ip, name, zone=zone)
@synchronized(_lock)
def load(self, zone):
self.zone = zone
def _get_fqdn(self, name):
return utils.get_fqdn(name, self.domain)
@synchronized(_lock)
def add_host(self, name, ip):
self.zone.add_host(name, ip)
def _add_forward(self, name, ip, zone=None):
zone = zone or self.zone
f = dnslib.RR(rname=dnslib.DNSLabel(self._get_fqdn(name)),
rtype=dnslib.QTYPE.reverse['A'],
rclass=dnslib.CLASS.reverse['IN'],
rdata=dns.A(ip))
zone.append((f.rname, 'A', f))
def _add_reverse(self, ip, name, zone=None):
zone = zone or self.zone
ip = IP(ip)
r = dnslib.RR(rname=dnslib.DNSLabel(ip.reverseName()),
rtype=dnslib.QTYPE.reverse['PTR'],
rclass=dnslib.CLASS.reverse['IN'],
rdata=dns.PTR(self._get_fqdn(name)))
zone.append((r.rname, 'PTR', r))
@synchronized(_lock)
def remove_host(self, name, ip):
self.zone.remove_host(name, ip)
def start(self, dns_address='0.0.0.0', dns_port=53,
api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0,
log="request,reply,truncated,error", log_prefix=False):
logger = server.DNSLogger(log, log_prefix)
log_components="request,reply,truncated,error",
log_prefix=False):
logger = server.DNSLogger(log_components, log_prefix)
print("Starting OneDNS (%s:%d) [%s]" %
log.info("Starting OneDNS (%s:%d) [%s]" %
(dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP"))
server.DNSHandler.udplen = udplen

View File

@ -3,7 +3,6 @@ import dnslib
from IPy import IP
from onedns import zone
from onedns import server
from onedns import resolver
from onedns.tests import vcr
from onedns.clients import one
@ -30,6 +29,7 @@ TEST_GET_FORWARD_DATA = [
HOST + '.',
]
@pytest.fixture(scope="function")
def dns(request):
dns = resolver.DynamicResolver(domain=DOMAIN)

View File

@ -1,14 +1,22 @@
import pytest
import dnslib
from IPy import IP
from onedns import zone
from onedns.tests import utils
from onedns.tests import conftest
def _clear_and_add_test_host(dns, name, ip):
dns.clear()
dns.add_host(name, ip)
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
def test_lookup(dns, qname, qtype, output):
dns.clear()
dns.add_host(conftest.HOST, conftest.HOST_IP)
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
try:
a = utils.dnsquery(qname, qtype)
assert a.short() == output
@ -25,15 +33,35 @@ def test_nxdomain(dns):
dns.close()
@pytest.mark.parametrize("name,ip", conftest.TEST_GET_FQDN_DATA)
def test_get_fqdn(dns, name, ip):
dns.clear()
dns.add_host(name, ip)
assert dns.zone[0][0].label[0] == name.split('.')[0]
assert '.'.join(dns.zone[0][0].label[1:]) == conftest.DOMAIN
def test_daemon(dns):
dns.close()
dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT,
tcp=True, testing=True)
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
def test_remove_host(dns, qname, qtype, output):
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
try:
a = utils.dnsquery(qname, qtype)
assert a.short() == output
dns.remove_host(conftest.HOST, conftest.HOST_IP)
a = utils.dnsquery(qname, qtype)
assert dnslib.RCODE.get(a.header.rcode) == 'NXDOMAIN'
finally:
dns.close()
def test_load_zone(dns):
new_ip = IP(IP(conftest.HOST_IP).int() + 1)
new_zone = zone.Zone(conftest.DOMAIN)
new_zone.add_host(conftest.HOST, new_ip)
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
try:
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
assert a.short() == conftest.HOST_IP
dns.load(new_zone)
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
assert a.short() == str(new_ip)
finally:
dns.close()

View File

@ -1,7 +1,12 @@
import os
import dns.name
import dns.reversename
import dnslib
from IPy import IP
from onedns.logger import log
@ -14,6 +19,11 @@ def get_fqdn(name, domain):
return name.idna()
def reverse_to_ip(reverse):
rname = dns.name.Name(reverse.split('.'))
return IP(dns.reversename.to_address(rname))
def get_kwargs_from_dict(d, prefix, lower=False):
tups_list = []
for i in d:

View File

@ -1,3 +1,5 @@
oca==4.10.0
IPy==0.83
dnslib==0.9.6
dnspython==1.14.0
wrapt==1.10.8

View File

@ -22,6 +22,8 @@ setup(
"oca>=4.10.0",
"IPy>=0.83",
"dnslib>=0.9.6",
"dnspython>=1.14.0",
"wrapt>=1.10.8",
],
setup_requires=[
'pytest-runner>=2.9'