diff --git a/backend/OBSController.py b/backend/OBSController.py index 73b5d18..2dea994 100644 --- a/backend/OBSController.py +++ b/backend/OBSController.py @@ -11,26 +11,73 @@ def __init__(self): self.event_obs: obsws = None # All events are connected to this to avoid crash if a request is made in an event pass - def validate_ip(self, host: str): - if host in ("localhost", "127.0.0.1"): + def validate_dns(self, host: str) -> bool: + """ + Validate hostname/DNS entry. + + Valid hostname: alphanumeric, hyphens, dots, max 253 chars. + Each label (between dots) max 63 chars, can't start/end with hyphen. + + Returns True if the hostname is valid, False otherwise. + """ + # Valid hostname: alphanumeric, hyphens, dots, max 253 chars + # Each label (between dots) max 63 chars, can't start/end with hyphen + if len(host) > 253: + return False + + # Allow localhost explicitly + if host.lower() == "localhost": return True - # We're explicitly disallowing non-localhost DNS entries here. - # Continuing this pattern for now, but this is probably the wrong thing - # to do long-term. + # Split into labels and validate each + labels = host.split('.') + if not labels: + return False - try: - addr = ipaddress.ip_address(host) + # Pattern for valid hostname labels + for label in labels: + if not label or len(label) > 63: + return False + # Must start and end with alphanumeric + if not (label[0].isalnum() and label[-1].isalnum()): + return False + # Labels can't be all numeric (would be confused with IP) + if label.isdigit(): + # If all labels are numeric, it looks like an IP address + # and should have been caught by ipaddress.ip_address() + if all(l.isdigit() for l in labels): + return False + # Middle characters can be alphanumeric or hyphen + for char in label: + if not (char.isalnum() or char == '-'): + return False + + return True + + def validate_ip(self, host: str): + """ + Validate host address (IPv4, IPv6, or hostname). - # And we're disallowing IPv6 entries here, for compatibility with - # previous implementations. Again, probably the wrong thing - # long-term, but implementing this way to mitigate risk while we're - # in a bad-push state. - if not addr.version == ipaddress.IPv4Address.version: - raise ValueError() + Returns True if the host is valid, False otherwise. + """ + if not host or not host.strip(): + return False + + host = host.strip() + + # Handle bracket-wrapped IPv6 addresses [::1] + if host.startswith('[') and host.endswith(']'): + host = host[1:-1] + + # Try to parse as IP address (IPv4 or IPv6) + try: + ipaddress.ip_address(host) return True except ValueError: - return False + pass + + # Try to validate as hostname/DNS entry + return self.validate_dns(host) def on_connect(self, obs): self.connected = True @@ -58,18 +105,32 @@ def connect_to(self, host=None, port=None, timeout=1, legacy=False, **kwargs): self.event_obs.disconnect() return False + # For IPv6 addresses, wrap in brackets if not already wrapped + # This is required for WebSocket URL construction (ws://[::1]:port) + connection_host = host + try: + addr = ipaddress.ip_address(host) + if isinstance(addr, ipaddress.IPv6Address): + # Only wrap if not already wrapped + if not (host.startswith('[') and host.endswith(']')): + connection_host = f"[{host}]" + log.debug(f"Wrapped IPv6 address: {host} -> {connection_host}") + except ValueError: + # Not an IP address, use as-is (hostname) + pass + try: log.debug(f"Trying to connect to obs with legacy: {legacy}") - super().__init__(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) - self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) + super().__init__(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) + self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) self.connect() log.info("Successfully connected to OBS") return True except (obswebsocket.exceptions.ConnectionFailure, ValueError) as e: try: log.error(f"Failed to connect to OBS with legacy: {legacy}, trying with legacy: {not legacy}") - super().__init__(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) - self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) + super().__init__(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) + self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs) self.connect() log.info("Successfully connected to OBS")