404

[ Avaa Bypassed ]




Upload:

Command:

botdev@3.147.78.134: ~ $
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Test cases for L{twisted.names.server}.
"""
from __future__ import division, absolute_import

from zope.interface.verify import verifyClass

from twisted.internet import defer
from twisted.internet.interfaces import IProtocolFactory
from twisted.names import dns, error, resolve, server
from twisted.python import failure, log
from twisted.trial import unittest



class RaisedArguments(Exception):
    """
    An exception containing the arguments raised by L{raiser}.
    """
    def __init__(self, args, kwargs):
        self.args = args
        self.kwargs = kwargs



def raiser(*args, **kwargs):
    """
    Raise a L{RaisedArguments} exception containing the supplied arguments.

    Used as a fake when testing the call signatures of  methods and functions.
    """
    raise RaisedArguments(args, kwargs)



class NoResponseDNSServerFactory(server.DNSServerFactory):
    """
    A L{server.DNSServerFactory} subclass which does not attempt to reply to any
    received messages.

    Used for testing logged messages in C{messageReceived} without having to
    fake or patch the preceding code which attempts to deliver a response
    message.
    """
    def allowQuery(self, message, protocol, address):
        """
        Deny all queries.

        @param message: See L{server.DNSServerFactory.allowQuery}
        @param protocol: See L{server.DNSServerFactory.allowQuery}
        @param address: See L{server.DNSServerFactory.allowQuery}

        @return: L{False}
        @rtype: L{bool}
        """
        return False


    def sendReply(self, protocol, message, address):
        """
        A noop send reply.

        @param protocol: See L{server.DNSServerFactory.sendReply}
        @param message: See L{server.DNSServerFactory.sendReply}
        @param address: See L{server.DNSServerFactory.sendReply}
        """



class RaisingDNSServerFactory(server.DNSServerFactory):
    """
    A L{server.DNSServerFactory} subclass whose methods raise an exception
    containing the supplied arguments.

    Used for stopping L{messageReceived} and testing the arguments supplied to
    L{allowQuery}.
    """

    class AllowQueryArguments(Exception):
        """
        Contains positional and keyword arguments in C{args}.
        """

    def allowQuery(self, *args, **kwargs):
        """
        Raise the arguments supplied to L{allowQuery}.

        @param args: Positional arguments which will be recorded in the raised
            exception.
        @type args: L{tuple}

        @param kwargs: Keyword args which will be recorded in the raised
            exception.
        @type kwargs: L{dict}
        """
        raise self.AllowQueryArguments(args, kwargs)



class RaisingProtocol(object):
    """
    A partial fake L{IProtocol} whose methods raise an exception containing the
    supplied arguments.
    """
    class WriteMessageArguments(Exception):
        """
        Contains positional and keyword arguments in C{args}.
        """

    def writeMessage(self, *args, **kwargs):
        """
        Raises the supplied arguments.

        @param args: Positional arguments
        @type args: L{tuple}

        @param kwargs: Keyword args
        @type kwargs: L{dict}
        """
        raise self.WriteMessageArguments(args, kwargs)



class NoopProtocol(object):
    """
    A partial fake L{dns.DNSProtocolMixin} with a noop L{writeMessage} method.
    """
    def writeMessage(self, *args, **kwargs):
        """
        A noop version of L{dns.DNSProtocolMixin.writeMessage}.

        @param args: Positional arguments
        @type args: L{tuple}

        @param kwargs: Keyword args
        @type kwargs: L{dict}
        """



class RaisingResolver(object):
    """
    A partial fake L{IResolver} whose methods raise an exception containing the
    supplied arguments.
    """
    class QueryArguments(Exception):
        """
        Contains positional and keyword arguments in C{args}.
        """


    def query(self, *args, **kwargs):
        """
        Raises the supplied arguments.

        @param args: Positional arguments
        @type args: L{tuple}

        @param kwargs: Keyword args
        @type kwargs: L{dict}
        """
        raise self.QueryArguments(args, kwargs)



class RaisingCache(object):
    """
    A partial fake L{twisted.names.cache.Cache} whose methods raise an exception
    containing the supplied arguments.
    """
    class CacheResultArguments(Exception):
        """
        Contains positional and keyword arguments in C{args}.
        """


    def cacheResult(self, *args, **kwargs):
        """
        Raises the supplied arguments.

        @param args: Positional arguments
        @type args: L{tuple}

        @param kwargs: Keyword args
        @type kwargs: L{dict}
        """
        raise self.CacheResultArguments(args, kwargs)



def assertLogMessage(testCase, expectedMessages, callable, *args, **kwargs):
    """
    Assert that the callable logs the expected messages when called.

    XXX: Put this somewhere where it can be re-used elsewhere. See #6677.

    @param testCase: The test case controlling the test which triggers the
        logged messages and on which assertions will be called.
    @type testCase: L{unittest.SynchronousTestCase}

    @param expectedMessages: A L{list} of the expected log messages
    @type expectedMessages: L{list}

    @param callable: The function which is expected to produce the
        C{expectedMessages} when called.
    @type callable: L{callable}

    @param args: Positional arguments to be passed to C{callable}.
    @type args: L{list}

    @param kwargs: Keyword arguments to be passed to C{callable}.
    @type kwargs: L{dict}
    """
    loggedMessages = []
    log.addObserver(loggedMessages.append)
    testCase.addCleanup(log.removeObserver, loggedMessages.append)

    callable(*args, **kwargs)

    testCase.assertEqual(
        [m['message'][0] for m in loggedMessages],
        expectedMessages)



class DNSServerFactoryTests(unittest.TestCase):
    """
    Tests for L{server.DNSServerFactory}.
    """
    def test_resolverType(self):
        """
        L{server.DNSServerFactory.resolver} is a L{resolve.ResolverChain}
        instance
        """
        self.assertIsInstance(
            server.DNSServerFactory().resolver,
            resolve.ResolverChain)


    def test_resolverDefaultEmpty(self):
        """
        L{server.DNSServerFactory.resolver} is an empty L{resolve.ResolverChain}
        by default.
        """
        self.assertEqual(
            server.DNSServerFactory().resolver.resolvers,
            [])


    def test_authorities(self):
        """
        L{server.DNSServerFactory.__init__} accepts an C{authorities}
        argument. The value of this argument is a list and is used to extend the
        C{resolver} L{resolve.ResolverChain}.
        """
        dummyResolver = object()
        self.assertEqual(
            server.DNSServerFactory(
                authorities=[dummyResolver]).resolver.resolvers,
            [dummyResolver])


    def test_caches(self):
        """
        L{server.DNSServerFactory.__init__} accepts a C{caches} argument. The
        value of this argument is a list and is used to extend the C{resolver}
        L{resolve.ResolverChain}.
        """
        dummyResolver = object()
        self.assertEqual(
            server.DNSServerFactory(
                caches=[dummyResolver]).resolver.resolvers,
            [dummyResolver])


    def test_clients(self):
        """
        L{server.DNSServerFactory.__init__} accepts a C{clients} argument. The
        value of this argument is a list and is used to extend the C{resolver}
        L{resolve.ResolverChain}.
        """
        dummyResolver = object()
        self.assertEqual(
            server.DNSServerFactory(
                clients=[dummyResolver]).resolver.resolvers,
            [dummyResolver])


    def test_resolverOrder(self):
        """
        L{server.DNSServerFactory.resolver} contains an ordered list of
        authorities, caches and clients.
        """
        # Use classes here so that we can see meaningful names in test results
        class DummyAuthority(object):
            pass

        class DummyCache(object):
            pass

        class DummyClient(object):
            pass

        self.assertEqual(
            server.DNSServerFactory(
                authorities=[DummyAuthority],
                caches=[DummyCache],
                clients=[DummyClient]).resolver.resolvers,
            [DummyAuthority, DummyCache, DummyClient])


    def test_cacheDefault(self):
        """
        L{server.DNSServerFactory.cache} is L{None} by default.
        """
        self.assertIsNone(server.DNSServerFactory().cache)


    def test_cacheOverride(self):
        """
        L{server.DNSServerFactory.__init__} assigns the last object in the
        C{caches} list to L{server.DNSServerFactory.cache}.
        """
        dummyResolver = object()
        self.assertEqual(
            server.DNSServerFactory(caches=[object(), dummyResolver]).cache,
            dummyResolver)


    def test_canRecurseDefault(self):
        """
        L{server.DNSServerFactory.canRecurse} is a flag indicating that this
        server is capable of performing recursive DNS lookups. It defaults to
        L{False}.
        """
        self.assertFalse(server.DNSServerFactory().canRecurse)


    def test_canRecurseOverride(self):
        """
        L{server.DNSServerFactory.__init__} sets C{canRecurse} to L{True} if it
        is supplied with C{clients}.
        """
        self.assertEqual(
            server.DNSServerFactory(clients=[None]).canRecurse, True)


    def test_verboseDefault(self):
        """
        L{server.DNSServerFactory.verbose} defaults to L{False}.
        """
        self.assertFalse(server.DNSServerFactory().verbose)


    def test_verboseOverride(self):
        """
        L{server.DNSServerFactory.__init__} accepts a C{verbose} argument which
        overrides L{server.DNSServerFactory.verbose}.
        """
        self.assertTrue(server.DNSServerFactory(verbose=True).verbose)


    def test_interface(self):
        """
        L{server.DNSServerFactory} implements L{IProtocolFactory}.
        """
        self.assertTrue(verifyClass(IProtocolFactory, server.DNSServerFactory))


    def test_defaultProtocol(self):
        """
        L{server.DNSServerFactory.protocol} defaults to L{dns.DNSProtocol}.
        """
        self.assertIs(server.DNSServerFactory.protocol, dns.DNSProtocol)


    def test_buildProtocolProtocolOverride(self):
        """
        L{server.DNSServerFactory.buildProtocol} builds a protocol by calling
        L{server.DNSServerFactory.protocol} with its self as a positional
        argument.
        """
        class FakeProtocol(object):
            factory = None
            args = None
            kwargs = None

        stubProtocol = FakeProtocol()

        def fakeProtocolFactory(*args, **kwargs):
            stubProtocol.args = args
            stubProtocol.kwargs = kwargs
            return stubProtocol

        f = server.DNSServerFactory()
        f.protocol = fakeProtocolFactory
        p = f.buildProtocol(addr=None)

        self.assertEqual(
            (stubProtocol, (f,), {}),
            (p, p.args, p.kwargs)
        )


    def test_verboseLogQuiet(self):
        """
        L{server.DNSServerFactory._verboseLog} does not log messages unless
        C{verbose > 0}.
        """
        f = server.DNSServerFactory()
        assertLogMessage(
            self,
            [],
            f._verboseLog,
            'Foo Bar'
        )


    def test_verboseLogVerbose(self):
        """
        L{server.DNSServerFactory._verboseLog} logs a message if C{verbose > 0}.
        """
        f = server.DNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ['Foo Bar'],
            f._verboseLog,
            'Foo Bar'
        )


    def test_messageReceivedLoggingNoQuery(self):
        """
        L{server.DNSServerFactory.messageReceived} logs about an empty query if
        the message had no queries and C{verbose} is C{>0}.
        """
        m = dns.Message()
        f = NoResponseDNSServerFactory(verbose=1)

        assertLogMessage(
            self,
            ["Empty query from ('192.0.2.100', 53)"],
            f.messageReceived,
            message=m, proto=None, address=('192.0.2.100', 53))


    def test_messageReceivedLogging1(self):
        """
        L{server.DNSServerFactory.messageReceived} logs the query types of all
        queries in the message if C{verbose} is set to C{1}.
        """
        m = dns.Message()
        m.addQuery(name='example.com', type=dns.MX)
        m.addQuery(name='example.com', type=dns.AAAA)
        f = NoResponseDNSServerFactory(verbose=1)

        assertLogMessage(
            self,
            ["MX AAAA query from ('192.0.2.100', 53)"],
            f.messageReceived,
            message=m, proto=None, address=('192.0.2.100', 53))


    def test_messageReceivedLogging2(self):
        """
        L{server.DNSServerFactory.messageReceived} logs the repr of all queries
        in the message if C{verbose} is set to C{2}.
        """
        m = dns.Message()
        m.addQuery(name='example.com', type=dns.MX)
        m.addQuery(name='example.com', type=dns.AAAA)
        f = NoResponseDNSServerFactory(verbose=2)

        assertLogMessage(
            self,
            ["<Query example.com MX IN> "
             "<Query example.com AAAA IN> query from ('192.0.2.100', 53)"],
            f.messageReceived,
            message=m, proto=None, address=('192.0.2.100', 53))


    def test_messageReceivedTimestamp(self):
        """
        L{server.DNSServerFactory.messageReceived} assigns a unix timestamp to
        the received message.
        """
        m = dns.Message()
        f = NoResponseDNSServerFactory()
        t = object()
        self.patch(server.time, 'time', lambda: t)
        f.messageReceived(message=m, proto=None, address=None)

        self.assertEqual(m.timeReceived, t)


    def test_messageReceivedAllowQuery(self):
        """
        L{server.DNSServerFactory.messageReceived} passes all messages to
        L{server.DNSServerFactory.allowQuery} along with the receiving protocol
        and origin address.
        """
        message = dns.Message()
        dummyProtocol = object()
        dummyAddress = object()

        f = RaisingDNSServerFactory()
        e = self.assertRaises(
            RaisingDNSServerFactory.AllowQueryArguments,
            f.messageReceived,
            message=message, proto=dummyProtocol, address=dummyAddress)
        args, kwargs = e.args
        self.assertEqual(args, (message, dummyProtocol, dummyAddress))
        self.assertEqual(kwargs, {})


    def test_allowQueryFalse(self):
        """
        If C{allowQuery} returns C{False},
        L{server.DNSServerFactory.messageReceived} calls L{server.sendReply}
        with a message whose C{rCode} is L{dns.EREFUSED}.
        """
        class SendReplyException(Exception):
            pass

        class RaisingDNSServerFactory(server.DNSServerFactory):
            def allowQuery(self, *args, **kwargs):
                return False

            def sendReply(self, *args, **kwargs):
                raise SendReplyException(args, kwargs)

        f = RaisingDNSServerFactory()
        e = self.assertRaises(
            SendReplyException,
            f.messageReceived,
            message=dns.Message(), proto=None, address=None)
        (proto, message, address), kwargs = e.args

        self.assertEqual(message.rCode, dns.EREFUSED)


    def _messageReceivedTest(self, methodName, message):
        """
        Assert that the named method is called with the given message when it is
        passed to L{DNSServerFactory.messageReceived}.

        @param methodName: The name of the method which is expected to be
            called.
        @type methodName: L{str}

        @param message: The message which is expected to be passed to the
            C{methodName} method.
        @type message: L{dns.Message}
        """
        # Make it appear to have some queries so that
        # DNSServerFactory.allowQuery allows it.
        message.queries = [None]

        receivedMessages = []
        def fakeHandler(message, protocol, address):
            receivedMessages.append((message, protocol, address))

        protocol = NoopProtocol()
        factory = server.DNSServerFactory(None)
        setattr(factory, methodName, fakeHandler)
        factory.messageReceived(message, protocol)
        self.assertEqual(receivedMessages, [(message, protocol, None)])


    def test_queryMessageReceived(self):
        """
        L{DNSServerFactory.messageReceived} passes messages with an opcode of
        C{OP_QUERY} on to L{DNSServerFactory.handleQuery}.
        """
        self._messageReceivedTest(
            'handleQuery', dns.Message(opCode=dns.OP_QUERY))


    def test_inverseQueryMessageReceived(self):
        """
        L{DNSServerFactory.messageReceived} passes messages with an opcode of
        C{OP_INVERSE} on to L{DNSServerFactory.handleInverseQuery}.
        """
        self._messageReceivedTest(
            'handleInverseQuery', dns.Message(opCode=dns.OP_INVERSE))


    def test_statusMessageReceived(self):
        """
        L{DNSServerFactory.messageReceived} passes messages with an opcode of
        C{OP_STATUS} on to L{DNSServerFactory.handleStatus}.
        """
        self._messageReceivedTest(
            'handleStatus', dns.Message(opCode=dns.OP_STATUS))


    def test_notifyMessageReceived(self):
        """
        L{DNSServerFactory.messageReceived} passes messages with an opcode of
        C{OP_NOTIFY} on to L{DNSServerFactory.handleNotify}.
        """
        self._messageReceivedTest(
            'handleNotify', dns.Message(opCode=dns.OP_NOTIFY))


    def test_updateMessageReceived(self):
        """
        L{DNSServerFactory.messageReceived} passes messages with an opcode of
        C{OP_UPDATE} on to L{DNSServerFactory.handleOther}.

        This may change if the implementation ever covers update messages.
        """
        self._messageReceivedTest(
            'handleOther', dns.Message(opCode=dns.OP_UPDATE))


    def test_connectionTracking(self):
        """
        The C{connectionMade} and C{connectionLost} methods of
        L{DNSServerFactory} cooperate to keep track of all L{DNSProtocol}
        objects created by a factory which are connected.
        """
        protoA, protoB = object(), object()
        factory = server.DNSServerFactory()
        factory.connectionMade(protoA)
        self.assertEqual(factory.connections, [protoA])
        factory.connectionMade(protoB)
        self.assertEqual(factory.connections, [protoA, protoB])
        factory.connectionLost(protoA)
        self.assertEqual(factory.connections, [protoB])
        factory.connectionLost(protoB)
        self.assertEqual(factory.connections, [])


    def test_handleQuery(self):
        """
        L{server.DNSServerFactory.handleQuery} takes the first query from the
        supplied message and dispatches it to
        L{server.DNSServerFactory.resolver.query}.
        """
        m = dns.Message()
        m.addQuery(b'one.example.com')
        m.addQuery(b'two.example.com')
        f = server.DNSServerFactory()
        f.resolver = RaisingResolver()

        e = self.assertRaises(
            RaisingResolver.QueryArguments,
            f.handleQuery,
            message=m, protocol=NoopProtocol(), address=None)
        (query,), kwargs = e.args
        self.assertEqual(query, m.queries[0])


    def test_handleQueryCallback(self):
        """
        L{server.DNSServerFactory.handleQuery} adds
        L{server.DNSServerFactory.resolver.gotResolverResponse} as a callback to
        the deferred returned by L{server.DNSServerFactory.resolver.query}. It
        is called with the query response, the original protocol, message and
        origin address.
        """
        f = server.DNSServerFactory()

        d = defer.Deferred()
        class FakeResolver(object):
            def query(self, *args, **kwargs):
                return d
        f.resolver = FakeResolver()

        gotResolverResponseArgs = []
        def fakeGotResolverResponse(*args, **kwargs):
            gotResolverResponseArgs.append((args, kwargs))
        f.gotResolverResponse = fakeGotResolverResponse

        m = dns.Message()
        m.addQuery(b'one.example.com')
        stubProtocol = NoopProtocol()
        dummyAddress = object()

        f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress)

        dummyResponse = object()
        d.callback(dummyResponse)

        self.assertEqual(
            gotResolverResponseArgs,
            [((dummyResponse, stubProtocol, m, dummyAddress), {})])


    def test_handleQueryErrback(self):
        """
        L{server.DNSServerFactory.handleQuery} adds
        L{server.DNSServerFactory.resolver.gotResolverError} as an errback to
        the deferred returned by L{server.DNSServerFactory.resolver.query}. It
        is called with the query failure, the original protocol, message and
        origin address.
        """
        f = server.DNSServerFactory()

        d = defer.Deferred()
        class FakeResolver(object):
            def query(self, *args, **kwargs):
                return d
        f.resolver = FakeResolver()

        gotResolverErrorArgs = []
        def fakeGotResolverError(*args, **kwargs):
            gotResolverErrorArgs.append((args, kwargs))
        f.gotResolverError = fakeGotResolverError

        m = dns.Message()
        m.addQuery(b'one.example.com')
        stubProtocol = NoopProtocol()
        dummyAddress = object()

        f.handleQuery(message=m, protocol=stubProtocol, address=dummyAddress)

        stubFailure = failure.Failure(Exception())
        d.errback(stubFailure)

        self.assertEqual(
            gotResolverErrorArgs,
            [((stubFailure, stubProtocol, m, dummyAddress), {})])


    def test_gotResolverResponse(self):
        """
        L{server.DNSServerFactory.gotResolverResponse} accepts a tuple of
        resource record lists and triggers a response message containing those
        resource record lists.
        """
        f = server.DNSServerFactory()
        answers = []
        authority = []
        additional = []
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.gotResolverResponse,
            (answers, authority, additional),
            protocol=RaisingProtocol(), message=dns.Message(), address=None)
        (message,), kwargs = e.args

        self.assertIs(message.answers, answers)
        self.assertIs(message.authority, authority)
        self.assertIs(message.additional, additional)


    def test_gotResolverResponseCallsResponseFromMessage(self):
        """
        L{server.DNSServerFactory.gotResolverResponse} calls
        L{server.DNSServerFactory._responseFromMessage} to generate a response.
        """
        factory = NoResponseDNSServerFactory()
        factory._responseFromMessage = raiser

        request = dns.Message()
        request.timeReceived = 1

        e = self.assertRaises(
            RaisedArguments,
            factory.gotResolverResponse,
            ([], [], []),
            protocol=None, message=request, address=None
        )
        self.assertEqual(
            ((), dict(message=request, rCode=dns.OK,
                      answers=[], authority=[], additional=[])),
            (e.args, e.kwargs)
        )


    def test_responseFromMessageNewMessage(self):
        """
        L{server.DNSServerFactory._responseFromMessage} generates a response
        message which is a copy of the request message.
        """
        factory = server.DNSServerFactory()
        request = dns.Message(answer=False, recAv=False)
        response = factory._responseFromMessage(message=request),

        self.assertIsNot(request, response)


    def test_responseFromMessageRecursionAvailable(self):
        """
        L{server.DNSServerFactory._responseFromMessage} generates a response
        message whose C{recAV} attribute is L{True} if
        L{server.DNSServerFactory.canRecurse} is L{True}.
        """
        factory = server.DNSServerFactory()
        factory.canRecurse = True
        response1 = factory._responseFromMessage(
            message=dns.Message(recAv=False))
        factory.canRecurse = False
        response2 = factory._responseFromMessage(
            message=dns.Message(recAv=True))
        self.assertEqual(
            (True, False),
            (response1.recAv, response2.recAv))


    def test_responseFromMessageTimeReceived(self):
        """
        L{server.DNSServerFactory._responseFromMessage} generates a response
        message whose C{timeReceived} attribute has the same value as that found
        on the request.
        """
        factory = server.DNSServerFactory()
        request = dns.Message()
        request.timeReceived = 1234
        response = factory._responseFromMessage(message=request)

        self.assertEqual(request.timeReceived, response.timeReceived)


    def test_responseFromMessageMaxSize(self):
        """
        L{server.DNSServerFactory._responseFromMessage} generates a response
        message whose C{maxSize} attribute has the same value as that found
        on the request.
        """
        factory = server.DNSServerFactory()
        request = dns.Message()
        request.maxSize = 0
        response = factory._responseFromMessage(message=request)

        self.assertEqual(request.maxSize, response.maxSize)


    def test_messageFactory(self):
        """
        L{server.DNSServerFactory} has a C{_messageFactory} attribute which is
        L{dns.Message} by default.
        """
        self.assertIs(dns.Message, server.DNSServerFactory._messageFactory)


    def test_responseFromMessageCallsMessageFactory(self):
        """
        L{server.DNSServerFactory._responseFromMessage} calls
        C{dns._responseFromMessage} to generate a response
        message from the request message. It supplies the request message and
        other keyword arguments which should be passed to the response message
        initialiser.
        """
        factory = server.DNSServerFactory()
        self.patch(dns, '_responseFromMessage', raiser)

        request = dns.Message()
        e = self.assertRaises(
            RaisedArguments,
            factory._responseFromMessage,
            message=request, rCode=dns.OK
        )
        self.assertEqual(
            ((), dict(responseConstructor=factory._messageFactory,
                      message=request, rCode=dns.OK, recAv=factory.canRecurse,
                      auth=False)),
            (e.args, e.kwargs)
        )


    def test_responseFromMessageAuthoritativeMessage(self):
        """
        L{server.DNSServerFactory._responseFromMessage} marks the response
        message as authoritative if any of the answer records are authoritative.
        """
        factory = server.DNSServerFactory()
        response1 = factory._responseFromMessage(
            message=dns.Message(), answers=[dns.RRHeader(auth=True)])
        response2 = factory._responseFromMessage(
            message=dns.Message(), answers=[dns.RRHeader(auth=False)])
        self.assertEqual(
            (True, False),
            (response1.auth, response2.auth),
        )


    def test_gotResolverResponseLogging(self):
        """
        L{server.DNSServerFactory.gotResolverResponse} logs the total number of
        records in the response if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        answers = [dns.RRHeader()]
        authority = [dns.RRHeader()]
        additional = [dns.RRHeader()]

        assertLogMessage(
            self,
            ["Lookup found 3 records"],
            f.gotResolverResponse,
            (answers, authority, additional),
            protocol=NoopProtocol(), message=dns.Message(), address=None)


    def test_gotResolverResponseCaching(self):
        """
        L{server.DNSServerFactory.gotResolverResponse} caches the response if at
        least one cache was provided in the constructor.
        """
        f = NoResponseDNSServerFactory(caches=[RaisingCache()])

        m = dns.Message()
        m.addQuery(b'example.com')
        expectedAnswers = [dns.RRHeader()]
        expectedAuthority = []
        expectedAdditional = []

        e = self.assertRaises(
            RaisingCache.CacheResultArguments,
            f.gotResolverResponse,
            (expectedAnswers, expectedAuthority, expectedAdditional),
            protocol=NoopProtocol(), message=m, address=None)
        (query, (answers, authority, additional)), kwargs = e.args

        self.assertEqual(query.name.name, b'example.com')
        self.assertIs(answers, expectedAnswers)
        self.assertIs(authority, expectedAuthority)
        self.assertIs(additional, expectedAdditional)


    def test_gotResolverErrorCallsResponseFromMessage(self):
        """
        L{server.DNSServerFactory.gotResolverError} calls
        L{server.DNSServerFactory._responseFromMessage} to generate a response.
        """
        factory = NoResponseDNSServerFactory()
        factory._responseFromMessage = raiser

        request = dns.Message()
        request.timeReceived = 1

        e = self.assertRaises(
            RaisedArguments,
            factory.gotResolverError,
            failure.Failure(error.DomainError()),
            protocol=None, message=request, address=None
        )
        self.assertEqual(
            ((), dict(message=request, rCode=dns.ENAME)),
            (e.args, e.kwargs)
        )


    def _assertMessageRcodeForError(self, responseError, expectedMessageCode):
        """
        L{server.DNSServerFactory.gotResolver} accepts a L{failure.Failure} and
        triggers a response message whose rCode corresponds to the DNS error
        contained in the C{Failure}.

        @param responseError: The L{Exception} instance which is expected to
            trigger C{expectedMessageCode} when it is supplied to
            C{gotResolverError}
        @type responseError: L{Exception}

        @param expectedMessageCode: The C{rCode} which is expected in the
            message returned by C{gotResolverError} in response to
            C{responseError}.
        @type expectedMessageCode: L{int}
        """
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.gotResolverError,
            failure.Failure(responseError),
            protocol=RaisingProtocol(), message=dns.Message(), address=None)
        (message,), kwargs = e.args

        self.assertEqual(message.rCode, expectedMessageCode)


    def test_gotResolverErrorDomainError(self):
        """
        L{server.DNSServerFactory.gotResolver} triggers a response message with
        an C{rCode} of L{dns.ENAME} if supplied with a L{error.DomainError}.
        """
        self._assertMessageRcodeForError(error.DomainError(), dns.ENAME)


    def test_gotResolverErrorAuthoritativeDomainError(self):
        """
        L{server.DNSServerFactory.gotResolver} triggers a response message with
        an C{rCode} of L{dns.ENAME} if supplied with a
        L{error.AuthoritativeDomainError}.
        """
        self._assertMessageRcodeForError(
            error.AuthoritativeDomainError(), dns.ENAME)


    def test_gotResolverErrorOtherError(self):
        """
        L{server.DNSServerFactory.gotResolver} triggers a response message with
        an C{rCode} of L{dns.ESERVER} if supplied with another type of error and
        logs the error.
        """
        self._assertMessageRcodeForError(KeyError(), dns.ESERVER)
        e = self.flushLoggedErrors(KeyError)
        self.assertEqual(len(e), 1)


    def test_gotResolverErrorLogging(self):
        """
        L{server.DNSServerFactory.gotResolver} logs a message if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ["Lookup failed"],
            f.gotResolverError,
            failure.Failure(error.DomainError()),
            protocol=NoopProtocol(), message=dns.Message(), address=None)


    def test_gotResolverErrorResetsResponseAttributes(self):
        """
        L{server.DNSServerFactory.gotResolverError} does not allow request
        attributes to leak into the response ie it sends a response with AD, CD
        set to 0 and empty response record sections.
        """
        factory = server.DNSServerFactory()
        responses = []
        factory.sendReply = (
            lambda protocol, response, address: responses.append(response)
        )
        request = dns.Message(authenticData=True, checkingDisabled=True)
        request.answers = [object(), object()]
        request.authority = [object(), object()]
        request.additional = [object(), object()]
        factory.gotResolverError(
            failure.Failure(error.DomainError()),
            protocol=None, message=request, address=None
        )

        self.assertEqual([dns.Message(rCode=3, answer=True)], responses)


    def test_gotResolverResponseResetsResponseAttributes(self):
        """
        L{server.DNSServerFactory.gotResolverResponse} does not allow request
        attributes to leak into the response ie it sends a response with AD, CD
        set to 0 and none of the records in the request answer sections are
        copied to the response.
        """
        factory = server.DNSServerFactory()
        responses = []
        factory.sendReply = (
            lambda protocol, response, address: responses.append(response)
        )
        request = dns.Message(authenticData=True, checkingDisabled=True)
        request.answers = [object(), object()]
        request.authority = [object(), object()]
        request.additional = [object(), object()]

        factory.gotResolverResponse(
            ([], [], []),
            protocol=None, message=request, address=None
        )

        self.assertEqual([dns.Message(rCode=0, answer=True)], responses)


    def test_sendReplyWithAddress(self):
        """
        If L{server.DNSServerFactory.sendReply} is supplied with a protocol
        *and* an address tuple it will supply that address to
        C{protocol.writeMessage}.
        """
        m = dns.Message()
        dummyAddress = object()
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.sendReply,
            protocol=RaisingProtocol(),
            message=m,
            address=dummyAddress)
        args, kwargs = e.args
        self.assertEqual(args, (m, dummyAddress))
        self.assertEqual(kwargs, {})


    def test_sendReplyWithoutAddress(self):
        """
        If L{server.DNSServerFactory.sendReply} is supplied with a protocol but
        no address tuple it will supply only a message to
        C{protocol.writeMessage}.
        """
        m = dns.Message()
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.sendReply,
            protocol=RaisingProtocol(),
            message=m,
            address=None)
        args, kwargs = e.args
        self.assertEqual(args, (m,))
        self.assertEqual(kwargs, {})


    def test_sendReplyLoggingNoAnswers(self):
        """
        If L{server.DNSServerFactory.sendReply} logs a "no answers" message if
        the supplied message has no answers.
        """
        self.patch(server.time, 'time', lambda: 86402)
        m = dns.Message()
        m.timeReceived = 86401
        f = server.DNSServerFactory(verbose=2)
        assertLogMessage(
            self,
            ["Replying with no answers", "Processed query in 1.000 seconds"],
            f.sendReply,
            protocol=NoopProtocol(),
            message=m,
            address=None)


    def test_sendReplyLoggingWithAnswers(self):
        """
        If L{server.DNSServerFactory.sendReply} logs a message for answers,
        authority, additional if the supplied a message has records in any of
        those sections.
        """
        self.patch(server.time, 'time', lambda: 86402)
        m = dns.Message()
        m.answers.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
        m.authority.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
        m.additional.append(dns.RRHeader(payload=dns.Record_A('127.0.0.1')))
        m.timeReceived = 86401
        f = server.DNSServerFactory(verbose=2)
        assertLogMessage(
            self,
            ['Answers are <A address=127.0.0.1 ttl=None>',
             'Authority is <A address=127.0.0.1 ttl=None>',
             'Additional is <A address=127.0.0.1 ttl=None>',
             'Processed query in 1.000 seconds'],
            f.sendReply,
            protocol=NoopProtocol(),
            message=m,
            address=None)


    def test_handleInverseQuery(self):
        """
        L{server.DNSServerFactory.handleInverseQuery} triggers the sending of a
        response message with C{rCode} set to L{dns.ENOTIMP}.
        """
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.handleInverseQuery,
            message=dns.Message(), protocol=RaisingProtocol(), address=None)
        (message,), kwargs = e.args

        self.assertEqual(message.rCode, dns.ENOTIMP)


    def test_handleInverseQueryLogging(self):
        """
        L{server.DNSServerFactory.handleInverseQuery} logs the message origin
        address if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ["Inverse query from ('::1', 53)"],
            f.handleInverseQuery,
            message=dns.Message(),
            protocol=NoopProtocol(),
            address=('::1', 53))


    def test_handleStatus(self):
        """
        L{server.DNSServerFactory.handleStatus} triggers the sending of a
        response message with C{rCode} set to L{dns.ENOTIMP}.
        """
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.handleStatus,
            message=dns.Message(), protocol=RaisingProtocol(), address=None)
        (message,), kwargs = e.args

        self.assertEqual(message.rCode, dns.ENOTIMP)


    def test_handleStatusLogging(self):
        """
        L{server.DNSServerFactory.handleStatus} logs the message origin address
        if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ["Status request from ('::1', 53)"],
            f.handleStatus,
            message=dns.Message(),
            protocol=NoopProtocol(),
            address=('::1', 53))


    def test_handleNotify(self):
        """
        L{server.DNSServerFactory.handleNotify} triggers the sending of a
        response message with C{rCode} set to L{dns.ENOTIMP}.
        """
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.handleNotify,
            message=dns.Message(), protocol=RaisingProtocol(), address=None)
        (message,), kwargs = e.args

        self.assertEqual(message.rCode, dns.ENOTIMP)


    def test_handleNotifyLogging(self):
        """
        L{server.DNSServerFactory.handleNotify} logs the message origin address
        if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ["Notify message from ('::1', 53)"],
            f.handleNotify,
            message=dns.Message(),
            protocol=NoopProtocol(),
            address=('::1', 53))


    def test_handleOther(self):
        """
        L{server.DNSServerFactory.handleOther} triggers the sending of a
        response message with C{rCode} set to L{dns.ENOTIMP}.
        """
        f = server.DNSServerFactory()
        e = self.assertRaises(
            RaisingProtocol.WriteMessageArguments,
            f.handleOther,
            message=dns.Message(), protocol=RaisingProtocol(), address=None)
        (message,), kwargs = e.args

        self.assertEqual(message.rCode, dns.ENOTIMP)


    def test_handleOtherLogging(self):
        """
        L{server.DNSServerFactory.handleOther} logs the message origin address
        if C{verbose > 0}.
        """
        f = NoResponseDNSServerFactory(verbose=1)
        assertLogMessage(
            self,
            ["Unknown op code (0) from ('::1', 53)"],
            f.handleOther,
            message=dns.Message(),
            protocol=NoopProtocol(),
            address=('::1', 53))

Filemanager

Name Type Size Permission Actions
__pycache__ Folder 0755
__init__.py File 26 B 0644
test_cache.py File 4.8 KB 0644
test_client.py File 39.31 KB 0644
test_common.py File 4.11 KB 0644
test_dns.py File 150.45 KB 0644
test_examples.py File 5.22 KB 0644
test_hosts.py File 8.35 KB 0644
test_names.py File 42.09 KB 0644
test_resolve.py File 1.05 KB 0644
test_rfc1982.py File 13.66 KB 0644
test_rootresolve.py File 25.04 KB 0644
test_server.py File 40.27 KB 0644
test_srvconnect.py File 9.89 KB 0644
test_tap.py File 4.84 KB 0644
test_util.py File 3.82 KB 0644