resolver: add thread mutex and use new zone module
parent
2c0283fdb7
commit
b735c84a04
|
@ -1,14 +1,16 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
import dnslib
|
import dnslib
|
||||||
from dnslib import server
|
from dnslib import server
|
||||||
from dnslib import dns
|
|
||||||
|
|
||||||
from IPy import IP
|
from wrapt import synchronized
|
||||||
|
|
||||||
|
from onedns import zone
|
||||||
from onedns import utils
|
from onedns import utils
|
||||||
|
from onedns import exception
|
||||||
|
from onedns.logger import log
|
||||||
|
|
||||||
|
|
||||||
class DynamicResolver(server.BaseResolver):
|
class DynamicResolver(server.BaseResolver):
|
||||||
|
@ -16,16 +18,18 @@ class DynamicResolver(server.BaseResolver):
|
||||||
Dynamic In-Memory DNS Resolver
|
Dynamic In-Memory DNS Resolver
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, domain, one_kwargs={}):
|
_lock = threading.RLock()
|
||||||
|
|
||||||
|
def __init__(self, domain):
|
||||||
"""
|
"""
|
||||||
Initialise resolver from zone list
|
Initialise resolver from zone list
|
||||||
Stores RRs as a list of (label, type, rr) tuples
|
Stores RRs as a list of (label, type, rr) tuples
|
||||||
"""
|
"""
|
||||||
self.domain = domain
|
self.zone = zone.Zone(domain)
|
||||||
self.zone = []
|
|
||||||
self._tcp_server = None
|
self._tcp_server = None
|
||||||
self._udp_server = None
|
self._udp_server = None
|
||||||
|
|
||||||
|
@synchronized(_lock)
|
||||||
def resolve(self, request, handler):
|
def resolve(self, request, handler):
|
||||||
"""
|
"""
|
||||||
Respond to DNS request - parameters are request packet & handler.
|
Respond to DNS request - parameters are request packet & handler.
|
||||||
|
@ -33,57 +37,46 @@ class DynamicResolver(server.BaseResolver):
|
||||||
"""
|
"""
|
||||||
reply = request.reply()
|
reply = request.reply()
|
||||||
qname = request.q.qname
|
qname = request.q.qname
|
||||||
qtype = dnslib.QTYPE[request.q.qtype]
|
qtype = request.q.qtype
|
||||||
A_RECORDS = ['A', 'AAAA']
|
try:
|
||||||
for name, rtype, rr in self.zone:
|
if qtype in (dnslib.QTYPE.A, dnslib.QTYPE.AAAA):
|
||||||
# Check if label & type match
|
forward = self.zone.get_forward(qname)
|
||||||
if qname == name and (qtype in [rtype, 'ANY'] or rtype == 'CNAME'):
|
reply.add_answer(forward)
|
||||||
reply.add_answer(rr)
|
elif qtype == dnslib.QTYPE.PTR:
|
||||||
# Check for A/AAAA records associated with reply and
|
reverse = self.zone.get_reverse(
|
||||||
# add in additional section
|
utils.reverse_to_ip(qname.idna()))
|
||||||
if rtype in ['CNAME', 'NS', 'MX', 'PTR']:
|
reply.add_answer(reverse)
|
||||||
for a_name, a_rtype, a_rr in self.zone:
|
forward = self.zone.get_forward(str(reverse.rdata))
|
||||||
if a_name == rr.rdata.label and a_rtype in A_RECORDS:
|
if forward:
|
||||||
reply.add_ar(a_rr)
|
reply.add_ar(forward)
|
||||||
if not reply.rr:
|
except exception.RecordDoesNotExist:
|
||||||
reply.header.rcode = dnslib.RCODE.NXDOMAIN
|
reply.header.rcode = dnslib.RCODE.NXDOMAIN
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
@synchronized(_lock)
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.zone = []
|
self.zone.clear()
|
||||||
|
|
||||||
def add_host(self, name, ip, zone=None):
|
@synchronized(_lock)
|
||||||
zone = zone or self.zone
|
def load(self, zone):
|
||||||
self._add_forward(name, ip, zone=zone)
|
self.zone = zone
|
||||||
self._add_reverse(ip, name, zone=zone)
|
|
||||||
|
|
||||||
def _get_fqdn(self, name):
|
@synchronized(_lock)
|
||||||
return utils.get_fqdn(name, self.domain)
|
def add_host(self, name, ip):
|
||||||
|
self.zone.add_host(name, ip)
|
||||||
|
|
||||||
def _add_forward(self, name, ip, zone=None):
|
@synchronized(_lock)
|
||||||
zone = zone or self.zone
|
def remove_host(self, name, ip):
|
||||||
f = dnslib.RR(rname=dnslib.DNSLabel(self._get_fqdn(name)),
|
self.zone.remove_host(name, ip)
|
||||||
rtype=dnslib.QTYPE.reverse['A'],
|
|
||||||
rclass=dnslib.CLASS.reverse['IN'],
|
|
||||||
rdata=dns.A(ip))
|
|
||||||
zone.append((f.rname, 'A', f))
|
|
||||||
|
|
||||||
def _add_reverse(self, ip, name, zone=None):
|
|
||||||
zone = zone or self.zone
|
|
||||||
ip = IP(ip)
|
|
||||||
r = dnslib.RR(rname=dnslib.DNSLabel(ip.reverseName()),
|
|
||||||
rtype=dnslib.QTYPE.reverse['PTR'],
|
|
||||||
rclass=dnslib.CLASS.reverse['IN'],
|
|
||||||
rdata=dns.PTR(self._get_fqdn(name)))
|
|
||||||
zone.append((r.rname, 'PTR', r))
|
|
||||||
|
|
||||||
def start(self, dns_address='0.0.0.0', dns_port=53,
|
def start(self, dns_address='0.0.0.0', dns_port=53,
|
||||||
api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0,
|
api_address='127.0.0.1', api_port=8000, tcp=False, udplen=0,
|
||||||
log="request,reply,truncated,error", log_prefix=False):
|
log_components="request,reply,truncated,error",
|
||||||
logger = server.DNSLogger(log, log_prefix)
|
log_prefix=False):
|
||||||
|
logger = server.DNSLogger(log_components, log_prefix)
|
||||||
|
|
||||||
print("Starting OneDNS (%s:%d) [%s]" %
|
log.info("Starting OneDNS (%s:%d) [%s]" %
|
||||||
(dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP"))
|
(dns_address or "*", dns_port, "UDP/TCP" if tcp else "UDP"))
|
||||||
|
|
||||||
server.DNSHandler.udplen = udplen
|
server.DNSHandler.udplen = udplen
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,6 @@ import dnslib
|
||||||
from IPy import IP
|
from IPy import IP
|
||||||
|
|
||||||
from onedns import zone
|
from onedns import zone
|
||||||
from onedns import server
|
|
||||||
from onedns import resolver
|
from onedns import resolver
|
||||||
from onedns.tests import vcr
|
from onedns.tests import vcr
|
||||||
from onedns.clients import one
|
from onedns.clients import one
|
||||||
|
@ -30,6 +29,7 @@ TEST_GET_FORWARD_DATA = [
|
||||||
HOST + '.',
|
HOST + '.',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def dns(request):
|
def dns(request):
|
||||||
dns = resolver.DynamicResolver(domain=DOMAIN)
|
dns = resolver.DynamicResolver(domain=DOMAIN)
|
||||||
|
|
|
@ -1,14 +1,22 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import dnslib
|
import dnslib
|
||||||
|
|
||||||
|
from IPy import IP
|
||||||
|
|
||||||
|
from onedns import zone
|
||||||
from onedns.tests import utils
|
from onedns.tests import utils
|
||||||
from onedns.tests import conftest
|
from onedns.tests import conftest
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_and_add_test_host(dns, name, ip):
|
||||||
|
dns.clear()
|
||||||
|
dns.add_host(name, ip)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
|
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
|
||||||
def test_lookup(dns, qname, qtype, output):
|
def test_lookup(dns, qname, qtype, output):
|
||||||
dns.clear()
|
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
|
||||||
dns.add_host(conftest.HOST, conftest.HOST_IP)
|
|
||||||
try:
|
try:
|
||||||
a = utils.dnsquery(qname, qtype)
|
a = utils.dnsquery(qname, qtype)
|
||||||
assert a.short() == output
|
assert a.short() == output
|
||||||
|
@ -25,15 +33,35 @@ def test_nxdomain(dns):
|
||||||
dns.close()
|
dns.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name,ip", conftest.TEST_GET_FQDN_DATA)
|
|
||||||
def test_get_fqdn(dns, name, ip):
|
|
||||||
dns.clear()
|
|
||||||
dns.add_host(name, ip)
|
|
||||||
assert dns.zone[0][0].label[0] == name.split('.')[0]
|
|
||||||
assert '.'.join(dns.zone[0][0].label[1:]) == conftest.DOMAIN
|
|
||||||
|
|
||||||
|
|
||||||
def test_daemon(dns):
|
def test_daemon(dns):
|
||||||
dns.close()
|
dns.close()
|
||||||
dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT,
|
dns.daemon(dns_address=conftest.INTERFACE, dns_port=conftest.PORT,
|
||||||
tcp=True, testing=True)
|
tcp=True, testing=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("qname,qtype,output", conftest.TEST_LOOKUP_DATA)
|
||||||
|
def test_remove_host(dns, qname, qtype, output):
|
||||||
|
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
|
||||||
|
try:
|
||||||
|
a = utils.dnsquery(qname, qtype)
|
||||||
|
assert a.short() == output
|
||||||
|
dns.remove_host(conftest.HOST, conftest.HOST_IP)
|
||||||
|
a = utils.dnsquery(qname, qtype)
|
||||||
|
assert dnslib.RCODE.get(a.header.rcode) == 'NXDOMAIN'
|
||||||
|
finally:
|
||||||
|
dns.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_zone(dns):
|
||||||
|
new_ip = IP(IP(conftest.HOST_IP).int() + 1)
|
||||||
|
new_zone = zone.Zone(conftest.DOMAIN)
|
||||||
|
new_zone.add_host(conftest.HOST, new_ip)
|
||||||
|
_clear_and_add_test_host(dns, conftest.HOST, conftest.HOST_IP)
|
||||||
|
try:
|
||||||
|
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
|
||||||
|
assert a.short() == conftest.HOST_IP
|
||||||
|
dns.load(new_zone)
|
||||||
|
a = utils.dnsquery(conftest.HOST, dnslib.QTYPE.A)
|
||||||
|
assert a.short() == str(new_ip)
|
||||||
|
finally:
|
||||||
|
dns.close()
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import dns.name
|
||||||
|
import dns.reversename
|
||||||
|
|
||||||
import dnslib
|
import dnslib
|
||||||
|
|
||||||
|
from IPy import IP
|
||||||
|
|
||||||
from onedns.logger import log
|
from onedns.logger import log
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +19,11 @@ def get_fqdn(name, domain):
|
||||||
return name.idna()
|
return name.idna()
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_to_ip(reverse):
|
||||||
|
rname = dns.name.Name(reverse.split('.'))
|
||||||
|
return IP(dns.reversename.to_address(rname))
|
||||||
|
|
||||||
|
|
||||||
def get_kwargs_from_dict(d, prefix, lower=False):
|
def get_kwargs_from_dict(d, prefix, lower=False):
|
||||||
tups_list = []
|
tups_list = []
|
||||||
for i in d:
|
for i in d:
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
oca==4.10.0
|
oca==4.10.0
|
||||||
IPy==0.83
|
IPy==0.83
|
||||||
dnslib==0.9.6
|
dnslib==0.9.6
|
||||||
|
dnspython==1.14.0
|
||||||
|
wrapt==1.10.8
|
||||||
|
|
Loading…
Reference in New Issue