Current File : //usr/local/lib/python3.6/site-packages/dns/asyncquery.py
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

"""Talk to a DNS server."""

import base64
import socket
import struct
import time

import dns.asyncbackend
import dns.exception
import dns.inet
import dns.name
import dns.message
import dns.rcode
import dns.rdataclass
import dns.rdatatype

from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \
    UDPMode, _have_httpx, _have_http2, NoDOH

if _have_httpx:
    import httpx

# for brevity
_lltuple = dns.inet.low_level_address_tuple


def _source_tuple(af, address, port):
    # Make a high level source tuple, or return None if address and port
    # are both None
    if address or port:
        if address is None:
            if af == socket.AF_INET:
                address = '0.0.0.0'
            elif af == socket.AF_INET6:
                address = '::'
            else:
                raise NotImplementedError(f'unknown address family {af}')
        return (address, port)
    else:
        return None


def _timeout(expiration, now=None):
    if expiration:
        if not now:
            now = time.time()
        return max(expiration - now, 0)
    else:
        return None


async def send_udp(sock, what, destination, expiration=None):
    """Send a DNS message to the specified UDP socket.

    *sock*, a ``dns.asyncbackend.DatagramSocket``.

    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.

    *destination*, a destination tuple appropriate for the address family
    of the socket, specifying where to send the query.

    *expiration*, a ``float`` or ``None``, the absolute time at which
    a timeout exception should be raised.  If ``None``, no timeout will
    occur.

    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
    """

    if isinstance(what, dns.message.Message):
        what = what.to_wire()
    sent_time = time.time()
    n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
    return (n, sent_time)


async def receive_udp(sock, destination=None, expiration=None,
                      ignore_unexpected=False, one_rr_per_rrset=False,
                      keyring=None, request_mac=b'', ignore_trailing=False,
                      raise_on_truncation=False):
    """Read a DNS message from a UDP socket.

    *sock*, a ``dns.asyncbackend.DatagramSocket``.

    See :py:func:`dns.query.receive_udp()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """

    wire = b''
    while 1:
        (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration))
        if _matches_destination(sock.family, from_address, destination,
                                ignore_unexpected):
            break
    received_time = time.time()
    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                              one_rr_per_rrset=one_rr_per_rrset,
                              ignore_trailing=ignore_trailing,
                              raise_on_truncation=raise_on_truncation)
    return (r, received_time, from_address)

async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
              ignore_unexpected=False, one_rr_per_rrset=False,
              ignore_trailing=False, raise_on_truncation=False, sock=None,
              backend=None):
    """Return the response obtained after sending a query via UDP.

    *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
    the socket to use for the query.  If ``None``, the default, a
    socket is created.  Note that if a socket is provided, the
    *source*, *source_port*, and *backend* are ignored.

    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
    the default, then dnspython will use the default backend.

    See :py:func:`dns.query.udp()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """
    wire = q.to_wire()
    (begin_time, expiration) = _compute_times(timeout)
    s = None
    # After 3.6 is no longer supported, this can use an AsyncExitStack.
    try:
        af = dns.inet.af_for_address(where)
        destination = _lltuple((where, port), af)
        if sock:
            s = sock
        else:
            if not backend:
                backend = dns.asyncbackend.get_default_backend()
            stuple = _source_tuple(af, source, source_port)
            if backend.datagram_connection_required():
                dtuple = (where, port)
            else:
                dtuple = None
            s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
                                          dtuple)
        await send_udp(s, wire, destination, expiration)
        (r, received_time, _) = await receive_udp(s, destination, expiration,
                                                  ignore_unexpected,
                                                  one_rr_per_rrset,
                                                  q.keyring, q.mac,
                                                  ignore_trailing,
                                                  raise_on_truncation)
        r.time = received_time - begin_time
        if not q.is_response(r):
            raise BadResponse
        return r
    finally:
        if not sock and s:
            await s.close()

