# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for L{twisted.names.server}.
"""
from __future__ import division, absolute_import
from zope.interface.verify import verifyClass
from twisted.internet import defer
from twisted.internet.interfaces import IProtocolFactory
from twisted.names import dns, error, resolve, server
from twisted.python import failure, log
from twisted.trial import unittest
class RaisedArguments(Exception):
"""
An exception containing the arguments raised by L{raiser}.
"""
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
def raiser(*args, **kwargs):
"""
Raise a L{RaisedArguments} exception containing the supplied arguments.
Used as a fake when testing the call signatures of methods and functions.
"""
raise RaisedArguments(args, kwargs)
class NoResponseDNSServerFactory(server.DNSServerFactory):
"""
A L{server.DNSServerFactory} subclass which does not attempt to reply to any
received messages.
Used for testing logged messages in C{messageReceived} without having to
fake or patch the preceding code which attempts to deliver a response
message.
"""
def allowQuery(self, message, protocol, address):
"""
Deny all queries.
@param message: See L{server.DNSServerFactory.allowQuery}
@param protocol: See L{server.DNSServerFactory.allowQuery}
@param address: See L{server.DNSServerFactory.allowQuery}
@return: L{False}
@rtype: L{bool}
"""
return False
def sendReply(self, protocol, message, address):
"""
A noop send reply.
@param protocol: See L{server.DNSServerFactory.sendReply}
@param message: See L{server.DNSServerFactory.sendReply}
@param address: See L{server.DNSServerFactory.sendReply}
"""
class RaisingDNSServerFactory(server.DNSServerFactory):
"""
A L{server.DNSServerFactory} subclass whose methods raise an exception
containing the supplied arguments.
Used for stopping L{messageReceived} and testing the arguments supplied to
L{allowQuery}.
"""
class AllowQueryArguments(Exception):
"""
Contains positional and keyword arguments in C{args}.
"""
def allowQuery(self, *args, **kwargs):
"""
Raise the arguments supplied to L{allowQuery}.
@param args: Positional arguments which will be recorded in the raised
exception.
@type args: L{tuple}
@param kwargs: Keyword args which will be recorded in the raised
exception.
@type kwargs: L{dict}
"""
raise self.AllowQueryArguments(args, kwargs)
class RaisingProtocol(object):
"""
A partial fake L{IProtocol} whose methods raise an exception containing the
supplied arguments.
"""
class WriteMessageArguments(Exception):
"""
Contains positional and keyword arguments in C{args}.
"""
def writeMessage(self, *args, **kwargs):
"""
Raises the supplied arguments.
@param args: Positional arguments
@type args: L{tuple}
@param kwargs: Keyword args
@type kwargs: L{dict}
"""
raise self.WriteMessageArguments(args, kwargs)
class NoopProtocol(object):
"""
A partial fake L{dns.DNSProtocolMixin} with a noop L{writeMessage} method.
"""
def writeMessage(self, *args, **kwargs):
"""
A noop version of L{dns.DNSProtocolMixin.writeMessage}.
@param args: Positional arguments
@type args: L{tuple}
@param kwargs: Keyword args
@type kwargs: L{dict}
"""
class RaisingResolver(object):
"""
A partial fake L{IResolver} whose methods raise an exception containing the
supplied arguments.
"""
class QueryArguments(Exception):
"""
Contains positional and keyword arguments in C{args}.
"""
def query(self, *args, **kwargs):
"""
Raises the supplied arguments.
@param args: Positional arguments
@type args: L{tuple}
@param kwargs: Keyword args
@type kwargs: L{dict}
"""
raise self.QueryArguments(args, kwargs)
class RaisingCache(object):
"""
A partial fake L{twisted.names.cache.Cache} whose methods raise an exception
containing the supplied arguments.
"""
class CacheResultArguments(Exception):
"""
Contains positional and keyword arguments in C{args}.
"""
def cacheResult(self, *args, **kwargs):
"""
Raises the supplied arguments.
@param args: Positional arguments
@type args: L{tuple}
@param kwargs: Keyword args
@type kwargs: L{dict}
"""
raise self.CacheResultArguments(args, kwargs)
def assertLogMessage(testCase, expectedMessages, callable, *args, **kwargs):
"""
Assert that the callable logs the expected messages when called.
XXX: Put this somewhere where it can be re-used elsewhere. See #6677.
@param testCase: The test case controlling the test which triggers the
logged messages and on which assertions will be called.
@type testCase: L{unittest.SynchronousTestCase}
@param expectedMessages: A L{list} of the expected log messages
@type expectedMessages: L{list}
@param callable: The function which is expected to produce the
C{expectedMessages} when called.
@type callable: L{callable}
@param args: Positional arguments to be passed to C{callable}.
@type args: L{list}
@param kwargs: Keyword arguments to be passed to C{callable}.
@type kwargs: L{dict}
"""
loggedMessages = []
log.addObserver(loggedMessages.append)
testCase.addCleanup(log.removeObserver, loggedMessages.append)
callable(*args, **kwargs)
testCase.assertEqual(
[m['message'][0] for m in loggedMessages],
expectedMessages)
class DNSServerFactoryTests(unittest.TestCase):
"""
Tests for L{server.DNSServerFactory}.
"""
def test_resolverType(self):
"""
L{server.DNSServerFactory.resolver} is a L{resolve.ResolverChain}
instance
"""
self.assertIsInstance(
server.DNSServerFactory().resolver,
resolve.ResolverChain)
def test_resolverDefaultEmpty(self):
"""
L{server.DNSServerFactory.resolver} is an empty L{resolve.ResolverChain}
by default.
"""
self.assertEqual(
server.DNSServerFactory().resolver.resolvers,
[])
def test_authorities(self):
"""
L{server.DNSServerFactory.__init__} accepts an C{authorities}
argument. The value of this argument is a list and is used to extend the
C{resolver} L{resolve.ResolverChain}.
"""
dummyResolver = object()
self.assertEqual(
server.DNSServerFactory(
authorities=[dummyResolver]).resolver.resolvers,
[dummyResolver])
def test_caches(self):
"""
L{server.DNSServerFactory.__init__} accepts a C{caches} argument. The
value of this argument is a list and is used to extend the C{resolver}
L{resolve.ResolverChain}.
"""
dummyResolver = object()
self.assertEqual(
server.DNSServerFactory(
caches=[dummyResolver]).resolver.resolvers,
[dummyResolver])
def test_clients(self):
"""
L{server.DNSServerFactory.__init__} accepts a C{clients} argument. The
value of this argument is a list and is used to extend the C{resolver}
L{resolve.ResolverChain}.
"""
dummyResolver = object()
self.assertEqual(
server.DNSServerFactory(
clients=[dummyResolver]).resolver.resolvers,
[dummyResolver])
def test_resolverOrder(self):
"""
L{server.DNSServerFactory.resolver} contains an ordered list of
authorities, caches and clients.
"""
# Use classes here so that we can see meaningful names in test results
class DummyAuthority(object):
pass
class DummyCache(object):
pass
class DummyClient(object):
pass
self.assertEqual(
server.DNSServerFactory(
authorities=[DummyAuthority],
caches=[DummyCache],
clients=[DummyClient]).resolver.resolvers,
[DummyAuthority, DummyCache, DummyClient])
def test_cacheDefault(self):
"""
L{server.DNSServerFactory.cache} is L{None} by default.
"""
self.assertIsNone(server.DNSServerFactory().cache)
def test_cacheOverride(self):
"""
L{server.DNSServerFactory.__init__} assigns the last object in the
C{caches} list to L{server.DNSServerFactory.cache}.
"""
dummyResolver = object()
self.assertEqual(
server.DNSServerFactory(caches=[object(), dummyResolver]).cache,
dummyResolver)
def test_canRecurseDefault(self):
"""
L{server.DNSServerFactory.canRecurse} is a flag indicating that this
server is capable of performing recursive DNS lookups. It defaults to
L{False}.
"""
self.assertFalse(server.DNSServerFactory().canRecurse)
def test_canRecurseOverride(self):
"""
L{server.DNSServerFactory.__init__} sets C{canRecurse} to L{True} if it
is supplied with C{clients}.
"""
self.assertEqual(
server.DNSServerFactory(clients=[None]).canRecurse, True)
def test_verboseDefault(self):
"""
L{server.DNSServerFactory.verbose} defaults to L{False}.
"""
self.assertFalse(server.DNSServerFactory().verbose)
def test_verboseOverride(self):
"""
L{server.DNSServerFactory.__init__} accepts a C{verbose} argument which
overrides L{server.DNSServerFactory.verbose}.
"""
self.assertTrue(server.DNSServerFactory(verbose=True).verbose)
def test_interface(self):
"""
L{server.DNSServerFactory} implements L{IProtocolFactory}.
"""
self.assertTrue(verifyClass(IProtocolFactory, server.DNSServerFactory))
def test_defaultProtocol(self):
"""
L{server.DNSServerFactory.protocol} defaults to L{dns.DNSProtocol}.
"""
self.assertIs(server.DNSServerFactory.protocol, dns.DNSProtocol)
def test_buildProtocolProtocolOverride(self):
"""
L{server.DNSServerFactory.buildProtocol} builds a protocol by calling
L{server.DNSServerFactory.protocol} with its self as a positional
argument.
"""
class FakeProtocol(object):
factory = None
args = None
kwargs = None
stubProtocol = FakeProtocol()
def fakeProtocolFactory(*args, **kwargs):
stubProtocol.args = args
stubProtocol.kwargs = kwargs
return stubProtocol
f = server.DNSServerFactory()
f.protocol = fakeProtocolFactory
p = f.buildProtocol(addr=None)
self.assertEqual(
(stubProtocol, (f,), {}),
(p, p.args, p.kwargs)
)
def test_verboseLogQuiet(self):
"""
L{server.DNSServerFactory._verboseLog} does not log messages unless
C{verbose > 0}.
"""
f = server.DNSServerFactory()
assertLogMessage(
self,
[],
f._verboseLog,
'Foo Bar'
)
def test_verboseLogVerbose(self):
"""
L{server.DNSServerFactory._verboseLog} logs a message if C{verbose > 0}.
"""
f = server.DNSServerFactory(verbose=1)
assertLogMessage(
self,
['Foo Bar'],
f._verboseLog,
'Foo Bar'
)
def test_messageReceivedLoggingNoQuery(self):
"""
L{server.DNSServerFactory.messageReceived} logs about an empty query if
the message had no queries and C{verbose} is C{>0}.
"""
m = dns.Message()
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Empty query from ('192.0.2.100', 53)"],
f.messageReceived,
message=m, proto=None, address=('192.0.2.100', 53))
def test_messageReceivedLogging1(self):
"""
L{server.DNSServerFactory.messageReceived} logs the query types of all
queries in the message if C{verbose} is set to C{1}.
"""
m = dns.Message()
m.addQuery(name='example.com', type=dns.MX)
m.addQuery(name='example.com', type=dns.AAAA)
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["MX AAAA query from ('192.0.2.100', 53)"],
f.messageReceived,
message=m, proto=None, address=('192.0.2.100', 53))
def test_messageReceivedLogging2(self):
"""
L{server.DNSServerFactory.messageReceived} logs the repr of all queries
in the message if C{verbose} is set to C{2}.
"""
m = dns.Message()
m.addQuery(name='example.com', type=dns.MX)
m.addQuery(name='example.com', type=dns.AAAA)
f = NoResponseDNSServerFactory(verbose=2)
assertLogMessage(
self,
["<Query example.com MX IN> "
"<Query example.com AAAA IN> query from ('192.0.2.100', 53)"],
f.messageReceived,
message=m, proto=None, address=('192.0.2.100', 53))
def test_messageReceivedTimestamp(self):
"""
L{server.DNSServerFactory.messageReceived} assigns a unix timestamp to
the received message.
"""
m = dns.Message()
f = NoResponseDNSServerFactory()
t = object()
self.patch(server.time, 'time', lambda: t)
f.messageReceived(message=m, proto=None, address=None)
self.assertEqual(m.timeReceived, t)
def test_messageReceivedAllowQuery(self):
"""
L{server.DNSServerFactory.messageReceived} passes all messages to
L{server.DNSServerFactory.allowQuery} along with the receiving protocol
and origin address.
"""
message = dns.Message()
dummyProtocol = object()
dummyAddress = object()
f = RaisingDNSServerFactory()
e = self.assertRaises(
RaisingDNSServerFactory.AllowQueryArguments,
f.messageReceived,
message=message, proto=dummyProtocol, address=dummyAddress)
args, kwargs = e.args
self.assertEqual(args, (message, dummyProtocol, dummyAddress))
self.assertEqual(kwargs, {})
def test_allowQueryFalse(self):
"""
If C{allowQuery} returns C{False},
L{server.DNSServerFactory.messageReceived} calls L{server.sendReply}
with a message whose C{rCode} is L{dns.EREFUSED}.
"""
class SendReplyException(Exception):
pass
class RaisingDNSServerFactory(server.DNSServerFactory):
def allowQuery(self, *args, **kwargs):
return False
def sendReply(self, *args, **kwargs):
raise SendReplyException(args, kwargs)
f = RaisingDNSServerFactory()
e = self.assertRaises(
SendReplyException,
f.messageReceived,
message=dns.Message(), proto=None, address=None)
(proto, message, address), kwargs = e.args
self.assertEqual(message.rCode, dns.EREFUSED)
def _messageReceivedTest(self, methodName, message):
"""
Assert that the named method is called with the given message when it is
passed to L{DNSServerFactory.messageReceived}.
@param methodName: The name of the method which is expected to be
called.
@type methodName: L{str}
@param message: The message which is expected to be passed to the
C{methodName} method.
@type message: L{dns.Message}
"""
# Make it appear to have some queries so that
# DNSServerFactory.allowQuery allows it.
message.queries = [None]
receivedMessages = []
def fakeHandler(message, protocol, address):
receivedMessages.append((message, protocol, address))
protocol = NoopProtocol()
factory = server.DNSServerFactory(None)
setattr(factory, methodName, fakeHandler)
factory.messageReceived(message, protocol)
self.assertEqual(receivedMessages, [(message, protocol, None)])
def test_queryMessageReceived(self):
"""
L{DNSServerFactory.messageReceived} passes messages with an opcode of
C{OP_QUERY} on to L{DNSServerFactory.handleQuery}.
"""
self._messageReceivedTest(
'handleQuery', dns.Message(opCode=dns.OP_QUERY))
def test_inverseQueryMessageReceived(self):
"""
L{DNSServerFactory.messageReceived} passes messages with an opcode of
C{OP_INVERSE} on to L{DNSServerFactory.handleInverseQuery}.
"""
self._messageReceivedTest(
'handleInverseQuery', dns.Message(opCode=dns.OP_INVERSE))
def test_statusMessageReceived(self):
"""
L{DNSServerFactory.messageReceived} passes messages with an opcode of
C{OP_STATUS} on to L{DNSServerFactory.handleStatus}.
"""
self._messageReceivedTest(
'handleStatus', dns.Message(opCode=dns.OP_STATUS))
def test_notifyMessageReceived(self):
"""
L{DNSServerFactory.messageReceived} passes messages with an opcode of
C{OP_NOTIFY} on to L{DNSServerFactory.handleNotify}.
"""
self._messageReceivedTest(
'handleNotify', dns.Message(opCode=dns.OP_NOTIFY))
def test_updateMessageReceived(self):
"""
L{DNSServerFactory.messageReceived} passes messages with an opcode of
C{OP_UPDATE} on to L{DNSServerFactory.handleOther}.
This may change if the implementation ever covers update messages.
"""
self._messageReceivedTest(
'handleOther', dns.Message(opCode=dns.OP_UPDATE))
def test_connectionTracking(self):
"""
The C{connectionMade} and C{connectionLost} methods of
L{DNSServerFactory} cooperate to keep track of all L{DNSProtocol}
objects created by a factory which are connected.
"""
protoA, protoB = object(), object()
factory = server.DNSServerFactory()
factory.connectionMade(protoA)
self.assertEqual(factory.connections, [protoA])
factory.connectionMade(protoB)
self.assertEqual(factory.connections, [protoA, protoB])
factory.connectionLost(protoA)
self.assertEqual(factory.connections, [protoB])
factory.connectionLost(protoB)
self.assertEqual(factory.connections, [])
def test_handleQuery(self):
"""
L{server.DNSServerFactory.handleQuery} takes the first query from the
supplied message and dispatches it to
L{server.DNSServerFactory.resolver.query}.
"""
m = dns.Message()
m.addQuery(b'one.example.com')
m.addQuery(b'two.example.com')
f = server.DNSServerFactory()
f.resolver = RaisingResolver()
e = self.assertRaises(
RaisingResolver.QueryArguments,
f.handleQuery,
message=m, protocol=NoopProtocol(), address=None)
(query,), kwargs = e.args
self.assertEqual(query, m.queries[0])
def test_handleQueryCallback(self):
"""
L{server.DNSServerFactory.handleQuery} adds
L{server.DNSServerFactory.resolver.gotResolverResponse} as a callback to
the deferred returned by L{server.DNSServerFactory.resolver.query}. It
is called with the query response, the original protocol, message and
origin address.
"""
f = server.DNSServerFactory()
d = defer.Deferred()
class FakeResolver(object):
def query(self, *args, **kwargs):
return d
f.resolver = FakeResolver()
gotResolverResponseArgs = []
def fakeGotResolverResponse(*args, **kwargs):
gotResolverResponseArgs.append((args, kwargs))
f.gotResolverResponse = fakeGotResolverResponse
m = dns.Message()
m.addQuery(b'one.example.com')
stubProtocol = NoopProtocol()
dummyAddress = object()
f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress)
dummyResponse = object()
d.callback(dummyResponse)
self.assertEqual(
gotResolverResponseArgs,
[((dummyResponse, stubProtocol, m, dummyAddress), {})])
def test_handleQueryErrback(self):
"""
L{server.DNSServerFactory.handleQuery} adds
L{server.DNSServerFactory.resolver.gotResolverError} as an errback to
the deferred returned by L{server.DNSServerFactory.resolver.query}. It
is called with the query failure, the original protocol, message and
origin address.
"""
f = server.DNSServerFactory()
d = defer.Deferred()
class FakeResolver(object):
def query(self, *args, **kwargs):
return d
f.resolver = FakeResolver()
gotResolverErrorArgs = []
def fakeGotResolverError(*args, **kwargs):
gotResolverErrorArgs.append((args, kwargs))
f.gotResolverError = fakeGotResolverError
m = dns.Message()
m.addQuery(b'one.example.com')
stubProtocol = NoopProtocol()
dummyAddress = object()
f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress)
stubFailure = failure.Failure(Exception())
d.errback(stubFailure)
self.assertEqual(
gotResolverErrorArgs,
[((stubFailure, stubProtocol, m, dummyAddress), {})])
def test_gotResolverResponse(self):
"""
L{server.DNSServerFactory.gotResolverResponse} accepts a tuple of
resource record lists and triggers a response message containing those
resource record lists.
"""
f = server.DNSServerFactory()
answers = []
authority = []
additional = []
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.gotResolverResponse,
(answers, authority, additional),
protocol=RaisingProtocol(), message=dns.Message(), address=None)
(message,), kwargs = e.args
self.assertIs(message.answers, answers)
self.assertIs(message.authority, authority)
self.assertIs(message.additional, additional)
def test_gotResolverResponseCallsResponseFromMessage(self):
"""
L{server.DNSServerFactory.gotResolverResponse} calls
L{server.DNSServerFactory._responseFromMessage} to generate a response.
"""
factory = NoResponseDNSServerFactory()
factory._responseFromMessage = raiser
request = dns.Message()
request.timeReceived = 1
e = self.assertRaises(
RaisedArguments,
factory.gotResolverResponse,
([], [], []),
protocol=None, message=request, address=None
)
self.assertEqual(
((), dict(message=request, rCode=dns.OK,
answers=[], authority=[], additional=[])),
(e.args, e.kwargs)
)
def test_responseFromMessageNewMessage(self):
"""
L{server.DNSServerFactory._responseFromMessage} generates a response
message which is a copy of the request message.
"""
factory = server.DNSServerFactory()
request = dns.Message(answer=False, recAv=False)
response = factory._responseFromMessage(message=request),
self.assertIsNot(request, response)
def test_responseFromMessageRecursionAvailable(self):
"""
L{server.DNSServerFactory._responseFromMessage} generates a response
message whose C{recAV} attribute is L{True} if
L{server.DNSServerFactory.canRecurse} is L{True}.
"""
factory = server.DNSServerFactory()
factory.canRecurse = True
response1 = factory._responseFromMessage(
message=dns.Message(recAv=False))
factory.canRecurse = False
response2 = factory._responseFromMessage(
message=dns.Message(recAv=True))
self.assertEqual(
(True, False),
(response1.recAv, response2.recAv))
def test_responseFromMessageTimeReceived(self):
"""
L{server.DNSServerFactory._responseFromMessage} generates a response
message whose C{timeReceived} attribute has the same value as that found
on the request.
"""
factory = server.DNSServerFactory()
request = dns.Message()
request.timeReceived = 1234
response = factory._responseFromMessage(message=request)
self.assertEqual(request.timeReceived, response.timeReceived)
def test_responseFromMessageMaxSize(self):
"""
L{server.DNSServerFactory._responseFromMessage} generates a response
message whose C{maxSize} attribute has the same value as that found
on the request.
"""
factory = server.DNSServerFactory()
request = dns.Message()
request.maxSize = 0
response = factory._responseFromMessage(message=request)
self.assertEqual(request.maxSize, response.maxSize)
def test_messageFactory(self):
"""
L{server.DNSServerFactory} has a C{_messageFactory} attribute which is
L{dns.Message} by default.
"""
self.assertIs(dns.Message, server.DNSServerFactory._messageFactory)
def test_responseFromMessageCallsMessageFactory(self):
"""
L{server.DNSServerFactory._responseFromMessage} calls
C{dns._responseFromMessage} to generate a response
message from the request message. It supplies the request message and
other keyword arguments which should be passed to the response message
initialiser.
"""
factory = server.DNSServerFactory()
self.patch(dns, '_responseFromMessage', raiser)
request = dns.Message()
e = self.assertRaises(
RaisedArguments,
factory._responseFromMessage,
message=request, rCode=dns.OK
)
self.assertEqual(
((), dict(responseConstructor=factory._messageFactory,
message=request, rCode=dns.OK, recAv=factory.canRecurse,
auth=False)),
(e.args, e.kwargs)
)
def test_responseFromMessageAuthoritativeMessage(self):
"""
L{server.DNSServerFactory._responseFromMessage} marks the response
message as authoritative if any of the answer records are authoritative.
"""
factory = server.DNSServerFactory()
response1 = factory._responseFromMessage(
message=dns.Message(), answers=[dns.RRHeader(auth=True)])
response2 = factory._responseFromMessage(
message=dns.Message(), answers=[dns.RRHeader(auth=False)])
self.assertEqual(
(True, False),
(response1.auth, response2.auth),
)
def test_gotResolverResponseLogging(self):
"""
L{server.DNSServerFactory.gotResolverResponse} logs the total number of
records in the response if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
answers = [dns.RRHeader()]
authority = [dns.RRHeader()]
additional = [dns.RRHeader()]
assertLogMessage(
self,
["Lookup found 3 records"],
f.gotResolverResponse,
(answers, authority, additional),
protocol=NoopProtocol(), message=dns.Message(), address=None)
def test_gotResolverResponseCaching(self):
"""
L{server.DNSServerFactory.gotResolverResponse} caches the response if at
least one cache was provided in the constructor.
"""
f = NoResponseDNSServerFactory(caches=[RaisingCache()])
m = dns.Message()
m.addQuery(b'example.com')
expectedAnswers = [dns.RRHeader()]
expectedAuthority = []
expectedAdditional = []
e = self.assertRaises(
RaisingCache.CacheResultArguments,
f.gotResolverResponse,
(expectedAnswers, expectedAuthority, expectedAdditional),
protocol=NoopProtocol(), message=m, address=None)
(query, (answers, authority, additional)), kwargs = e.args
self.assertEqual(query.name.name, b'example.com')
self.assertIs(answers, expectedAnswers)
self.assertIs(authority, expectedAuthority)
self.assertIs(additional, expectedAdditional)
def test_gotResolverErrorCallsResponseFromMessage(self):
"""
L{server.DNSServerFactory.gotResolverError} calls
L{server.DNSServerFactory._responseFromMessage} to generate a response.
"""
factory = NoResponseDNSServerFactory()
factory._responseFromMessage = raiser
request = dns.Message()
request.timeReceived = 1
e = self.assertRaises(
RaisedArguments,
factory.gotResolverError,
failure.Failure(error.DomainError()),
protocol=None, message=request, address=None
)
self.assertEqual(
((), dict(message=request, rCode=dns.ENAME)),
(e.args, e.kwargs)
)
def _assertMessageRcodeForError(self, responseError, expectedMessageCode):
"""
L{server.DNSServerFactory.gotResolver} accepts a L{failure.Failure} and
triggers a response message whose rCode corresponds to the DNS error
contained in the C{Failure}.
@param responseError: The L{Exception} instance which is expected to
trigger C{expectedMessageCode} when it is supplied to
C{gotResolverError}
@type responseError: L{Exception}
@param expectedMessageCode: The C{rCode} which is expected in the
message returned by C{gotResolverError} in response to
C{responseError}.
@type expectedMessageCode: L{int}
"""
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.gotResolverError,
failure.Failure(responseError),
protocol=RaisingProtocol(), message=dns.Message(), address=None)
(message,), kwargs = e.args
self.assertEqual(message.rCode, expectedMessageCode)
def test_gotResolverErrorDomainError(self):
"""
L{server.DNSServerFactory.gotResolver} triggers a response message with
an C{rCode} of L{dns.ENAME} if supplied with a L{error.DomainError}.
"""
self._assertMessageRcodeForError(error.DomainError(), dns.ENAME)
def test_gotResolverErrorAuthoritativeDomainError(self):
"""
L{server.DNSServerFactory.gotResolver} triggers a response message with
an C{rCode} of L{dns.ENAME} if supplied with a
L{error.AuthoritativeDomainError}.
"""
self._assertMessageRcodeForError(
error.AuthoritativeDomainError(), dns.ENAME)
def test_gotResolverErrorOtherError(self):
"""
L{server.DNSServerFactory.gotResolver} triggers a response message with
an C{rCode} of L{dns.ESERVER} if supplied with another type of error and
logs the error.
"""
self._assertMessageRcodeForError(KeyError(), dns.ESERVER)
e = self.flushLoggedErrors(KeyError)
self.assertEqual(len(e), 1)
def test_gotResolverErrorLogging(self):
"""
L{server.DNSServerFactory.gotResolver} logs a message if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Lookup failed"],
f.gotResolverError,
failure.Failure(error.DomainError()),
protocol=NoopProtocol(), message=dns.Message(), address=None)
def test_gotResolverErrorResetsResponseAttributes(self):
"""
L{server.DNSServerFactory.gotResolverError} does not allow request
attributes to leak into the response ie it sends a response with AD, CD
set to 0 and empty response record sections.
"""
factory = server.DNSServerFactory()
responses = []
factory.sendReply = (
lambda protocol, response, address: responses.append(response)
)
request = dns.Message(authenticData=True, checkingDisabled=True)
request.answers = [object(), object()]
request.authority = [object(), object()]
request.additional = [object(), object()]
factory.gotResolverError(
failure.Failure(error.DomainError()),
protocol=None, message=request, address=None
)
self.assertEqual([dns.Message(rCode=3, answer=True)], responses)
def test_gotResolverResponseResetsResponseAttributes(self):
"""
L{server.DNSServerFactory.gotResolverResponse} does not allow request
attributes to leak into the response ie it sends a response with AD, CD
set to 0 and none of the records in the request answer sections are
copied to the response.
"""
factory = server.DNSServerFactory()
responses = []
factory.sendReply = (
lambda protocol, response, address: responses.append(response)
)
request = dns.Message(authenticData=True, checkingDisabled=True)
request.answers = [object(), object()]
request.authority = [object(), object()]
request.additional = [object(), object()]
factory.gotResolverResponse(
([], [], []),
protocol=None, message=request, address=None
)
self.assertEqual([dns.Message(rCode=0, answer=True)], responses)
def test_sendReplyWithAddress(self):
"""
If L{server.DNSServerFactory.sendReply} is supplied with a protocol
*and* an address tuple it will supply that address to
C{protocol.writeMessage}.
"""
m = dns.Message()
dummyAddress = object()
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.sendReply,
protocol=RaisingProtocol(),
message=m,
address=dummyAddress)
args, kwargs = e.args
self.assertEqual(args, (m, dummyAddress))
self.assertEqual(kwargs, {})
def test_sendReplyWithoutAddress(self):
"""
If L{server.DNSServerFactory.sendReply} is supplied with a protocol but
no address tuple it will supply only a message to
C{protocol.writeMessage}.
"""
m = dns.Message()
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.sendReply,
protocol=RaisingProtocol(),
message=m,
address=None)
args, kwargs = e.args
self.assertEqual(args, (m,))
self.assertEqual(kwargs, {})
def test_sendReplyLoggingNoAnswers(self):
"""
If L{server.DNSServerFactory.sendReply} logs a "no answers" message if
the supplied message has no answers.
"""
self.patch(server.time, 'time', lambda: 86402)
m = dns.Message()
m.timeReceived = 86401
f = server.DNSServerFactory(verbose=2)
assertLogMessage(
self,
["Replying with no answers", "Processed query in 1.000 seconds"],
f.sendReply,
protocol=NoopProtocol(),
message=m,
address=None)
def test_sendReplyLoggingWithAnswers(self):
"""
If L{server.DNSServerFactory.sendReply} logs a message for answers,
authority, additional if the supplied a message has records in any of
those sections.
"""
self.patch(server.time, 'time', lambda: 86402)
m = dns.Message()
m.answers.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
m.authority.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
m.additional.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
m.timeReceived = 86401
f = server.DNSServerFactory(verbose=2)
assertLogMessage(
self,
['Answers are <A address=127.0.0.1 ttl=None>',
'Authority is <A address=127.0.0.1 ttl=None>',
'Additional is <A address=127.0.0.1 ttl=None>',
'Processed query in 1.000 seconds'],
f.sendReply,
protocol=NoopProtocol(),
message=m,
address=None)
def test_handleInverseQuery(self):
"""
L{server.DNSServerFactory.handleInverseQuery} triggers the sending of a
response message with C{rCode} set to L{dns.ENOTIMP}.
"""
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.handleInverseQuery,
message=dns.Message(), protocol=RaisingProtocol(), address=None)
(message,), kwargs = e.args
self.assertEqual(message.rCode, dns.ENOTIMP)
def test_handleInverseQueryLogging(self):
"""
L{server.DNSServerFactory.handleInverseQuery} logs the message origin
address if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Inverse query from ('::1', 53)"],
f.handleInverseQuery,
message=dns.Message(),
protocol=NoopProtocol(),
address=('::1', 53))
def test_handleStatus(self):
"""
L{server.DNSServerFactory.handleStatus} triggers the sending of a
response message with C{rCode} set to L{dns.ENOTIMP}.
"""
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.handleStatus,
message=dns.Message(), protocol=RaisingProtocol(), address=None)
(message,), kwargs = e.args
self.assertEqual(message.rCode, dns.ENOTIMP)
def test_handleStatusLogging(self):
"""
L{server.DNSServerFactory.handleStatus} logs the message origin address
if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Status request from ('::1', 53)"],
f.handleStatus,
message=dns.Message(),
protocol=NoopProtocol(),
address=('::1', 53))
def test_handleNotify(self):
"""
L{server.DNSServerFactory.handleNotify} triggers the sending of a
response message with C{rCode} set to L{dns.ENOTIMP}.
"""
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.handleNotify,
message=dns.Message(), protocol=RaisingProtocol(), address=None)
(message,), kwargs = e.args
self.assertEqual(message.rCode, dns.ENOTIMP)
def test_handleNotifyLogging(self):
"""
L{server.DNSServerFactory.handleNotify} logs the message origin address
if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Notify message from ('::1', 53)"],
f.handleNotify,
message=dns.Message(),
protocol=NoopProtocol(),
address=('::1', 53))
def test_handleOther(self):
"""
L{server.DNSServerFactory.handleOther} triggers the sending of a
response message with C{rCode} set to L{dns.ENOTIMP}.
"""
f = server.DNSServerFactory()
e = self.assertRaises(
RaisingProtocol.WriteMessageArguments,
f.handleOther,
message=dns.Message(), protocol=RaisingProtocol(), address=None)
(message,), kwargs = e.args
self.assertEqual(message.rCode, dns.ENOTIMP)
def test_handleOtherLogging(self):
"""
L{server.DNSServerFactory.handleOther} logs the message origin address
if C{verbose > 0}.
"""
f = NoResponseDNSServerFactory(verbose=1)
assertLogMessage(
self,
["Unknown op code (0) from ('::1', 53)"],
f.handleOther,
message=dns.Message(),
protocol=NoopProtocol(),
address=('::1', 53))