#!/usr/bin/python3
###############################################################################
#                                                                             #
# libloc - A library to determine the location of someone on the Internet     #
#                                                                             #
# Copyright (C) 2017-2021 IPFire Development Team <info@ipfire.org>           #
#                                                                             #
# This library is free software; you can redistribute it and/or               #
# modify it under the terms of the GNU Lesser General Public                  #
# License as published by the Free Software Foundation; either                #
# version 2.1 of the License, or (at your option) any later version.          #
#                                                                             #
# This library is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU           #
# Lesser General Public License for more details.                             #
#                                                                             #
###############################################################################

import argparse
import datetime
import ipaddress
import logging
import os
import re
import shutil
import socket
import sys
import time

# Load our location module
import location
import location.downloader
import location.export

from location.i18n import _

# Setup logging
log = logging.getLogger("location")

# Output formatters

class CLI(object):
	def parse_cli(self):
		parser = argparse.ArgumentParser(
			description=_("Location Database Command Line Interface"),
		)
		subparsers = parser.add_subparsers()

		# Global configuration flags
		parser.add_argument("--debug", action="store_true",
			help=_("Enable debug output"))
		parser.add_argument("--quiet", action="store_true",
			help=_("Enable quiet mode"))

		# version
		parser.add_argument("--version", action="version",
			version="%(prog)s 0.9.17")

		# database
		parser.add_argument("--database", "-d",
			default=location.DATABASE_PATH, help=_("Path to database"),
		)

		# public key
		parser.add_argument("--public-key", "-k",
			default="/var/lib/location/signing-key.pem", help=_("Public Signing Key"),
		)

		# Show the database version
		version = subparsers.add_parser("version",
			help=_("Show database version"))
		version.set_defaults(func=self.handle_version)

		# lookup an IP address
		lookup = subparsers.add_parser("lookup",
			help=_("Lookup one or multiple IP addresses"),
		)
		lookup.add_argument("address", nargs="+")
		lookup.set_defaults(func=self.handle_lookup)

		# Dump the whole database
		dump = subparsers.add_parser("dump",
			help=_("Dump the entire database"),
		)
		dump.add_argument("output", nargs="?", type=argparse.FileType("w"))
		dump.set_defaults(func=self.handle_dump)

		# Update
		update = subparsers.add_parser("update", help=_("Update database"))
		update.add_argument("--cron",
			help=_("Update the library only once per interval"),
			choices=("daily", "weekly", "monthly"),
		)
		update.set_defaults(func=self.handle_update)

		# Verify
		verify = subparsers.add_parser("verify",
			help=_("Verify the downloaded database"))
		verify.set_defaults(func=self.handle_verify)

		# Get AS
		get_as = subparsers.add_parser("get-as",
			help=_("Get information about one or multiple Autonomous Systems"),
		)
		get_as.add_argument("asn", nargs="+")
		get_as.set_defaults(func=self.handle_get_as)

		# Search for AS
		search_as = subparsers.add_parser("search-as",
			help=_("Search for Autonomous Systems that match the string"),
		)
		search_as.add_argument("query", nargs=1)
		search_as.set_defaults(func=self.handle_search_as)

		# List all networks in an AS
		list_networks_by_as = subparsers.add_parser("list-networks-by-as",
			help=_("Lists all networks in an AS"),
		)
		list_networks_by_as.add_argument("asn", nargs=1, type=int)
		list_networks_by_as.add_argument("--family", choices=("ipv6", "ipv4"))
		list_networks_by_as.add_argument("--format",
			choices=location.export.formats.keys(), default="list")
		list_networks_by_as.set_defaults(func=self.handle_list_networks_by_as)

		# List all networks in a country
		list_networks_by_cc = subparsers.add_parser("list-networks-by-cc",
			help=_("Lists all networks in a country"),
		)
		list_networks_by_cc.add_argument("country_code", nargs=1)
		list_networks_by_cc.add_argument("--family", choices=("ipv6", "ipv4"))
		list_networks_by_cc.add_argument("--format",
			choices=location.export.formats.keys(), default="list")
		list_networks_by_cc.set_defaults(func=self.handle_list_networks_by_cc)

		# List all networks with flags
		list_networks_by_flags = subparsers.add_parser("list-networks-by-flags",
			help=_("Lists all networks with flags"),
		)
		list_networks_by_flags.add_argument("--anonymous-proxy",
			action="store_true", help=_("Anonymous Proxies"),
		)
		list_networks_by_flags.add_argument("--satellite-provider",
			action="store_true", help=_("Satellite Providers"),
		)
		list_networks_by_flags.add_argument("--anycast",
			action="store_true", help=_("Anycasts"),
		)
		list_networks_by_flags.add_argument("--drop",
			action="store_true", help=_("Hostile Networks safe to drop"),
		)
		list_networks_by_flags.add_argument("--family", choices=("ipv6", "ipv4"))
		list_networks_by_flags.add_argument("--format",
			choices=location.export.formats.keys(), default="list")
		list_networks_by_flags.set_defaults(func=self.handle_list_networks_by_flags)

		# List bogons
		list_bogons = subparsers.add_parser("list-bogons",
			help=_("Lists all bogons"),
		)
		list_bogons.add_argument("--family", choices=("ipv6", "ipv4"))
		list_bogons.add_argument("--format",
			choices=location.export.formats.keys(), default="list")
		list_bogons.set_defaults(func=self.handle_list_bogons)

		# List countries
		list_countries = subparsers.add_parser("list-countries",
			help=_("Lists all countries"),
		)
		list_countries.add_argument("--show-name",
			action="store_true", help=_("Show the name of the country"),
		)
		list_countries.add_argument("--show-continent",
			action="store_true", help=_("Show the continent"),
		)
		list_countries.set_defaults(func=self.handle_list_countries)

		# Export
		export = subparsers.add_parser("export",
			help=_("Exports data in many formats to load it into packet filters"),
		)
		export.add_argument("--format", help=_("Output format"),
			choices=location.export.formats.keys(), default="list")
		export.add_argument("--directory", help=_("Output directory"))
		export.add_argument("--family",
			help=_("Specify address family"), choices=("ipv6", "ipv4"),
		)
		export.add_argument("objects", nargs="*", help=_("List country codes or ASNs to export"))
		export.set_defaults(func=self.handle_export)

		args = parser.parse_args()

		# Configure logging
		if args.debug:
			location.logger.set_level(logging.DEBUG)
		elif args.quiet:
			location.logger.set_level(logging.WARNING)

		# Print usage if no action was given
		if not "func" in args:
			parser.print_usage()
			sys.exit(2)

		return args

	def run(self):
		# Parse command line arguments
		args = self.parse_cli()

		# Open database
		try:
			db = location.Database(args.database)
		except FileNotFoundError as e:
			# Allow continuing without a database
			if args.func == self.handle_update:
				db = None

			else:
				sys.stderr.write("location: Could not open database %s: %s\n" \
					% (args.database, e))
				sys.exit(1)

		# Translate family (if present)
		if "family" in args:
			if args.family == "ipv6":
				args.family = socket.AF_INET6
			elif args.family == "ipv4":
				args.family = socket.AF_INET
			else:
				args.family = 0

		# Call function
		try:
			ret = args.func(db, args)

		# Catch invalid inputs
		except ValueError as e:
			sys.stderr.write("%s\n" % e)
			ret = 2

		# Catch any other exceptions
		except Exception as e:
			sys.stderr.write("%s\n" % e)
			ret = 1

		# Return with exit code
		if ret:
			sys.exit(ret)

		# Otherwise just exit
		sys.exit(0)

	def handle_version(self, db, ns):
		"""
			Print the version of the database
		"""
		t = time.strftime(
			"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
		)

		print(t)

	def handle_lookup(self, db, ns):
		ret = 0

		format = "  %-24s: %s"

		for address in ns.address:
			try:
				network = db.lookup(address)
			except ValueError:
				print(_("Invalid IP address: %s") % address, file=sys.stderr)
				return 2

			args = {
				"address" : address,
				"network" : network,
			}

			# Nothing found?
			if not network:
				print(_("Nothing found for %(address)s") % args, file=sys.stderr)
				ret = 1
				continue

			print("%s:" % address)
			print(format % (_("Network"), network))

			# Print country
			if network.country_code:
				country = db.get_country(network.country_code)

				print(format % (
					_("Country"),
					country.name if country else network.country_code),
				)

			# Print AS information
			if network.asn:
				autonomous_system = db.get_as(network.asn)

				print(format % (
					_("Autonomous System"),
					autonomous_system or "AS%s" % network.asn),
				)

			# Anonymous Proxy
			if network.has_flag(location.NETWORK_FLAG_ANONYMOUS_PROXY):
				print(format % (
					_("Anonymous Proxy"), _("yes"),
				))

			# Satellite Provider
			if network.has_flag(location.NETWORK_FLAG_SATELLITE_PROVIDER):
				print(format % (
					_("Satellite Provider"), _("yes"),
				))

			# Anycast
			if network.has_flag(location.NETWORK_FLAG_ANYCAST):
				print(format % (
					_("Anycast"), _("yes"),
				))

			# Hostile Network
			if network.has_flag(location.NETWORK_FLAG_DROP):
				print(format % (
					_("Hostile Network safe to drop"), _("yes"),
				))

		return ret

	def handle_dump(self, db, ns):
		# Use output file or write to stdout
		f = ns.output or sys.stdout

		# Format everything like this
		format = "%-24s %s\n"

		# Write metadata
		f.write("#\n# Location Database Export\n#\n")

		f.write("# Generated: %s\n" % time.strftime(
			"%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
		))

		if db.vendor:
			f.write("# Vendor:    %s\n" % db.vendor)

		if db.license:
			f.write("# License:   %s\n" % db.license)

		f.write("#\n")

		if db.description:
			for line in db.description.splitlines():
				line = "# %s" % line
				f.write("%s\n" % line.rstrip())

			f.write("#\n")

		# Iterate over all ASes
		for a in db.ases:
			f.write("\n")
			f.write(format % ("aut-num:", "AS%s" % a.number))
			f.write(format % ("name:", a.name))

		flags = {
			location.NETWORK_FLAG_ANONYMOUS_PROXY    : "is-anonymous-proxy:",
			location.NETWORK_FLAG_SATELLITE_PROVIDER : "is-satellite-provider:",
			location.NETWORK_FLAG_ANYCAST            : "is-anycast:",
			location.NETWORK_FLAG_DROP               : "drop:",
		}

		# Iterate over all networks
		for n in db.networks:
			f.write("\n")
			f.write(format % ("net:", n))

			if n.country_code:
				f.write(format % ("country:", n.country_code))

			if n.asn:
				f.write(format % ("aut-num:", n.asn))

			# Print all flags
			for flag in flags:
				if n.has_flag(flag):
					f.write(format % (flags[flag], "yes"))

	def handle_get_as(self, db, ns):
		"""
			Gets information about Autonomous Systems
		"""
		ret = 0

		for asn in ns.asn:
			try:
				asn = int(asn)
			except ValueError:
				print(_("Invalid ASN: %s") % asn, file=sys.stderr)
				ret = 1
				continue

			# Fetch AS from database
			a = db.get_as(asn)

			# Nothing found
			if not a:
				print(_("Could not find AS%s") % asn, file=sys.stderr)
				ret = 1
				continue

			print(_("AS%(asn)s belongs to %(name)s") % { "asn" : a.number, "name" : a.name })

		return ret

	def handle_search_as(self, db, ns):
		for query in ns.query:
			# Print all matches ASes
			for a in db.search_as(query):
				print(a)

	def handle_update(self, db, ns):
		if ns.cron and db:
			now = time.time()

			if ns.cron == "daily":
				delta = datetime.timedelta(days=1)
			elif ns.cron == "weekly":
				delta = datetime.timedelta(days=7)
			elif ns.cron == "monthly":
				delta = datetime.timedelta(days=30)

			delta = delta.total_seconds()

			# Check if the database has recently been updated
			if db.created_at >= (now - delta):
				log.info(
					_("The database has been updated recently"),
				)
				return 3

		# Fetch the timestamp we need from DNS
		t = location.discover_latest_version()

		# Check the version of the local database
		if db and t and db.created_at >= t:
			log.info("Already on the latest version")
			return

		# Download the database into the correct directory
		tmpdir = os.path.dirname(ns.database)

		# Create a downloader
		d = location.downloader.Downloader()

		# Try downloading a new database
		try:
			t = d.download(public_key=ns.public_key, timestamp=t, tmpdir=tmpdir)

		# If no file could be downloaded, log a message
		except FileNotFoundError as e:
			log.error("Could not download a new database")
			return 1

		# If we have not received a new file, there is nothing to do
		if not t:
			return 3

		# Move temporary file to destination
		shutil.move(t.name, ns.database)

		return 0

	def handle_verify(self, db, ns):
		# Verify the database
		with open(ns.public_key, "r") as f:
			if not db.verify(f):
				log.error("Could not verify database")
				return 1

		# Success
		log.info("Database successfully verified")
		return 0

	def __get_output_formatter(self, ns):
		try:
			cls = location.export.formats[ns.format]
		except KeyError:
			cls = location.export.OutputFormatter

		return cls

	def handle_list_countries(self, db, ns):
		for country in db.countries:
			line = [
				country.code,
			]

			if ns.show_continent:
				line.append(country.continent_code)

			if ns.show_name:
				line.append(country.name)

			# Format the output
			line = " ".join(line)

			# Print the output
			print(line)

	def handle_list_networks_by_as(self, db, ns):
		writer = self.__get_output_formatter(ns)

		for asn in ns.asn:
			f = writer("AS%s" % asn, family=ns.family, f=sys.stdout)

			# Print all matching networks
			for n in db.search_networks(asns=[asn], family=ns.family):
				f.write(n)

			f.finish()

	def handle_list_networks_by_cc(self, db, ns):
		writer = self.__get_output_formatter(ns)

		for country_code in ns.country_code:
			# Open standard output
			f = writer(country_code, family=ns.family, f=sys.stdout)

			# Print all matching networks
			for n in db.search_networks(country_codes=[country_code], family=ns.family):
				f.write(n)

			f.finish()

	def handle_list_networks_by_flags(self, db, ns):
		flags = 0

		if ns.anonymous_proxy:
			flags |= location.NETWORK_FLAG_ANONYMOUS_PROXY

		if ns.satellite_provider:
			flags |= location.NETWORK_FLAG_SATELLITE_PROVIDER

		if ns.anycast:
			flags |= location.NETWORK_FLAG_ANYCAST

		if ns.drop:
			flags |= location.NETWORK_FLAG_DROP

		if not flags:
			raise ValueError(_("You must at least pass one flag"))

		writer = self.__get_output_formatter(ns)
		f = writer("custom", family=ns.family, f=sys.stdout)

		for n in db.search_networks(flags=flags, family=ns.family):
			f.write(n)

		f.finish()

	def handle_list_bogons(self, db, ns):
		writer = self.__get_output_formatter(ns)
		f = writer("bogons", family=ns.family, f=sys.stdout)

		for n in db.list_bogons(family=ns.family):
			f.write(n)

		f.finish()

	def handle_export(self, db, ns):
		countries, asns = [], []

		# Translate family
		if ns.family:
			families = [ ns.family ]
		else:
			families = [ socket.AF_INET6, socket.AF_INET ]

		for object in ns.objects:
			m = re.match(r"^AS(\d+)$", object)
			if m:
				object = int(m.group(1))

				asns.append(object)

			elif location.country_code_is_valid(object) \
					or object in ("A1", "A2", "A3", "XD"):
				countries.append(object)

			else:
				log.warning("Invalid argument: %s" % object)
				continue

		# Default to exporting all countries
		if not countries and not asns:
			countries = ["A1", "A2", "A3", "XD"] + [country.code for country in db.countries]

		# Select the output format
		writer = self.__get_output_formatter(ns)

		e = location.export.Exporter(db, writer)
		e.export(ns.directory, countries=countries, asns=asns, families=families)


def format_timedelta(t):
	s = []

	if t.days:
		s.append(
			_("One Day", "%(days)s Days", t.days) % { "days" : t.days, }
		)

	hours = t.seconds // 3600
	if hours:
		s.append(
			_("One Hour", "%(hours)s Hours", hours) % { "hours" : hours, }
		)

	minutes = (t.seconds % 3600) // 60
	if minutes:
		s.append(
			_("One Minute", "%(minutes)s Minutes", minutes) % { "minutes" : minutes, }
		)

	seconds = t.seconds % 60
	if t.seconds:
		s.append(
			_("One Second", "%(seconds)s Seconds", seconds) % { "seconds" : seconds, }
		)

	if not s:
		return _("Now")

	return _("%s ago") % ", ".join(s)

def main():
	# Run the command line interface
	c = CLI()
	c.run()

main()
