#!/usr/bin/env python3

import argparse
import http.server
import json
import re
from socketserver import ThreadingMixIn
import sys
from pmtiles.reader import Reader, MmapSource


# https://docs.python.org/3/library/http.server.html
class ThreadingSimpleServer(ThreadingMixIn, http.server.HTTPServer):
    pass


parser = argparse.ArgumentParser(description="HTTP server for PMTiles archives.")
parser.add_argument("pmtiles_file", help="PMTiles archive to serve")
parser.add_argument("port", help="Port to bind to")
parser.add_argument("--bind", help="Address to bind server to: default localhost")
parser.add_argument(
    "--cors-allow-all",
    help="Return Access-Control-Allow-Origin:* header",
    action="store_true",
)
args = parser.parse_args()

with open(args.pmtiles_file, "r+b") as f:
    source = MmapSource(f)
    reader = Reader(source)

    fmt = reader.header().metadata["format"]

    class Handler(http.server.SimpleHTTPRequestHandler):
        def do_GET(self):
            if self.path == "/metadata":
                self.send_response(200)
                if args.cors_allow_all:
                    self.send_header("Access-Control-Allow-Origin", "*")
                self.end_headers()
                self.wfile.write(json.dumps(reader.header().metadata).encode("utf-8"))
                return
            match = re.match("/(\d+)/(\d+)/(\d+)." + fmt, self.path)
            if not match:
                self.send_response(400)
                self.end_headers()
                self.wfile.write("bad request".encode("utf-8"))
                return
            z = int(match.group(1))
            x = int(match.group(2))
            y = int(match.group(3))
            data = reader.get(z, x, y)
            if not data:
                self.send_response(404)
                self.end_headers()
                self.wfile.write("tile not found".encode("utf-8"))
                return
            self.send_response(200)
            if args.cors_allow_all:
                self.send_header("Access-Control-Allow-Origin", "*")
            if fmt == "pbf":
                self.send_header("Content-Type", "application/x-protobuf")
            else:
                self.send_header("Content-Type", "image/" + fmt)
            self.end_headers()
            self.wfile.write(data)

    bind = args.bind or "localhost"
    print(f"serving {bind}:{args.port}/{{z}}/{{x}}/{{y}}.{fmt}, for development only")
    httpd = ThreadingSimpleServer((args.bind or "", int(args.port)), Handler)
    httpd.serve_forever()
