Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions pytest_postgresql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
# along with pytest-postgresql. If not, see <http://www.gnu.org/licenses/>.
"""PostgreSQL executor crafter around pg_ctl."""

import os
import os.path
import platform
import re
import shutil
import signal
import subprocess
import tempfile
import time
Expand Down Expand Up @@ -48,13 +50,20 @@ class PostgreSQLExecutor(TCPExecutor):
<http://www.postgresql.org/docs/current/static/app-pg-ctl.html>`_
"""

BASE_PROC_START_COMMAND = (
'{executable} start -D "{datadir}" '
"-o \"-F -p {port} -c log_destination='stderr' "
"-c logging_collector=off "
"-c unix_socket_directories='{unixsocketdir}' {postgres_options}\" "
'-l "{logfile}" {startparams}'
)
def _get_base_command(self) -> str:
"""Get the base PostgreSQL command, cross-platform compatible."""
# Use unified format without single quotes around values
# This format works on both Windows and Unix systems since PostgreSQL
# configuration values without spaces don't require quotes
return (
'{executable} start -D "{datadir}" '
'-o "-F -p {port} -c log_destination=stderr '
"-c logging_collector=off "
'-c unix_socket_directories={unixsocketdir} {postgres_options}" '
'-l "{logfile}" {startparams}'
)

BASE_PROC_START_COMMAND = "" # Will be set dynamically

VERSION_RE = re.compile(r".* (?P<version>\d+(?:\.\d+)?)")
MIN_SUPPORTED_VERSION = parse("10")
Expand Down Expand Up @@ -108,7 +117,7 @@ def __init__(
self.logfile = logfile
self.startparams = startparams
self.postgres_options = postgres_options
command = self.BASE_PROC_START_COMMAND.format(
command = self._get_base_command().format(
executable=self.executable,
datadir=self.datadir,
port=port,
Expand Down Expand Up @@ -219,17 +228,58 @@ def running(self) -> bool:
status_code = subprocess.getstatusoutput(f'{self.executable} status -D "{self.datadir}"')[0]
return status_code == 0

def _windows_terminate_process(self, sig: Optional[int] = None) -> None:
"""Terminate process on Windows."""
if self.process is None:
return

try:
# On Windows, try to terminate gracefully first
self.process.terminate()
# Give it a chance to terminate gracefully
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
# If it doesn't terminate gracefully, force kill
self.process.kill()
self.process.wait()
except (OSError, AttributeError):
# Process might already be dead or other issues
pass

def _unix_terminate_process(self, sig: Optional[int] = None) -> None:
"""Terminate process on Unix systems."""
if self.process is None:
return

try:
# On Unix systems, use the signal
actual_sig = sig or signal.SIGTERM
os.killpg(self.process.pid, actual_sig)
except (OSError, AttributeError):
# Process might already be dead or other issues
pass

def stop(self: T, sig: Optional[int] = None, exp_sig: Optional[int] = None) -> T:
"""Issue a stop request to executable."""
subprocess.check_output(
f'{self.executable} stop -D "{self.datadir}" -m f',
shell=True,
)
try:
super().stop(sig, exp_sig)
if platform.system() == "Windows":
self._windows_terminate_process(sig)
else:
super().stop(sig, exp_sig)
except ProcessFinishedWithError:
# Finished, leftovers ought to be cleaned afterwards anyway
pass
except AttributeError as e:
# Handle case where os.killpg doesn't exist (shouldn't happen now)
if "killpg" in str(e):
self._windows_terminate_process(sig)
else:
raise
return self

def __del__(self) -> None:
Expand Down
207 changes: 207 additions & 0 deletions tests/test_windows_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""Test Windows compatibility fixes for pytest-postgresql."""

import subprocess
from unittest.mock import MagicMock, patch

from pytest_postgresql.executor import PostgreSQLExecutor


class TestWindowsCompatibility:
"""Test Windows-specific functionality."""

