From abd77e4db6acbbf74b918797f033d0342025d704 Mon Sep 17 00:00:00 2001
From: Andrej Shadura <andrew.shadura@collabora.co.uk>
Date: Wed, 19 Apr 2023 15:33:31 +0200
Subject: [PATCH] Add configuration de/serialisation code

This is necessary to enable the client to be able to only specify
minimal configuration, and fetch the rest from the server.

Signed-off-by: Andrej Shadura <andrew.shadura@collabora.co.uk>
---
 obs_proxy/config.py  | 101 +++++++++++++++++++++------
 tests/test_config.py | 161 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 240 insertions(+), 22 deletions(-)
 create mode 100644 tests/test_config.py

diff --git a/obs_proxy/config.py b/obs_proxy/config.py
index e312e11..e6f2444 100644
--- a/obs_proxy/config.py
+++ b/obs_proxy/config.py
@@ -14,10 +14,10 @@
 import os
 import sys
 from configparser import ConfigParser
-from dataclasses import dataclass
+from dataclasses import dataclass, is_dataclass
 from logging import error
 from pathlib import Path
-from typing import Iterable
+from typing import Iterable, Mapping
 
 HTTP_TIMEOUT = 20 * 60
 
@@ -49,6 +49,10 @@ class ServerConfig:
     keyfile: Path = None
     certfile: Path = None
 
+    def __post_init__(self):
+        self.keyfile = path_or_none(self.keyfile)
+        self.certfile = path_or_none(self.certfile)
+
 
 @dataclass
 class ClientConfig:
@@ -88,7 +92,7 @@ class Config:
     client: ClientConfig
     auth: AuthConfig
     backend: BackendConfig
-    worker_ports: Iterable[int]
+    worker_ports: range
     debug: bool
     tracedir: Path
     cachedir: Path
@@ -97,10 +101,15 @@ class Config:
     def parse_file(cls, f):
         parser = ConfigParser()
         print(f"Parsing {f}")
-        parser.read(f)
-        startport = parser["workers"].getint("startport")
-        endport = parser["workers"].getint("endport")
-        worker_ports = range(startport, endport + 1)
+        if isinstance(f, os.PathLike) or isinstance(f, str):
+            f = open(f)
+        parser.read_file(f)
+        if "workers" in parser.sections():
+            startport = parser["workers"].getint("startport")
+            endport = parser["workers"].getint("endport")
+            worker_ports = range(startport, endport + 1)
+        else:
+            worker_ports = None
 
         server = {
             "host": parser["server"].get("host"),
@@ -109,13 +118,16 @@ class Config:
             "http2": parser["server"].getboolean("http2", False),
             "tls": parser["server"].getboolean("tls", False),
             "insecure": parser["server"].getboolean("insecure", False),
-            "keyfile": path_or_none(parser["server"]["keyfile"]),
-            "certfile": path_or_none(parser["server"]["certfile"]),
+            "keyfile": parser["server"].get("keyfile"),
+            "certfile": parser["server"].get("certfile"),
             "buffer_uploads": parser["server"].getboolean("buffer_uploads", False),
         }
 
