import select
import socket
import string
import time

import runloop

def fromnum16(num):
    return chr((num >> 8) & 255) + chr(num & 255)

def fromnum32(num):
    return chr((num >> 24) & 255) + chr((num >> 16) & 255) \
	   + chr((num >> 8) & 255) + chr(num & 255)

def make_query(question, id):
    if len(question) == 3:
	return fromnum16(id) + '\000\000\000\001\000\000\000\000\000\000' \
	       + make_question(question)
    else:
	return fromnum16(id) + '\010\000\000\000\000\001\000\000\000\000' \
	       + make_rr(question[:3])

def make_question(question):
    qname, qtype, qclass = question
    return name_to_label_sequence(qname) + fromnum16(qtype) + fromnum16(qclass)

def make_rr(rr):
    rdata, qtype, qclass = rr
    # qtype == 5
    rdata = name_to_label_sequence(rdata)
    return '\000' + fromnum16(qtype) + fromnum16(qclass) + fromnum32(0) \
	   + fromnum16(len(rdata)) + rdata

def name_to_label_sequence(name):
    labels = string.split(name, '.')
    if labels[-1] <> '':
	labels.append('')
    data = ''
    for label in labels:
	data = data + chr(len(label)) + label
    return data

def parse_message(message):
    header = message[:12]
    header_fields = parse_header(header)
    i = 0
    pos = 12
    questions = []
    while i < header_fields['qdcount']:
	question, pos = parse_question(message, pos)
	questions.append(question)
	i = i + 1
    i = 0
    answers = []
    while i < header_fields['ancount']:
	answer, pos = parse_rr(message, pos)
	answers.append(answer)
	i = i + 1
    i = 0
    nameservers = []
    while i < header_fields['nscount']:
	nameserver, pos = parse_rr(message, pos)
	nameservers.append(nameserver)
	i = i + 1
    i = 0
    additionals = []
    while i < header_fields['arcount']:
	additional, pos = parse_rr(message, pos)
	additionals.append(additional)
	i = i + 1
    return header_fields, questions, answers, nameservers, additionals

def parse_header(header):
    id = tonum16(header[:2])
    a = ord(header[2])
    qr = a >> 7
    opcode = (a >> 3) & 15
    aa = (a >> 2) & 1
    tc = (a >> 1) & 1
    rd = a & 1
    b = ord(header[3])
    ra = b >> 7
    z = (b >> 4) & 7
    rcode = b & 15
    qdcount = tonum16(header[4:6])
    ancount = tonum16(header[6:8])
    nscount = tonum16(header[8:10])
    arcount = tonum16(header[10:12])
    del header
    del a
    del b
    return locals()

def tonum16(octets):
    a = octets[0]
    b = octets[1]
    return 256 * ord(a) + ord(b)

def tonum32(octets):
    return 16777216L * long(ord(octets[0])) \
	   + 65536L * long(ord(octets[1])) \
	   + 256L * long(ord(octets[2])) \
	   + long(ord(octets[3]))

def parse_question(message, pos):
    qname, pos = parse_name(message, pos)
    qtype = tonum16(message[pos:pos + 2])
    qclass = tonum16(message[pos + 2:pos + 4])
    return (qname, qtype, qclass), (pos + 4)

def parse_name(message, pos):
    length = ord(message[pos])
    if not length:
	return '', pos + 1
    elif not (length >> 6):
	tail, npos = parse_name(message, pos + length + 1)
	return (message[pos + 1:pos + length + 1] + '.' + tail), npos
    else:
	pointer = 256 * (length & 63) + ord(message[pos + 1])
	return parse_name(message, pointer)[0], (pos + 2)

def rdata_parser_a(rdata, message, pos):
    address = tonum32(rdata)
    return str(ord(rdata[0])) + '.' + str(ord(rdata[1])) + '.' \
	   + str(ord(rdata[2])) + '.' + str(ord(rdata[3]))

def rdata_parser_domainname(rdata, message, pos):
    return parse_name(message, pos)[0]

rdata_parser = {1: rdata_parser_a, 2: rdata_parser_domainname,
		5: rdata_parser_domainname, 12: rdata_parser_domainname}

def parse_rr(message, pos):
    name, pos = parse_name(message, pos)
    type = tonum16(message[pos:pos + 2])
    clas = tonum16(message[pos + 2:pos + 4])
    ttl = tonum32(message[pos + 4:pos + 8])
    rdlength = tonum16(message[pos + 8:pos + 10])
    rdata = message[pos + 10:pos + rdlength + 10]
    if rdata_parser.has_key(type):
	rdata = rdata_parser[type](rdata, message, pos + 10)
    return (name, type, clas, ttl, rdata), (pos + rdlength + 10)

