From abd77e4db6acbbf74b918797f033d0342025d704 Mon Sep 17 00:00:00 2001 From: Andrej Shadura <andrew.shadura@collabora.co.uk> Date: Wed, 19 Apr 2023 15:33:31 +0200 Subject: [PATCH] Add configuration de/serialisation code This is necessary to enable the client to be able to only specify minimal configuration, and fetch the rest from the server. Signed-off-by: Andrej Shadura <andrew.shadura@collabora.co.uk> --- obs_proxy/config.py | 101 +++++++++++++++++++++------ tests/test_config.py | 161 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+), 22 deletions(-) create mode 100644 tests/test_config.py diff --git a/obs_proxy/config.py b/obs_proxy/config.py index e312e11..e6f2444 100644 --- a/obs_proxy/config.py +++ b/obs_proxy/config.py @@ -14,10 +14,10 @@ import os import sys from configparser import ConfigParser -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from logging import error from pathlib import Path -from typing import Iterable +from typing import Iterable, Mapping HTTP_TIMEOUT = 20 * 60 @@ -49,6 +49,10 @@ class ServerConfig: keyfile: Path = None certfile: Path = None + def __post_init__(self): + self.keyfile = path_or_none(self.keyfile) + self.certfile = path_or_none(self.certfile) + @dataclass class ClientConfig: @@ -88,7 +92,7 @@ class Config: client: ClientConfig auth: AuthConfig backend: BackendConfig - worker_ports: Iterable[int] + worker_ports: range debug: bool tracedir: Path cachedir: Path @@ -97,10 +101,15 @@ class Config: def parse_file(cls, f): parser = ConfigParser() print(f"Parsing {f}") - parser.read(f) - startport = parser["workers"].getint("startport") - endport = parser["workers"].getint("endport") - worker_ports = range(startport, endport + 1) + if isinstance(f, os.PathLike) or isinstance(f, str): + f = open(f) + parser.read_file(f) + if "workers" in parser.sections(): + startport = parser["workers"].getint("startport") + endport = parser["workers"].getint("endport") + worker_ports = range(startport, endport + 1) + else: + worker_ports = None server = { "host": parser["server"].get("host"), @@ -109,13 +118,16 @@ class Config: "http2": parser["server"].getboolean("http2", False), "tls": parser["server"].getboolean("tls", False), "insecure": parser["server"].getboolean("insecure", False), - "keyfile": path_or_none(parser["server"]["keyfile"]), - "certfile": path_or_none(parser["server"]["certfile"]), + "keyfile": parser["server"].get("keyfile"), + "certfile": parser["server"].get("certfile"), "buffer_uploads": parser["server"].getboolean("buffer_uploads", False), } + if "client" not in parser.sections(): + parser.add_section("client") + client = { - "host": parser["client"].get("host"), + "host": parser["client"].get("host", "0.0.0.0"), "port": parser["client"].getint("port", 5000), # buffer uploads on clients only "buffer_uploads": parser["client"].getboolean("buffer_uploads", True), @@ -124,18 +136,19 @@ class Config: error("HTTP proxies are no longer supported, please adjust your configuration and remove proxy=.") sys.exit(1) + if "debug" not in parser.sections(): + parser.add_section("debug") + debug = ( parser["server"].getboolean("debug", False) or parser["client"].getboolean("debug", False) or - ('debug' in parser and parser["debug"].getboolean("enabled", False)) + parser["debug"].getboolean("enabled", False) ) tracedir = path_or_none( parser["server"].get("tracedir", None) or parser["client"].get("tracedir", None) or - ('debug' in parser and parser["debug"].get("tracedir", None)), + parser["debug"].get("tracedir", None), ) - if tracedir: - tracedir.mkdir(parents=True, exist_ok=True) auth = { "username": parser["auth"].get("username"), @@ -145,14 +158,20 @@ class Config: cachedir = path_or_none(parser["cache"].get("cachedir", None)) if 'cache' in parser else None - srcserver_host, srcserver_port = parser["backend"]["srcserver"].split(":") - repserver_host, repserver_port = parser["backend"]["repserver"].split(":") - backend_host = parser["backend"]["host"] - - backend = { - "srcserver_uri": "http://%s:%s" % (srcserver_host or backend_host, srcserver_port), - "repserver_uri": "http://%s:%s" % (repserver_host or backend_host, repserver_port), - } + if "backend" in parser.sections(): + srcserver_host, srcserver_port = parser["backend"]["srcserver"].split(":") + repserver_host, repserver_port = parser["backend"]["repserver"].split(":") + backend_host = parser["backend"].get("host") + + backend = { + "srcserver_uri": "http://%s:%s" % (srcserver_host or backend_host, srcserver_port), + "repserver_uri": "http://%s:%s" % (repserver_host or backend_host, repserver_port), + } + else: + backend = { + "srcserver_uri": None, + "repserver_uri": None, + } return cls( server=ServerConfig(**server), @@ -166,6 +185,44 @@ class Config: ) + def __post_init__(self): + if self.tracedir: + self.tracedir.mkdir(parents=True, exist_ok=True) + + + def as_dict(self): + def serialize(o): + if is_dataclass(o): + return serialize(o.__dict__) + if isinstance(o, Mapping): + return {k: serialize(v) for k, v in o.items()} + if isinstance(o, range): + return {'start': o.start, 'end': o.stop - 1} + if isinstance(o, os.PathLike): + return str(o) + return o + + return serialize(self) + + + def update_from_dict(self, d): + def deserialize(o, c, d): + if is_dataclass(o): + for k in d.keys(): + if k in o.__annotations__: + cls = o.__annotations__[k] + o.__dict__[k] = deserialize(o.__dict__[k], cls, d[k]) + elif c is range: + o = range(d['start'], d['end'] + 1) + elif c is Path: + o = path_or_none(d) + else: + o = d + return o + + deserialize(self, type(self), d) + + appname = 'obs-proxy' diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..c8eaaf4 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,161 @@ +import io + +import pytest +from obs_proxy.config import * + + +basic_config = """ +[server] +host = 1.2.3.4 +port = 1234 +http2 = yes +tls = no +[client] +host = 2.3.4.5 +port = 4321 +[backend] +srcserver = backend1:2345 +repserver = backend2:5252 +[auth] +token = test-token +[workers] +startport = 6123 +endport = 6134 +""" + + +basic_config_dict = { + 'server': { + 'host': '1.2.3.4', + 'port': 1234, + 'prefix': '', + 'buffer_uploads': False, + 'certfile': None, + 'http2': True, + 'insecure': False, + 'keyfile': None, + 'tls': False, + }, + 'client': { + 'host': '2.3.4.5', + 'port': 4321, + 'buffer_uploads': True, + }, + 'auth': { + 'username': None, + 'password': None, + 'token': 'test-token', + }, + 'backend': { + 'repserver_uri': 'http://backend2:5252', + 'srcserver_uri': 'http://backend1:2345', + }, + 'debug': False, + 'tracedir': None, + 'cachedir': None, + 'worker_ports': { + 'start': 6123, + 'end': 6134, + }, + +} + + +def test_config_parser(): + f = io.StringIO(basic_config) + c = Config.parse_file(f) + assert c.server == ServerConfig( + host='1.2.3.4', + port=1234, + prefix="", + http2=True, + tls=False, + insecure=False, + buffer_uploads=False, + keyfile=None, + certfile=None, + ) + assert c.client == ClientConfig( + host='2.3.4.5', + port=4321, + buffer_uploads=True, + ) + assert c.auth == AuthConfig( + username=None, + password=None, + token='test-token', + ) + assert c.backend == BackendConfig( + srcserver_uri='http://backend1:2345', + repserver_uri='http://backend2:5252', + ) + assert c.debug is False + assert c.worker_ports == range(6123, 6135) + assert c.tracedir is None + assert c.cachedir is None + + +def test_config_parser_dict(): + f = io.StringIO(basic_config) + c = Config.parse_file(f) + assert c.as_dict() == basic_config_dict + + +mini_config = """ +[server] +host = 3.4.5.6 +port = 5678 +prefix = /baz +[auth] +username = luser +password = drosswap +""" + +def test_mini_config(): + f = io.StringIO(mini_config) + c = Config.parse_file(f) + assert c == Config( + server=ServerConfig( + host='3.4.5.6', + port=5678, + prefix='/baz', + http2=False, + tls=False, + insecure=False, + buffer_uploads=False, + keyfile=None, + certfile=None + ), + client=ClientConfig( + host='0.0.0.0', + port=5000, + buffer_uploads=True + ), + auth=AuthConfig( + username='luser', + password='drosswap', + token=None, + ), + backend=BackendConfig( + srcserver_uri=None, + repserver_uri=None + ), + worker_ports=None, + debug=False, + tracedir=None, + cachedir=None + ) + + +def test_mini_config_update(): + f = io.StringIO(mini_config) + mini_conf = Config.parse_file(f) + + f = io.StringIO(basic_config) + full_conf = Config.parse_file(f) + + mini_conf.update_from_dict(full_conf.as_dict()) + + assert mini_conf.as_dict() == basic_config_dict + + assert mini_conf.worker_ports == range(6123, 6135) -- GitLab