+        if "client" not in parser.sections():
+            parser.add_section("client")
+
         client = {
-            "host": parser["client"].get("host"),
+            "host": parser["client"].get("host", "0.0.0.0"),
             "port": parser["client"].getint("port", 5000),
             # buffer uploads on clients only
             "buffer_uploads": parser["client"].getboolean("buffer_uploads", True),
@@ -124,18 +136,19 @@ class Config:
             error("HTTP proxies are no longer supported, please adjust your configuration and remove proxy=.")
             sys.exit(1)
 
+        if "debug" not in parser.sections():
+            parser.add_section("debug")
+
         debug = (
             parser["server"].getboolean("debug", False) or
             parser["client"].getboolean("debug", False) or
-            ('debug' in parser and parser["debug"].getboolean("enabled", False))
+            parser["debug"].getboolean("enabled", False)
         )
         tracedir = path_or_none(
             parser["server"].get("tracedir", None) or
             parser["client"].get("tracedir", None) or
-            ('debug' in parser and parser["debug"].get("tracedir", None)),
+            parser["debug"].get("tracedir", None),
         )
-        if tracedir:
-            tracedir.mkdir(parents=True, exist_ok=True)
 
         auth = {
             "username": parser["auth"].get("username"),
@@ -145,14 +158,20 @@ class Config:
 
         cachedir = path_or_none(parser["cache"].get("cachedir", None)) if 'cache' in parser else None
 
-        srcserver_host, srcserver_port = parser["backend"]["srcserver"].split(":")
-        repserver_host, repserver_port = parser["backend"]["repserver"].split(":")
-        backend_host = parser["backend"]["host"]
-
-        backend = {
-            "srcserver_uri": "http://%s:%s" % (srcserver_host or backend_host, srcserver_port),
-            "repserver_uri": "http://%s:%s" % (repserver_host or backend_host, repserver_port),
-        }
+        if "backend" in parser.sections():
+            srcserver_host, srcserver_port = parser["backend"]["srcserver"].split(":")
+            repserver_host, repserver_port = parser["backend"]["repserver"].split(":")
+            backend_host = parser["backend"].get("host")
+
+            backend = {
+                "srcserver_uri": "http://%s:%s" % (srcserver_host or backend_host, srcserver_port),
+                "repserver_uri": "http://%s:%s" % (repserver_host or backend_host, repserver_port),
+            }
+        else:
+            backend = {
+                "srcserver_uri": None,
+                "repserver_uri": None,
+            }
 
         return cls(
             server=ServerConfig(**server),
@@ -166,6 +185,44 @@ class Config:
         )
 
 
+    def __post_init__(self):
+        if self.tracedir:
+            self.tracedir.mkdir(parents=True, exist_ok=True)
+
+
+    def as_dict(self):
+        def serialize(o):
+            if is_dataclass(o):
+                return serialize(o.__dict__)
+            if isinstance(o, Mapping):
+                return {k: serialize(v) for k, v in o.items()}
+            if isinstance(o, range):
+                return {'start': o.start, 'end': o.stop - 1}
+            if isinstance(o, os.PathLike):
+                return str(o)
+            return o
+
+        return serialize(self)
+
+
+    def update_from_dict(self, d):
+        def deserialize(o, c, d):
+            if is_dataclass(o):
+                for k in d.keys():
+                    if k in o.__annotations__:
+                        cls = o.__annotations__[k]
+                        o.__dict__[k] = deserialize(o.__dict__[k], cls, d[k])
+            elif c is range:
+                o = range(d['start'], d['end'] + 1)
+            elif c is Path:
+                o = path_or_none(d)
+            else:
+                o = d
+            return o
+
+        deserialize(self, type(self), d)
+
+
 appname = 'obs-proxy'
 
 
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..c8eaaf4
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,161 @@
+import io
+
+import pytest
+from obs_proxy.config import *
+
+
+basic_config = """
+[server]
+host = 1.2.3.4
+port = 1234
+http2 = yes
+tls = no
+[client]
+host = 2.3.4.5
+port = 4321
+[backend]
+srcserver = backend1:2345
+repserver = backend2:5252
+[auth]
+token = test-token
+[workers]
+startport = 6123
+endport = 6134
+"""
+
+
+basic_config_dict = {
+    'server': {
+        'host': '1.2.3.4',
+        'port': 1234,
+        'prefix': '',
+        'buffer_uploads': False,
+        'certfile': None,
+        'http2': True,
+        'insecure': False,
+        'keyfile': None,
+        'tls': False,
+    },
+    'client': {
+        'host': '2.3.4.5',
+        'port': 4321,
+        'buffer_uploads': True,
+    },
+    'auth': {
+        'username': None,
+        'password': None,
+        'token': 'test-token',
+    },
+    'backend': {
+        'repserver_uri': 'http://backend2:5252',
+        'srcserver_uri': 'http://backend1:2345',
+    },
+    'debug': False,
+    'tracedir': None,
+    'cachedir': None,
+    'worker_ports': {
+        'start': 6123,
+        'end': 6134,
+    },
+
+}
+
+
+def test_config_parser():
+    f = io.StringIO(basic_config)
+    c = Config.parse_file(f)
+    assert c.server == ServerConfig(
+        host='1.2.3.4',
+        port=1234,
+        prefix="",
+        http2=True,
+        tls=False,
+        insecure=False,
+        buffer_uploads=False,
+        keyfile=None,
+        certfile=None,
+    )
+    assert c.client == ClientConfig(
+        host='2.3.4.5',
+        port=4321,
+        buffer_uploads=True,
+    )
+    assert c.auth == AuthConfig(
+        username=None,
+        password=None,
+        token='test-token',
+    )
+    assert c.backend == BackendConfig(
+        srcserver_uri='http://backend1:2345',
+        repserver_uri='http://backend2:5252',
+    )
+    assert c.debug is False
+    assert c.worker_ports == range(6123, 6135)
+    assert c.tracedir is None
+    assert c.cachedir is None
+
+
+def test_config_parser_dict():
+    f = io.StringIO(basic_config)
+    c = Config.parse_file(f)
+    assert c.as_dict() == basic_config_dict
+
+
+mini_config = """
+[server]
+host = 3.4.5.6
+port = 5678
+prefix = /baz
+[auth]
+username = luser
+password = drosswap
+"""
+
+def test_mini_config():
+    f = io.StringIO(mini_config)
+    c = Config.parse_file(f)
+    assert c == Config(
+        server=ServerConfig(
+            host='3.4.5.6',
+            port=5678,
+            prefix='/baz',
+            http2=False,
+            tls=False,
+            insecure=False,
+            buffer_uploads=False,
+            keyfile=None,
+            certfile=None
+        ),
+        client=ClientConfig(
+            host='0.0.0.0',
+            port=5000,
+            buffer_uploads=True
+        ),
+        auth=AuthConfig(
+            username='luser',
+            password='drosswap',
+            token=None,
+        ),
+        backend=BackendConfig(
+            srcserver_uri=None,
+            repserver_uri=None
+        ),
+        worker_ports=None,
+        debug=False,
+        tracedir=None,
+        cachedir=None
+    )
+
+
+def test_mini_config_update():
+    f = io.StringIO(mini_config)
+    mini_conf = Config.parse_file(f)
+
+    f = io.StringIO(basic_config)
+    full_conf = Config.parse_file(f)
+
+    mini_conf.update_from_dict(full_conf.as_dict())
+
+    assert mini_conf.as_dict() == basic_config_dict
+
+    assert mini_conf.worker_ports == range(6123, 6135)
-- 
GitLab