Skip to content

Commit 7d6605f

Browse files
committed
Implement windows
Signed-off-by: Cristian Le <git@lecris.dev>
1 parent eddd870 commit 7d6605f

File tree

2 files changed

+260
-7
lines changed

2 files changed

+260
-7
lines changed

src/scikit_build_core/repair_wheel/windows.py

Lines changed: 256 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
from __future__ import annotations
66

77
import dataclasses
8-
from typing import TYPE_CHECKING
8+
import os.path
9+
import textwrap
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, ClassVar
912

10-
from .base import WheelRepairer
13+
from .._logging import logger
14+
from .base import WheelRepairer, _get_buildenv_platlib
1115

1216
if TYPE_CHECKING:
1317
from ..file_api.model.codemodel import Target
@@ -22,11 +26,258 @@ def __dir__() -> list[str]:
2226
@dataclasses.dataclass
2327
class WindowsWheelRepairer(WheelRepairer):
2428
"""
25-
Do some windows specific magic.
29+
Patch the package and top-level python module files with ``os.add_dll_directory``.
2630
"""
2731

2832
_platform = "Windows"
2933

34+
PATCH_PY_FILE: ClassVar[str] = textwrap.dedent("""\
35+
# start scikit-build-core Windows patch
36+
def _skbuild_patch_dll_dir():
37+
import os
38+
import os.path
39+
40+
mod_dir = os.path.abspath(os.path.dirname(__file__))
41+
path_to_platlib = os.path.normpath({path_to_platlib!r})
42+
dll_paths = {dll_paths!r}
43+
for path in dll_paths:
44+
path = os.path.normpath(path)
45+
path = os.path.join(mod_dir, path_to_platlib, path)
46+
os.add_dll_directory(path)
47+
48+
_skbuild_patch_dll_dir()
49+
del _skbuild_patch_dll_dir
50+
# end scikit-build-core Windows patch
51+
""")
52+
dll_dirs: set[Path] = dataclasses.field(default_factory=set, init=False)
53+
"""All dll paths used relative to ``platlib``."""
54+
55+
def get_dll_path_from_lib(self, lib_path: Path) -> Path | None:
56+
"""Guess the dll path from lib path."""
57+
dll_path = None
58+
platlib = Path(_get_buildenv_platlib())
59+
lib_path = lib_path.relative_to(platlib)
60+
# Change the `.lib` to `.dll`
61+
if ".dll" in (suffixes := lib_path.suffixes):
62+
# In some cases like msys, they use `.dll.a`, in which case we can't use `with_suffix`
63+
if suffixes[-2] != ".dll":
64+
logger.warning(
65+
"Expected .dll suffix to be the penultimate extension, instead got: {lib_path}",
66+
lib_path=lib_path,
67+
)
68+
return None
69+
# Drop the last suffix it should then be just .dll file
70+
dll_name = lib_path.stem
71+
else:
72+
dll_name = lib_path.with_suffix(".dll").name
73+
# Try to find the dll in the same package directory
74+
if len(lib_path.parts) > 1:
75+
pkg_dir = lib_path.parts[0]
76+
for root, _, files in os.walk(platlib / pkg_dir):
77+
if dll_name in files:
78+
dll_path = Path(root) / dll_name
79+
break
80+
else:
81+
logger.debug(
82+
"Did not find the dll file under {pkg_dir}",
83+
pkg_dir=pkg_dir,
84+
)
85+
if not dll_path:
86+
logger.debug(
87+
"Looking for {dll_name} in all platlib path.",
88+
dll_name=dll_name,
89+
)
90+
for root, _, files in os.walk(platlib):
91+
if dll_name in files:
92+
dll_path = Path(root) / dll_name
93+
break
94+
else:
95+
logger.warning(
96+
"Could not find dll file {dll_name} corresponding to {lib_path}",
97+
dll_name=dll_name,
98+
lib_path=lib_path,
99+
)
100+
return None
101+
logger.debug(
102+
"Found dll file {dll_path}",
103+
dll_path=dll_path,
104+
)
105+
return self.path_relative_site_packages(dll_path)
106+
107+
def get_library_dependencies(self, target: Target) -> list[Target]:
108+
msg = "get_library_dependencies is not generalized for Windows."
109+
raise NotImplementedError(msg)
110+
111+
def get_dependency_dll(self, target: Target) -> list[Path]:
112+
"""Get the dll due to target link dependencies."""
113+
dll_paths = []
114+
for dep in target.dependencies:
115+
dep_target = next(targ for targ in self.targets if targ.id == dep.id)
116+
if dep_target.type != "SHARED_LIBRARY":
117+
logger.debug(
118+
"Skipping dependency {dep_target} of type {type}",
119+
dep_target=dep_target.name,
120+
type=dep_target.type,
121+
)
122+
continue
123+
if not dep_target.install:
124+
logger.warning(
125+
"Dependency {dep_target} is not installed",
126+
dep_target=dep_target.name,
127+
)
128+
continue
129+
dll_artifact = next(
130+
artifact.path.name
131+
for artifact in dep_target.artifacts
132+
if artifact.path.suffix == ".dll"
133+
)
134+
for install_path in self.get_wheel_install_paths(dep_target):
135+
dep_install_path = self.install_dir / install_path
136+
if (dep_install_path / dll_artifact).exists():
137+
break
138+
else:
139+
logger.warning(
140+
"Could not find installed {dll_artifact} location in install paths: {install_path}",
141+
dll_artifact=dll_artifact,
142+
install_path=[
143+
dest.path for dest in dep_target.install.destinations
144+
],
145+
)
146+
continue
147+
dll_path = self.path_relative_site_packages(dep_install_path)
148+
dll_paths.append(dll_path)
149+
return dll_paths
150+
151+
def get_package_dll(self, target: Target) -> list[Path]:
152+
"""
153+
Get the dll due to external package linkage.
154+
155+
Have to use the guess the dll paths until the package targets are exposed.
156+
https://gitlab.kitware.com/cmake/cmake/-/issues/26755
157+
"""
158+
if not target.link:
159+
return []
160+
dll_paths = []
161+
assert target.link.commandFragments is not None
162+
for link_command in target.link.commandFragments:
163+
if link_command.role == "flags":
164+
if not link_command.fragment:
165+
logger.debug(
166+
"Skipping {target} link-flags: {flags}",
167+
target=target.name,
168+
flags=link_command.fragment,
169+
)
170+
continue
171+
if link_command.role != "libraries":
172+
logger.warning(
173+
"File-api link role {role} is not supported. "
174+
"Target={target}, command={command}",
175+
target=target.name,
176+
role=link_command.role,
177+
command=link_command.fragment,
178+
)
179+
continue
180+
# The remaining case should be a path
181+
try:
182+
# TODO: how to best catch if a string is a valid path?
183+
lib_path = Path(link_command.fragment)
184+
if not lib_path.is_absolute():
185+
# If the link_command is a space-separated list of libraries, this should be skipped
186+
logger.debug(
187+
"Skipping non-absolute-path library: {fragment}",
188+
fragment=link_command.fragment,
189+
)
190+
continue
191+
try:
192+
self.path_relative_site_packages(lib_path)
193+
except ValueError:
194+
logger.debug(
195+
"Skipping library outside site-package path: {lib_path}",
196+
lib_path=lib_path,
197+
)
198+
continue
199+
dll_path = self.get_dll_path_from_lib(lib_path)
200+
if not dll_path:
201+
continue
202+
dll_paths.append(dll_path.parent)
203+
except Exception as exc:
204+
logger.warning(
205+
"Could not parse link-library as a path: {fragment}\nexc = {exc}",
206+
fragment=link_command.fragment,
207+
exc=exc,
208+
)
209+
continue
210+
return dll_paths
211+
30212
def patch_target(self, target: Target) -> None:
31-
# TODO: Implement patching
32-
pass
213+
# Here we just gather all dll paths needed for each target
214+
package_dlls = self.get_package_dll(target)
215+
dependency_dlls = self.get_dependency_dll(target)
216+
if not package_dlls and not dependency_dlls:
217+
logger.warning(
218+
"No dll files found for target {target}",
219+
target=target.name,
220+
)
221+
return
222+
logger.debug(
223+
"Found dlls for target {target}:\n"
224+
"package_dlls={package_dlls}\n"
225+
"dependency_dlls={dependency_dlls}\n",
226+
target=target.name,
227+
package_dlls=package_dlls,
228+
dependency_dlls=dependency_dlls,
229+
)
230+
self.dll_dirs.update(package_dlls)
231+
self.dll_dirs.update(dependency_dlls)
232+
233+
def patch_python_file(self, file: Path) -> None:
234+
"""
235+
Patch python package or top-level module.
236+
237+
Make sure the python files have an appropriate ``os.add_dll_directory``
238+
for the scripts directory.
239+
"""
240+
assert self.dll_dirs
241+
assert all(not path.is_absolute() for path in self.dll_dirs)
242+
logger.debug(
243+
"Patching python file: {file}",
244+
file=file,
245+
)
246+
platlib = Path(self.wheel_dirs["platlib"])
247+
content = file.read_text()
248+
mod_dir = file.parent
249+
path_to_platlib = os.path.relpath(platlib, mod_dir)
250+
patch_script = self.PATCH_PY_FILE.format(
251+
path_to_platlib=path_to_platlib,
252+
dll_paths=[str(path) for path in self.dll_dirs],
253+
)
254+
# TODO: Account for the header comments, __future__.annotations, etc.
255+
with file.open("w") as f:
256+
f.write(f"{patch_script}\n" + content)
257+
258+
def repair_wheel(self) -> None:
259+
super().repair_wheel()
260+
platlib = Path(self.wheel_dirs["platlib"])
261+
if not self.dll_dirs:
262+
logger.debug(
263+
"Skipping wheel repair because no site-package dlls were found."
264+
)
265+
return
266+
logger.debug(
267+
"Patching dll directories: {dll_dirs}",
268+
dll_dirs=self.dll_dirs,
269+
)
270+
# TODO: Not handling namespace packages with this
271+
for path in platlib.iterdir():
272+
assert isinstance(path, Path)
273+
if path.is_dir():
274+
pkg_file = path / "__init__.py"
275+
if not pkg_file.exists():
276+
logger.debug(
277+
"Ignoring non-python package: {pkg_file}",
278+
pkg_file=pkg_file,
279+
)
280+
continue
281+
self.patch_python_file(pkg_file)
282+
elif path.suffix == ".py":
283+
self.patch_python_file(path)

tests/test_repair_wheel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def test_full_build(
7272
wheels = list(dist.glob("*.whl"))
7373
isolated.install(*wheels)
7474

75-
isolated.run("main")
76-
isolated.module("repair_wheel")
75+
if platform.system() != "Windows":
76+
# Requires a more specialized patch
77+
isolated.run("main")
78+
isolated.module("repair_wheel")
7779
isolated.execute(
7880
"from repair_wheel._module import hello; hello()",
7981
)

0 commit comments

Comments
 (0)