def exchange(query, server = '194.89.10.13'):
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.sendto(query, 0, (server, 53))
    rs, ws, exs = select.select([sock], [], [], 10.0)
    if rs:
	reply = sock.recvfrom(512, 0)
	return reply[0]
    else:
	return ()

gethostbyname_cache = {}
gethostbyaddr_cache = {}
default_server = ''
dnssocket = None
max_op_count = 256

class DNSQuery:
    def __init__(self, question, answer, server, op_count = max_op_count):
	if op_count:
	    self.question = question
	    self.answer = answer
            self.server = server
	    self.op_count = op_count
	    self.query_id = dnssocket.register_query(self.received)
	    dnssocket.send_query(make_query(self.question, self.query_id),
				 server)
	    dnssocket.recvfrom(512)
	else:
	    answer.error('Maximum query count for name lookup reached.')

    def received(self, message):
	header_fields, questions, answers, nameservers, additionals = message
	if header_fields['rcode']:
	    self.answer.error(header_fields['rcode'])
	else:
	    if answers:
		if len(self.question) == 3:
		    if answers[0][1] == 1:
			self.answer.answer(answers[0][4])
		    elif answers[0][1] == 5:
			gethostbyname(answers[0][4], self.answer,
				      self.op_count - 1)
		    else:
			self.answer.error('Funny response from name server.')
		else:
		    self.answer.answer(map(lambda x: x[0], answers))
	    else:
                self.answer._unfound[self.server] = None
                NameserverLookup(nameservers, self.question,
                                 self.answer, self.op_count - 1)

class NameserverLookup:
    def __init__(self, nameservers, question, answer, op_count):
        self.nameservers = nameservers
	self.question = question
	self.stored_answer = answer
	self.op_count = op_count
        self.next_server()

    def next_server(self):
        if self.nameservers:
            nameserver = self.nameservers[0][4]
            self.nameservers = self.nameservers[1:]
            self.op_count = self.op_count - 1
            gethostbyname(nameserver, self, self.op_count)
        else:
            self.answer.error('Resolver failure, unable to backtrack.')

    def answer(self, host):
        if not self.stored_answer._unfound.has_key(host):
            DNSQuery(self.question, self.stored_answer, host,
                     self.op_count - 1)
        else:
            self.next_server()

def ensure_init():
    global dnssocket
    if not dnssocket:
	dnssocket = DNSSocket()

def cache_answer(now, answer):
    name, type, clas, ttl, rdata = answer
    if type == 1:
	gethostbyname_cache[name] = now + ttl, rdata
    elif type == 12:
	gethostbyaddr_cache[name] = now + ttl, rdata

class DNSSocket(runloop.UDPSocket):
    def __init__(self):
	global default_server
	runloop.UDPSocket.__init__(self)
	self.queries = {}
	self.query_id = 0
	f = open('/etc/resolv.conf')
	for line in f.readlines():
	    elems = string.split(line)
	    if len(elems) >= 2 and elems[0] == 'nameserver':
		default_server = elems[1]
                break
	f.close()

    def register_query(self, query):
	self.query_id = (self.query_id + 1) & 65535
	self.queries[self.query_id] = query
	return self.query_id

    def received(self, data, address):
	message = parse_message(data)
	now = time.time()
	for i in [2, 3, 4]:
	    for answer in message[i]:
		cache_answer(now, answer)
	self.queries[message[0]['id']](message)

    def send_query(self, query, server):
	self.sendto(query, 0, (server, 53))

def gethostbyname(name, answer, op_count = max_op_count):
    ensure_init()
    if gethostbyname_cache.has_key(name):
	expires, host = gethostbyname_cache[name]
	if expires > time.time():
	    answer.answer(host)
	    return
	else:
	    del gethostbyname_cache[name]
    answer._unfound = {}
    DNSQuery((name, 1, 1), answer, default_server, op_count)

def gethostbyaddr(address, answer):
    a = string.split(address, '.')[:4]
    a.reverse()
    addrname = reduce(lambda x, y: x + '.' + y, a) + '.in-addr.arpa'
    ensure_init()
    if gethostbyaddr_cache.has_key(addrname):
	expires, name = gethostbyaddr_cache[addrname]
	if expires > time.time():
	    answer.answer(name)
	    return
	else:
	    del gethostbyaddr_cache[addrname]
    answer._unfound = {}
    DNSQuery((addrname, 12, 1), answer, default_server)

def getaliases(name, answer):
    ensure_init()
    answer._unfound = {}
    DNSQuery((name, 5, 1, 1), answer, default_server)

## End. ##
