Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 79 additions & 18 deletions backend/OBSController.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,73 @@ def __init__(self):
self.event_obs: obsws = None # All events are connected to this to avoid crash if a request is made in an event
pass

def validate_ip(self, host: str):
if host in ("localhost", "127.0.0.1"):
def validate_dns(self, host: str) -> bool:
"""
Validate hostname/DNS entry.

Valid hostname: alphanumeric, hyphens, dots, max 253 chars.
Each label (between dots) max 63 chars, can't start/end with hyphen.

Returns True if the hostname is valid, False otherwise.
"""
# Valid hostname: alphanumeric, hyphens, dots, max 253 chars
# Each label (between dots) max 63 chars, can't start/end with hyphen
if len(host) > 253:
return False

# Allow localhost explicitly
if host.lower() == "localhost":
return True

# We're explicitly disallowing non-localhost DNS entries here.
# Continuing this pattern for now, but this is probably the wrong thing
# to do long-term.
# Split into labels and validate each
labels = host.split('.')
if not labels:
return False

try:
addr = ipaddress.ip_address(host)
# Pattern for valid hostname labels
for label in labels:
if not label or len(label) > 63:
return False
# Must start and end with alphanumeric
if not (label[0].isalnum() and label[-1].isalnum()):
return False
# Labels can't be all numeric (would be confused with IP)
if label.isdigit():
# If all labels are numeric, it looks like an IP address
# and should have been caught by ipaddress.ip_address()
if all(l.isdigit() for l in labels):
return False
# Middle characters can be alphanumeric or hyphen
for char in label:
if not (char.isalnum() or char == '-'):
return False

return True

def validate_ip(self, host: str):
"""
Validate host address (IPv4, IPv6, or hostname).

# And we're disallowing IPv6 entries here, for compatibility with
# previous implementations. Again, probably the wrong thing
# long-term, but implementing this way to mitigate risk while we're
# in a bad-push state.
if not addr.version == ipaddress.IPv4Address.version:
raise ValueError()
Returns True if the host is valid, False otherwise.
"""
if not host or not host.strip():
return False

host = host.strip()

# Handle bracket-wrapped IPv6 addresses [::1]
if host.startswith('[') and host.endswith(']'):
host = host[1:-1]

# Try to parse as IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(host)
return True
except ValueError:
return False
pass

# Try to validate as hostname/DNS entry
return self.validate_dns(host)

def on_connect(self, obs):
self.connected = True
Expand Down Expand Up @@ -58,18 +105,32 @@ def connect_to(self, host=None, port=None, timeout=1, legacy=False, **kwargs):
self.event_obs.disconnect()
return False

# For IPv6 addresses, wrap in brackets if not already wrapped
# This is required for WebSocket URL construction (ws://[::1]:port)
connection_host = host
try:
addr = ipaddress.ip_address(host)
if isinstance(addr, ipaddress.IPv6Address):
# Only wrap if not already wrapped
if not (host.startswith('[') and host.endswith(']')):
connection_host = f"[{host}]"
log.debug(f"Wrapped IPv6 address: {host} -> {connection_host}")
except ValueError:
# Not an IP address, use as-is (hostname)
pass

try:
log.debug(f"Trying to connect to obs with legacy: {legacy}")
super().__init__(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
super().__init__(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.connect()
log.info("Successfully connected to OBS")
return True
except (obswebsocket.exceptions.ConnectionFailure, ValueError) as e:
try:
log.error(f"Failed to connect to OBS with legacy: {legacy}, trying with legacy: {not legacy}")
super().__init__(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
super().__init__(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.connect()
log.info("Successfully connected to OBS")

Expand Down