async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
                            source_port=0, ignore_unexpected=False,
                            one_rr_per_rrset=False, ignore_trailing=False,
                            udp_sock=None, tcp_sock=None, backend=None):
    """Return the response to the query, trying UDP first and falling back
    to TCP if UDP results in a truncated response.

    *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
    the socket to use for the UDP query.  If ``None``, the default, a
    socket is created.  Note that if a socket is provided the *source*,
    *source_port*, and *backend* are ignored for the UDP query.

    *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
    socket to use for the TCP query.  If ``None``, the default, a
    socket is created.  Note that if a socket is provided *where*,
    *source*, *source_port*, and *backend*  are ignored for the TCP query.

    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
    the default, then dnspython will use the default backend.

    See :py:func:`dns.query.udp_with_fallback()` for the documentation
    of the other parameters, exceptions, and return type of this
    method.
    """
    try:
        response = await udp(q, where, timeout, port, source, source_port,
                             ignore_unexpected, one_rr_per_rrset,
                             ignore_trailing, True, udp_sock, backend)
        return (response, False)
    except dns.message.Truncated:
        response = await tcp(q, where, timeout, port, source, source_port,
                             one_rr_per_rrset, ignore_trailing, tcp_sock,
                             backend)
        return (response, True)


async def send_tcp(sock, what, expiration=None):
    """Send a DNS message to the specified TCP socket.

    *sock*, a ``dns.asyncbackend.StreamSocket``.

    See :py:func:`dns.query.send_tcp()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """

    if isinstance(what, dns.message.Message):
        what = what.to_wire()
    l = len(what)
    # copying the wire into tcpmsg is inefficient, but lets us
    # avoid writev() or doing a short write that would get pushed
    # onto the net
    tcpmsg = struct.pack("!H", l) + what
    sent_time = time.time()
    await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
    return (len(tcpmsg), sent_time)


async def _read_exactly(sock, count, expiration):
    """Read the specified number of bytes from stream.  Keep trying until we
    either get the desired amount, or we hit EOF.
    """
    s = b''
    while count > 0:
        n = await sock.recv(count, _timeout(expiration))
        if n == b'':
            raise EOFError
        count = count - len(n)
        s = s + n
    return s


async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
                      keyring=None, request_mac=b'', ignore_trailing=False):
    """Read a DNS message from a TCP socket.

    *sock*, a ``dns.asyncbackend.StreamSocket``.

    See :py:func:`dns.query.receive_tcp()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """

    ldata = await _read_exactly(sock, 2, expiration)
    (l,) = struct.unpack("!H", ldata)
    wire = await _read_exactly(sock, l, expiration)
    received_time = time.time()
    r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
                              one_rr_per_rrset=one_rr_per_rrset,
                              ignore_trailing=ignore_trailing)
    return (r, received_time)


async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
              backend=None):
    """Return the response obtained after sending a query via TCP.

    *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
    socket to use for the query.  If ``None``, the default, a socket
    is created.  Note that if a socket is provided
    *where*, *port*, *source*, *source_port*, and *backend* are ignored.

    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
    the default, then dnspython will use the default backend.

    See :py:func:`dns.query.tcp()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """

    wire = q.to_wire()
    (begin_time, expiration) = _compute_times(timeout)
    s = None
    # After 3.6 is no longer supported, this can use an AsyncExitStack.
    try:
        if sock:
            # Verify that the socket is connected, as if it's not connected,
            # it's not writable, and the polling in send_tcp() will time out or
            # hang forever.
            await sock.getpeername()
            s = sock
        else:
            # These are simple (address, port) pairs, not
            # family-dependent tuples you pass to lowlevel socket
            # code.
            af = dns.inet.af_for_address(where)
            stuple = _source_tuple(af, source, source_port)
            dtuple = (where, port)
            if not backend:
                backend = dns.asyncbackend.get_default_backend()
            s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
                                          dtuple, timeout)
        await send_tcp(s, wire, expiration)
        (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
                                               q.keyring, q.mac,
                                               ignore_trailing)
        r.time = received_time - begin_time
        if not q.is_response(r):
            raise BadResponse
        return r
    finally:
        if not sock and s:
            await s.close()