def test_get_base_command_unified(self):
"""Test that base command is unified across platforms."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Test that command format is consistent across platforms
with patch("platform.system", return_value="Windows"):
windows_command = executor._get_base_command()

with patch("platform.system", return_value="Linux"):
unix_command = executor._get_base_command()

# Both should be the same now
assert windows_command == unix_command

# Both should use the simplified format without single quotes
assert "log_destination=stderr" in windows_command
assert "log_destination='stderr'" not in windows_command
assert "unix_socket_directories={unixsocketdir}" in windows_command
assert "unix_socket_directories='{unixsocketdir}'" not in windows_command

def test_windows_terminate_process(self):
"""Test Windows process termination."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Mock process
mock_process = MagicMock()
executor.process = mock_process

# No need to mock platform.system() since the method doesn't check it anymore
executor._windows_terminate_process()

# Should call terminate first
mock_process.terminate.assert_called_once()
mock_process.wait.assert_called()

def test_windows_terminate_process_force_kill(self):
"""Test Windows process termination with force kill on timeout."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Mock process that times out
mock_process = MagicMock()
mock_process.wait.side_effect = [subprocess.TimeoutExpired(cmd="test", timeout=5), None]
executor.process = mock_process

# No need to mock platform.system() since the method doesn't check it anymore
executor._windows_terminate_process()

# Should call terminate, wait (timeout), then kill, then wait again
mock_process.terminate.assert_called_once()
mock_process.kill.assert_called_once()
assert mock_process.wait.call_count == 2

def test_stop_method_windows(self):
"""Test stop method on Windows."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Mock subprocess and process
with (
patch("subprocess.check_output") as mock_subprocess,
patch("platform.system", return_value="Windows"),
patch.object(executor, "_windows_terminate_process") as mock_terminate,
):

result = executor.stop()

# Should call pg_ctl stop and Windows terminate
mock_subprocess.assert_called_once()
mock_terminate.assert_called_once_with(None)
assert result is executor

def test_stop_method_unix(self):
"""Test stop method on Unix systems."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Mock subprocess and super().stop
with (
patch("subprocess.check_output") as mock_subprocess,
patch("platform.system", return_value="Linux"),
patch("pytest_postgresql.executor.TCPExecutor.stop") as mock_super_stop,
):

mock_super_stop.return_value = executor
result = executor.stop()

# Should call pg_ctl stop and parent class stop
mock_subprocess.assert_called_once()
mock_super_stop.assert_called_once_with(None, None)
assert result is executor

def test_stop_method_fallback_on_killpg_error(self):
"""Test stop method falls back to Windows termination on killpg AttributeError."""
executor = PostgreSQLExecutor(
executable="/path/to/pg_ctl",
host="localhost",
port=5432,
datadir="/tmp/data",
unixsocketdir="/tmp/socket",
logfile="/tmp/log",
startparams="-w",
dbname="test",
)

# Mock subprocess and super().stop to raise AttributeError
with (
patch("subprocess.check_output") as mock_subprocess,
patch("platform.system", return_value="Linux"),
patch(
"pytest_postgresql.executor.TCPExecutor.stop",
side_effect=AttributeError("module 'os' has no attribute 'killpg'"),
),
patch.object(executor, "_windows_terminate_process") as mock_terminate,
):

result = executor.stop()

# Should call pg_ctl stop, fail on super().stop, then use Windows terminate
mock_subprocess.assert_called_once()
mock_terminate.assert_called_once_with(None)
assert result is executor

def test_command_formatting_windows(self):
"""Test that command is properly formatted for Windows."""
with patch("platform.system", return_value="Windows"):
executor = PostgreSQLExecutor(
executable="C:/Program Files/PostgreSQL/bin/pg_ctl.exe",
host="localhost",
port=5555,
datadir="C:/temp/data",
unixsocketdir="C:/temp/socket",
logfile="C:/temp/log.txt",
startparams="-w -s",
dbname="testdb",
postgres_options="-c shared_preload_libraries=test",
)

# The command should be properly formatted without quotes around stderr
expected_parts = [
"C:/Program Files/PostgreSQL/bin/pg_ctl.exe start",
'-D "C:/temp/data"',
'-o "-F -p 5555 -c log_destination=stderr',
"-c logging_collector=off",
"-c unix_socket_directories=C:/temp/socket",
'-c shared_preload_libraries=test"',
'-l "C:/temp/log.txt"',
"-w -s",
]

# Check if all expected parts are in the command
command = executor.command
for part in expected_parts:
assert part in command, f"Expected '{part}' in command: {command}"