#!/usr/bin/python3

import sys
import json
import logging
import colorlog
from time import sleep
import re
import dns.update
import dns.query
import dns.tsigkeyring
import dns.resolver
import dns.reversename
import dns.rdata
import dns.rdatatype

# Only log warnings
log_level = logging.INFO

#########################################
#
# Define your nameservers here
#
#########################################

# Default TTL for dsnet records is 5 minutes
default_ttl = 300

# Declare our internal DNS server
# dsnet_int_nameserver = '10.164.236.1'
# Or leave as 'json' to use "DNS" from dsnetreport.json
dsnet_int_nameserver = 'json'

# Define an external DNS server here if using split horizon
# dsnet_ext_nameserver = '198.51.100.2'
# Or set to 'json' to use "ExternalIP" from dsnetreport.json
# dsnet_ext_nameserver = 'json'
# Or set to 'None' to disable split horizon DNS
dsnet_ext_nameserver = None

# Specifically declare our zone (NOTE THE '.' AT THE END)
dsnet_zone = 'example.com.'
# Or set to 'json' to use "Domain" from dsnetreport.json
# dsnet_zone = 'json'

# Declare our reverse zones here
dsnet_reverse_zone = '236.164.10.in-addr.arpa.'
dsnet_reverse6_zone = '0.0.e.a.a.6.0.1.1.3.b.7.0.0.d.f.ip6.arpa.'
# In the future we should automatically determine the reverse zone
# from the 'Network' and 'Network6' parameters in the JSON
# Currently the below does not work correctly:
# dns.reversename.from_address(ipv4_space).to_text()
# dns.reversename.from_address(ipv6_space).to_text()

# Which TSIG key file do we need to use
dns_tsig_key_file = '/etc/bind/dsnet-update.key'

# Which TXT record are we using to track current peers?
dsnet_current_peers_record = '_dsnet_peers'

#########################################

# Logger format
log_format = colorlog.ColoredFormatter(
        "%(asctime)s %(log_color)s[%(levelname)s]%(reset)s %(name)s: %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%S",
        log_colors={
            'DEBUG': 'cyan',
            'INFO': 'green',
            'WARNING': 'yellow',
            'ERROR': 'red',
            'CRITICAL': 'red,bg_white',
            }
        )

# Set up the fancy colour logging
handler = colorlog.StreamHandler()
handler.setFormatter(log_format)
logger = colorlog.getLogger('dsnsupdate')
logger.addHandler(handler)
logger.setLevel(log_level)

# Set up some resolver instances
# Internally
resolver_int = dns.resolver.Resolver(configure=False)

# And externally
if dsnet_ext_nameserver:
    resolver_ext = dns.resolver.Resolver(configure=False)


# Dirty function to load a TSIG key from a file
def load_tsig_key(tsig_file):
    try:
        # Open the file
        f = open(tsig_file)
        # Read the contents
        lines = f.readlines()
        # Close it again
        f.close()
    except FileNotFoundError:
        # If the file isn't found, log and error and quit
        logger.error("Failed to load TSIG key!")
        sys.exit(1)

    # Iterate through the lines we read
    for line in lines:
        if 'key' in line:
            # Read the line with the key name
            key_line = line
        if 'secret' in line:
            # Read the line with the secret
            secret_line = line

    if not key_line:
        # If we don't have a key name, log an error and quit
        logger.error("No key name found!")
        sys.exit(1)

    if not secret_line:
        # If we don't have a secret, log an error and quit
        logger.error("No secrets found!")
        sys.exit(1)

    # Construct the key dict for dnspython
    dns_key = {}
    # Grab the key name from the raw line
    key_name = key_line.split(' ')[1]
    # Grab the secret from the raw line
    key_secret = secret_line.split('"')[1]
    # Place it in the dict
    dns_key[key_name] = key_secret

    # Return the dict
    return dns_key