async def tls(q, where, timeout=None, port=853, source=None, source_port=0,
              one_rr_per_rrset=False, ignore_trailing=False, sock=None,
              backend=None, ssl_context=None, server_hostname=None):
    """Return the response obtained after sending a query via TLS.

    *sock*, an ``asyncbackend.StreamSocket``, or ``None``, the socket
    to use for the query.  If ``None``, the default, a socket is
    created.  Note that if a socket is provided, it must be a
    connected SSL stream socket, and *where*, *port*,
    *source*, *source_port*, *backend*, *ssl_context*, and *server_hostname*
    are ignored.

    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
    the default, then dnspython will use the default backend.

    See :py:func:`dns.query.tls()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """
    # After 3.6 is no longer supported, this can use an AsyncExitStack.
    (begin_time, expiration) = _compute_times(timeout)
    if not sock:
        if ssl_context is None:
            ssl_context = ssl.create_default_context()
            if server_hostname is None:
                ssl_context.check_hostname = False
        else:
            ssl_context = None
            server_hostname = None
        af = dns.inet.af_for_address(where)
        stuple = _source_tuple(af, source, source_port)
        dtuple = (where, port)
        if not backend:
            backend = dns.asyncbackend.get_default_backend()
        s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
                                      dtuple, timeout, ssl_context,
                                      server_hostname)
    else:
        s = sock
    try:
        timeout = _timeout(expiration)
        response = await tcp(q, where, timeout, port, source, source_port,
                             one_rr_per_rrset, ignore_trailing, s, backend)
        end_time = time.time()
        response.time = end_time - begin_time
        return response
    finally:
        if not sock and s:
            await s.close()

async def https(q, where, timeout=None, port=443, source=None, source_port=0,
                one_rr_per_rrset=False, ignore_trailing=False, client=None,
                path='/dns-query', post=True, verify=True):
    """Return the response obtained after sending a query via DNS-over-HTTPS.

    *client*, a ``httpx.AsyncClient``.  If provided, the client to use for
    the query.

    Unlike the other dnspython async functions, a backend cannot be provided
    in this function because httpx always auto-detects the async backend.

    See :py:func:`dns.query.https()` for the documentation of the other
    parameters, exceptions, and return type of this method.
    """

    if not _have_httpx:
        raise NoDOH('httpx is not available.')  # pragma: no cover

    wire = q.to_wire()
    try:
        af = dns.inet.af_for_address(where)
    except ValueError:
        af = None
    transport = None
    headers = {
        "accept": "application/dns-message"
    }
    if af is not None:
        if af == socket.AF_INET:
            url = 'https://{}:{}{}'.format(where, port, path)
        elif af == socket.AF_INET6:
            url = 'https://[{}]:{}{}'.format(where, port, path)
    else:
        url = where
    if source is not None:
        transport = httpx.AsyncHTTPTransport(local_address=source[0])

    # After 3.6 is no longer supported, this can use an AsyncExitStack
    client_to_close = None
    try:
        if not client:
            client = httpx.AsyncClient(http1=True, http2=_have_http2,
                                       verify=verify, transport=transport)
            client_to_close = client

        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
        # GET and POST examples
        if post:
            headers.update({
                "content-type": "application/dns-message",
                "content-length": str(len(wire))
            })
            response = await client.post(url, headers=headers, content=wire,
                                         timeout=timeout)
        else:
            wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
            wire = wire.decode()  # httpx does a repr() if we give it bytes
            response = await client.get(url, headers=headers, timeout=timeout,
                                        params={"dns": wire})
    finally:
        if client_to_close:
            await client.aclose()

    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
    # status codes
    if response.status_code < 200 or response.status_code > 299:
        raise ValueError('{} responded with status code {}'
                         '\nResponse body: {}'.format(where,
                                                      response.status_code,
                                                      response.content))
    r = dns.message.from_wire(response.content,
                              keyring=q.keyring,
                              request_mac=q.request_mac,
                              one_rr_per_rrset=one_rr_per_rrset,
                              ignore_trailing=ignore_trailing)
    r.time = response.elapsed
    if not q.is_response(r):
        raise BadResponse
    return r

