#! /usr/bin/python3

import os
import re
import shutil
import ssl
import sys
from http.server import BaseHTTPRequestHandler, HTTPServer


class handler(BaseHTTPRequestHandler):
    def match(self, p):
        return re.fullmatch(p, self.path)

    def do_POST(self):
        data = self.rfile.read(int(self.headers['Content-Length']))
        m = self.match("/auth/realms/redhat-external/protocol/openid-connect/token")
        if m and data == b'grant_type=refresh_token&client_id=rhsm-api&refresh_token=valid_token':
            self.send_response(200)
            self.end_headers()
            self.wfile.write(b'{ "access_token": "my_access_token" }\n')
            return

        self.send_response(400)
        self.end_headers()

    def do_GET(self):
        authorization = self.headers['Authorization']
        m = self.match("/management/v1/images/rhel/8.1/.*")

        if m and authorization == 'Bearer my_access_token':
            self.send_response(200)
            self.end_headers()
            self.wfile.write(b'{ "body": [\
                {"imageName": "RHEL 8.1 Boot ISO", "filename": "rhel-8-1-boot.iso",\
                "downloadHref": "https://api.access.redhat.com/management/v1/images/my_image_checksum/download"}\
            ] }\n')
            return

        m = self.match("/management/v1/images/rhel/.*/.*")
        if m and authorization == 'Bearer my_access_token':
            self.send_response(200)
            self.end_headers()
            self.wfile.write(b'{ "body": [] }\n')
            return

        m = self.match("/management/v1/images/my_image_checksum/download")
        if m and authorization == 'Bearer my_access_token':
            filepath = "/var/lib/libvirt/images/example.img"
            with open(filepath, 'rb') as f:
                self.send_response(200)
                self.send_header("Content-Type", 'application/octet-stream')
                self.send_header("Content-Disposition", f"attachment; filename='{os.path.basename(filepath)}'")
                fs = os.fstat(f.fileno())
                self.send_header("Content-Length", str(fs.st_size))
                self.end_headers()
                # Specify 1KiB buffer size, otherwise the download might be too fast on some platforms
                shutil.copyfileobj(f, self.wfile, 10240)
            return

        self.send_response(404)
        self.end_headers()


if __name__ == '__main__':
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    context.load_cert_chain(certfile=sys.argv[1], keyfile=sys.argv[2])
    context.check_hostname = False

    with HTTPServer(("localhost", 443), handler) as httpd:
        httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
        httpd.serve_forever()