def process_hostname(hostname):
    # Identify if the hostname supplied is a valid
    # FQDN for the zone we are mananging
    if hostname.endswith('.' + dsnet_zone):
        fqdn = hostname
    elif hostname.endswith('.' + dsnet_zone[:-1]):
        fqdn = hostname + '.'
    else:
        fqdn = hostname + '.' + dsnet_zone

    # Check if the name has been delegated
    try:
        answer_ns = resolver_int.query(fqdn, 'NS')
        # Name has been delegated, and will be ignored!
        return fqdn
    except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
        # If it's not delegated, that's fine!
        pass

    # Check if it already exists
    try:
        answer = resolver_int.query(fqdn, 'A')
        # If the TTL is over 300, it's probably a service
        if answer.rrset.ttl > default_ttl:
            # Add a -dsnet suffix to it to prevent spoofing
            # Or more likely, the name is in use in a subnet
            # thus -dsnet should be appended
            logger.info(str(hostname) + ' already taken! Using ' +
                        str(hostname) + '-dsnet instead')
            fqdn = fqdn[:-12] + '-dsnet.' + dsnet_zone

    except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
        # If the host doesn't exist, that's fine!
        pass

    return fqdn


def get_current_peers(peer_txt_record):
    # Set up our current peers dict
    current_peers = {}
    try:
        # Grab the TXT record containing our current list of peers
        peer_list = resolver_int.query(peer_txt_record, 'TXT')
        for peer_entry in peer_list:
            # For each peer in the result decode the hostname
            peer = peer_entry.strings[0].decode()
            # Create an entry in the dict for it
            current_peers[peer] = {}
            # Determine it's FQDN
            fqdn = process_hostname(peer)
            current_peers[peer]['fqdn'] = fqdn

            # Delegation
            try:
                # Determine if the name is delegated
                answer_ns = resolver_int.query(fqdn, 'NS')
                ns_record = answer_ns[0].to_text()
                logger.debug(fqdn + ' has been delegated to ' + ns_record)
                current_peers[peer]['delegated'] = True
            except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                current_peers[peer]['delegated'] = False

            # IPv4
            try:
                # Resolve IPv4 record
                answer = resolver_int.query(fqdn, 'A')
                current_peers[peer]['ip'] = answer[0].to_text()
                # Generate our reverse record name from the IPv4
                # And get what's currently in the DNS
                reverse_ptr = dns.reversename.from_address(current_peers[peer]['ip'])
                current_peers[peer]['reverse'] = reverse_ptr.to_text()
            except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                # Set these to None if they do not exist
                logger.debug('Incomplete IPv4 records for ' + fqdn)
                current_peers[peer]['ip'] = None
                current_peers[peer]['reverse'] = None
            if current_peers[peer]['reverse']:
                try:
                    # If there's an A record, query the reverse for it
                    answer_ptr = resolver_int.query(current_peers[peer]['reverse'],
                                                    'PTR')
                    current_peers[peer]['reverse_ptr'] = answer_ptr[0].to_text()
                except(dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                    # Set to None if it doesn't exist
                    logger.debug('Incomplete IPv4 records for ' + fqdn)
                    current_peers[peer]['reverse_ptr'] = None
            else:
                current_peers[peer]['reverse_ptr'] = None

            # IPv6
            try:
                # Resolve IPv6 record
                answer6 = resolver_int.query(fqdn, 'AAAA')
                current_peers[peer]['ip6'] = answer6[0].to_text()
                # Generate our reverse record name from the IPv6
                # And get what's currently in the DNS
                reverse6_ptr = dns.reversename.from_address(current_peers[peer]['ip6'])
                current_peers[peer]['reverse6'] = reverse6_ptr.to_text()
            except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                # Set these to None if they do not exist
                logger.debug('Incomplete IPv6 records for ' + fqdn)
                current_peers[peer]['ip6'] = None
                current_peers[peer]['reverse6'] = None
            if current_peers[peer]['reverse6']:
                try:
                    # If there's an AAAA record, query the reverse for it
                    answer6_ptr = resolver_int.query(current_peers[peer]['reverse6'],
                                                     'PTR')
                    current_peers[peer]['reverse6_ptr'] = answer6_ptr[0].to_text()
                except(dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                    # Set to None if it doesn't exist
                    logger.debug('Incomplete IPv6 records for ' + fqdn)
                    current_peers[peer]['reverse6_ptr'] = None
            else:
                current_peers[peer]['reverse6_ptr'] = None

            # External IP
            if dsnet_ext_nameserver:
                try:
                    # Resolve external IP
                    answer_ext = resolver_ext.query(fqdn, 'A')
                    current_peers[peer]['ext_ip'] = answer_ext[0].to_text()
                except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                    # Set to None if it doesn't exist
                    current_peers[peer]['ext_ip'] = None

    except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
        # If we are here, it means our TXT record doesn't exist
        # So we have no idea what's in DNS current and it needs fixing
        # manually. DNS is working fine, however.
        logger.error("Couldn't retrieve current list of peers! Exiting...")
        sys.exit(1)

    # If we get here, we've successfully processed all the current peers
    # So return the dict
    return current_peers


def process_peer_json(json_data):
    # The JSON data has multiple entries, so iterate throug them
    for entry in json_data:
        if entry == 'Peers':
            # We're only interested in the 'Peers' entry
            json_peers = json_data['Peers']

    # Sift through the peers from JSON and get the data we want
    new_peers = {}
    for peer_entry in json_peers:
        # Get the peer name
        peer = peer_entry['Hostname']
        new_peers[peer] = {}
        # Get a safe FQDN
        fqdn = process_hostname(peer)
        new_peers[peer]['fqdn'] = fqdn
        # Set the IPv4
        new_peers[peer]['ip'] = peer_entry['IP']
        # Set the IPv6
        new_peers[peer]['ip6'] = peer_entry['IP6']
        if dsnet_ext_nameserver:
            if peer_entry['Online']:
                # Only set an external IP if the peer is online
                new_peers[peer]['ext_ip'] = peer_entry['ExternalIP']
            else:
                # Else set it to None
                new_peers[peer]['ext_ip'] = None
        # Construct the reverse records for the peer for IPv4
        reverse_ptr = dns.reversename.from_address(peer_entry['IP'])
        new_peers[peer]['reverse'] = reverse_ptr.to_text()
        new_peers[peer]['reverse_ptr'] = fqdn
        # And IPv6
        if new_peers[peer]['ip6']:
            # If enabled
            reverse6_ptr = dns.reversename.from_address(peer_entry['IP6'])
            new_peers[peer]['reverse6'] = reverse6_ptr.to_text()
            new_peers[peer]['reverse6_ptr'] = fqdn
        else:
            # Else set to None
            new_peers[peer]['ip6'] = None
            new_peers[peer]['reverse6'] = None
            new_peers[peer]['reverse6_ptr'] = None

    # Return a list of what needs to be in DNS
    return new_peers


def main():
    logger.info('Updating dsnet DNS zone')
    # We should have a json file as an argument
    if len(sys.argv) < 2:
        # Quit if not present
        logger.error('I need JSON to live!')
        sys.exit(1)

    with open(sys.argv[1]) as update_file:
        # Open and load that JSON file
        dsnet_json = json.load(update_file)

    # If we're using the JSON data for our zone
    # then pull that in
    global dsnet_zone
    if dsnet_zone.lower() == 'json':
        dsnet_zone = dsnet_json['Domain']
    # Just in case people forget...
    if not dsnet_zone.endswith('.'):
        dsnet_zone = dsnet_zone + '.'
    logger.debug('Using DNS zone: ' + dsnet_zone)

    # Create the full FQDN for our peer list txt record
    dsnet_current_peers_txt = dsnet_current_peers_record + '.' + dsnet_zone

    # If we're using the JSON data for our int nameserver
    # then pull that in
    global dsnet_int_nameserver
    if dsnet_int_nameserver.lower() == 'json':
        dsnet_int_nameserver = dsnet_json['DNS']
    logger.debug('Using internal nameserver: ' + dsnet_int_nameserver)

    # If we're using the JSON data for our ext nameserver
    # then pull that in
    global dsnet_ext_nameserver
    if dsnet_ext_nameserver:
        if dsnet_ext_nameserver.lower() == 'json':
            dsnet_ext_nameserver = dsnet_json['ExternalIP']
        logger.debug('Using external nameserver: ' + dsnet_ext_nameserver)
    else:
        logger.debug('No external nameserver specified!')

    # Add these to the resolver objects
    resolver_int.nameservers = [dsnet_int_nameserver]
    if dsnet_ext_nameserver:
        resolver_ext.nameservers = [dsnet_ext_nameserver]

    # Determine our reverse zones from the data in the JSON
    # For IPv4
    ipv4_space = re.sub('\/[0-9]+$', '', dsnet_json['Network'])
    logger.debug('Using IPv4 address space ' + dsnet_json['Network'])
    logger.debug('with reverse zone ' + dsnet_reverse_zone)

    # And for IPv6
    ipv6_space = re.sub('\/[0-9]+$', '', dsnet_json['Network6'])
    logger.debug('Using IPv6 address space ' + dsnet_json['Network6'])
    logger.debug('with reverse zone ' + dsnet_reverse6_zone)

    # Get a list of what's currently in DNS
    current_peers = get_current_peers(dsnet_current_peers_txt)

    # Print some debug info about current peers
    logger.debug("Current peers:")
    logger.debug(current_peers)

    # Work out what needs to be in DNS
    new_peers = process_peer_json(dsnet_json)

    # Print some debug info
    logger.debug("New peers:")
    logger.debug(new_peers)

    # Set up some lists for what we're updating
    add_peers = []
    update_int_peers = []
    update_int6_peers = []
    if dsnet_ext_nameserver:
        update_ext_peers = []
    update_ptr_peers = []
    update_ptr6_peers = []
    delete_peers = []

    # What do we delete?
    for peer in current_peers:
        # If the peer is in current_peers but not new_peers
        # it has been deleted
        if peer not in new_peers:
            # Add it to the list
            delete_peers.append(peer)

    # What do we add?
    for peer in new_peers:
        # If the peer is in new_peers but not current_peers, it is new
        if peer not in current_peers:
            # Add it to the list
            add_peers.append(peer)
        else:
            # What do we update?
            # Check if this peer is delegated to it's own DNS first
            if not current_peers[peer]['delegated']:
                # Check internal IPv4
                if new_peers[peer]['ip'] != current_peers[peer]['ip']:
                    # Update if the internal IPv4 doesn't match
                    update_int_peers.append(peer)
                # Check internal IPv6
                if new_peers[peer]['ip6'] != current_peers[peer]['ip6']:
                    # Update if the internal IPv4 doesn't match
                    update_int6_peers.append(peer)

            if dsnet_ext_nameserver:
                # Check external IP
                if new_peers[peer]['ext_ip'] != current_peers[peer]['ext_ip']:
                    # Update if the external IP doesn't match
                    update_ext_peers.append(peer)

            # Check reverse IPv4 record
            if new_peers[peer]['reverse_ptr'] != current_peers[peer]['reverse_ptr']:
                # Update if the PTR records don't match
                # Check if it's in our IPv4 reverse zone
                if new_peers[peer]['reverse'].endswith(dsnet_reverse_zone):
                    update_ptr_peers.append(peer)
                else:
                    logger.warn(peer + " internal IPv4 not in our reverse zone!")

            # Check reverse IPv6 record
            if new_peers[peer]['reverse6_ptr'] != current_peers[peer]['reverse6_ptr']:
                # Update if the PTR records don't match
                # Check if it's in our IPv6 reverse zone
                if new_peers[peer]['reverse6'].endswith(dsnet_reverse6_zone):
                    update_ptr6_peers.append(peer)
                else:
                   logger.warn(peer + " internal IPv6 not in our reverse zone!")

    # List peers we're adding
    if add_peers:
        logger.info("Adding peers:")
        for peer in add_peers:
            logger.info(" - " + peer)

    # List peers we're updating the internal IPv4 of
    if update_int_peers:
        logger.info("Updating internal IPv4 peers:")
        for peer in update_int_peers:
            logger.info(" - " + peer + ": " + str(new_peers[peer]['ip']))

    # List peers we're updating the internal IPv6 of
    if update_int6_peers:
        logger.info("Updating internal IPv6 peers:")
        for peer in update_int6_peers:
            logger.info(" - " + peer + ": " + str(new_peers[peer]['ip6']))

    if dsnet_ext_nameserver:
        # List peers we're updating the external IP of
        if update_ext_peers:
            logger.info("Updating external peers:")
            for peer in update_ext_peers:
                logger.info(" - " + peer + ": " + str(new_peers[peer]['ext_ip']))

    # List peers we're updating the reverse IPv4 of
    if update_ptr_peers:
        logger.info("Updating IPv4 reverse peers:")
        for peer in update_ptr_peers:
            logger.info(" - " + peer + ": " + str(new_peers[peer]['reverse_ptr']))

    # List peers we're updating the reverse IPv6 of
    if update_ptr6_peers:
        logger.info("Updating IPv6 reverse peers:")
        for peer in update_ptr6_peers:
            logger.info(" - " + peer + ": " + str(new_peers[peer]['reverse6_ptr']))

    # List peers we're deleting
    if delete_peers:
        logger.info("Deleting peers:")
        for peer in delete_peers:
            logger.info(" - " + peer)

    # If there's nothing in any of these lists,
    # we don't need to do anything!
    if not add_peers and not delete_peers:
        if not update_int_peers and not update_int6_peers:
            if not update_ptr_peers and not update_ptr6_peers:
                if dsnet_ext_nameserver:
                    if not update_ext_peers:
                        logger.info("Nothing to do! Exiting...")
                        sys.exit(0)
                else:
                    logger.info("Nothing to do! Exiting...")
                    sys.exit(0)

    # Load the TSIG key from file
    dsnet_update_key = load_tsig_key(dns_tsig_key_file)
    # Add it to the keyring
    keyring = dns.tsigkeyring.from_text(dsnet_update_key)

    # Set up the update entries for each zone
    update_int = dns.update.Update(dsnet_zone, keyring=keyring)
    update_ext = dns.update.Update(dsnet_zone, keyring=keyring)
    update_reverse = dns.update.Update(dsnet_reverse_zone, keyring=keyring)
    update_reverse6 = dns.update.Update(dsnet_reverse6_zone, keyring=keyring)

    # Manage the TXT record first
    # Only change the TXT records we are adding
    for peer in add_peers:
        # Add the TXT record for the peer
        update_int.add(dsnet_current_peers_txt, default_ttl, 'TXT', peer)
    # Or deleting
    for peer in delete_peers:
        # Construct an rdata object so we can delete a SPECIFIC record
        datatype = dns.rdatatype.from_text('TXT')
        rdata = dns.rdata.from_text(dns.rdataclass.IN, datatype, peer)
        update_int.delete(dsnet_current_peers_txt, rdata)

    # For new peers
    for peer in add_peers:
        # Add the A record and reverse
        update_int.replace(new_peers[peer]['fqdn'], default_ttl,
                           'A', new_peers[peer]['ip'])
        update_reverse.replace(new_peers[peer]['reverse'], default_ttl,
                               'PTR', new_peers[peer]['fqdn'])

        # Add the AAAA record and reverse if there is an IPv6
        if new_peers[peer]['ip6']:
            update_int.replace(new_peers[peer]['fqdn'], default_ttl,
                               'AAAA', new_peers[peer]['ip6'])
            update_reverse6.replace(new_peers[peer]['reverse'], default_ttl,
                                    'PTR', new_peers[peer]['fqdn'])

        if dsnet_ext_nameserver:
            # An external IP if present
            if new_peers[peer]['ext_ip']:
                update_ext.replace(new_peers[peer]['fqdn'], default_ttl,
                                   'A', new_peers[peer]['ext_ip'])

    # Update IPv4 records as needed
    for peer in update_int_peers:
        # Update if present
        if new_peers[peer]['ip']:
            update_int.replace(new_peers[peer]['fqdn'], default_ttl,
                               'A', new_peers[peer]['ip'])
        # Delete if removed for some reason
        else:
            update_int.delete(current_peers[peer]['fqdn'], 'A')

    # Update IPv6 records as needed
    for peer in update_int6_peers:
        # Update if present
        if new_peers[peer]['ip6']:
            update_int.replace(new_peers[peer]['fqdn'], default_ttl,
                               'AAAA', new_peers[peer]['ip6'])
        # Delete if removed for some reason
        else:
            update_int.delete(current_peers[peer]['fqdn'], 'AAAA')

    if dsnet_ext_nameserver:
        # Update external IPs if needed
        for peer in update_ext_peers:
            # Update if present
            if new_peers[peer]['ext_ip']:
                update_ext.replace(new_peers[peer]['fqdn'], default_ttl,
                                   'A', new_peers[peer]['ext_ip'])
            # Delete if host has disconnected
            else:
                update_ext.delete(current_peers[peer]['fqdn'], 'A')

    # Update reverse IPv4 reconds as needed
    for peer in update_ptr_peers:
        # Update if present
        if new_peers[peer]['reverse']:
            update_reverse.replace(new_peers[peer]['reverse'], default_ttl,
                                   'PTR', new_peers[peer]['fqdn'])
        # Delete if removed for some reason
        else:
            update_reverse.delete(current_peers[peer]['reverse'], 'PTR')

    # Update reverse IPv6 reconds as needed
    for peer in update_ptr6_peers:
        # Update if present
        if new_peers[peer]['reverse6']:
            update_reverse6.replace(new_peers[peer]['reverse6'], default_ttl,
                                    'PTR', new_peers[peer]['fqdn'])
        # Delete if removed for some reason
        else:
            update_reverse6.delete(current_peers[peer]['reverse6'], 'PTR')

    # For deleted peers
    for peer in delete_peers:
        # Delete the forward records
        update_int.delete(current_peers[peer]['fqdn'], 'A')
        update_int.delete(current_peers[peer]['fqdn'], 'AAAA')
        # Delete the external IP record if it exists
        if dsnet_ext_nameserver:
            if current_peers[peer]['ext_ip']:
                update_ext.delete(current_peers[peer]['fqdn'], 'A')
        # Delete the reverse records
        update_reverse.delete(current_peers[peer]['reverse'], 'PTR')
        update_reverse6.delete(current_peers[peer]['reverse6'], 'PTR')

    try:
        # Send the updates to the DNS servers, via TCP because they are LONG
        # Internal forward zone
        logger.debug(update_int)
        response = dns.query.tcp(update_int, dsnet_int_nameserver, timeout=10)

        if dsnet_ext_nameserver:
            # External forward zone
            logger.debug(update_ext)
            response = dns.query.tcp(update_ext, dsnet_ext_nameserver, timeout=10)

        # IPv4 reverse zone
        logger.debug(update_reverse)
        response = dns.query.tcp(update_reverse, dsnet_int_nameserver, timeout=10)

        # IPv6 reverse zone
        logger.debug(update_reverse6)
        response = dns.query.tcp(update_reverse6, dsnet_int_nameserver, timeout=10)
    except dns.tsig.PeerBadKey:
        # Warn if we get a TSIG key error
        logger.error("TSIG key failure on update!")
        sys.exit(1)

    # All done!
    sys.exit(0)


if __name__ == '__main__':
    main()