async def inbound_xfr(where, txn_manager, query=None,
                      port=53, timeout=None, lifetime=None, source=None,
                      source_port=0, udp_mode=UDPMode.NEVER, backend=None):
    """Conduct an inbound transfer and apply it via a transaction from the
    txn_manager.

    *backend*, a ``dns.asyncbackend.Backend``, or ``None``.  If ``None``,
    the default, then dnspython will use the default backend.

    See :py:func:`dns.query.inbound_xfr()` for the documentation of
    the other parameters, exceptions, and return type of this method.
    """
    if query is None:
        (query, serial) = dns.xfr.make_query(txn_manager)
    else:
        serial = dns.xfr.extract_serial_from_query(query)
    rdtype = query.question[0].rdtype
    is_ixfr = rdtype == dns.rdatatype.IXFR
    origin = txn_manager.from_wire_origin()
    wire = query.to_wire()
    af = dns.inet.af_for_address(where)
    stuple = _source_tuple(af, source, source_port)
    dtuple = (where, port)
    (_, expiration) = _compute_times(lifetime)
    retry = True
    while retry:
        retry = False
        if is_ixfr and udp_mode != UDPMode.NEVER:
            sock_type = socket.SOCK_DGRAM
            is_udp = True
        else:
            sock_type = socket.SOCK_STREAM
            is_udp = False
        if not backend:
            backend = dns.asyncbackend.get_default_backend()
        s = await backend.make_socket(af, sock_type, 0, stuple, dtuple,
                                      _timeout(expiration))
        async with s:
            if is_udp:
                await s.sendto(wire, dtuple, _timeout(expiration))
            else:
                tcpmsg = struct.pack("!H", len(wire)) + wire
                await s.sendall(tcpmsg, expiration)
            with dns.xfr.Inbound(txn_manager, rdtype, serial,
                                 is_udp) as inbound:
                done = False
                tsig_ctx = None
                while not done:
                    (_, mexpiration) = _compute_times(timeout)
                    if mexpiration is None or \
                       (expiration is not None and mexpiration > expiration):
                        mexpiration = expiration
                    if is_udp:
                        destination = _lltuple((where, port), af)
                        while True:
                            timeout = _timeout(mexpiration)
                            (rwire, from_address) = await s.recvfrom(65535,
                                                                     timeout)
                            if _matches_destination(af, from_address,
                                                    destination, True):
                                break
                    else:
                        ldata = await _read_exactly(s, 2, mexpiration)
                        (l,) = struct.unpack("!H", ldata)
                        rwire = await _read_exactly(s, l, mexpiration)
                    is_ixfr = (rdtype == dns.rdatatype.IXFR)
                    r = dns.message.from_wire(rwire, keyring=query.keyring,
                                              request_mac=query.mac, xfr=True,
                                              origin=origin, tsig_ctx=tsig_ctx,
                                              multi=(not is_udp),
                                              one_rr_per_rrset=is_ixfr)
                    try:
                        done = inbound.process_message(r)
                    except dns.xfr.UseTCP:
                        assert is_udp  # should not happen if we used TCP!
                        if udp_mode == UDPMode.ONLY:
                            raise
                        done = True
                        retry = True
                        udp_mode = UDPMode.NEVER
                        continue
                    tsig_ctx = r.tsig_ctx
                if not retry and query.keyring and not r.had_tsig:
                    raise dns.exception.FormError("missing TSIG")