Skip to content

Commit b4e45a0

Browse files
committed
Implement windows
Signed-off-by: Cristian Le <git@lecris.dev>
1 parent b16f66e commit b4e45a0

File tree

2 files changed

+261
-7
lines changed

2 files changed

+261
-7
lines changed

src/scikit_build_core/repair_wheel/windows.py

Lines changed: 257 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING
7+
import dataclasses
8+
import os.path
9+
import textwrap
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, ClassVar
812

9-
from . import WheelRepairer
13+
from .._logging import logger
14+
from . import WheelRepairer, _get_buildenv_platlib
1015

1116
if TYPE_CHECKING:
1217
from ..file_api.model.codemodel import Target
@@ -18,13 +23,260 @@ def __dir__() -> list[str]:
1823
return __all__
1924

2025

26+
@dataclasses.dataclass
2127
class WindowsWheelRepairer(WheelRepairer):
2228
"""
23-
Do some windows specific magic.
29+
Patch the package and top-level python module files with ``os.add_dll_directory``.
2430
"""
2531

2632
_platform = "Windows"
2733

34+
PATCH_PY_FILE: ClassVar[str] = textwrap.dedent("""\
35+
# start scikit-build-core Windows patch
36+
def _skbuild_patch_dll_dir():
37+
import os
38+
import os.path
39+
40+
mod_dir = os.path.abspath(os.path.dirname(__file__))
41+
path_to_platlib = os.path.normpath({path_to_platlib!r})
42+
dll_paths = {dll_paths!r}
43+
for path in dll_paths:
44+
path = os.path.normpath(path)
45+
path = os.path.join(mod_dir, path_to_platlib, path)
46+
os.add_dll_directory(path)
47+
48+
_skbuild_patch_dll_dir()
49+
del _skbuild_patch_dll_dir
50+
# end scikit-build-core Windows patch
51+
""")
52+
dll_dirs: set[Path] = dataclasses.field(default_factory=set, init=False)
53+
"""All dll paths used relative to ``platlib``."""
54+
55+
def get_dll_path_from_lib(self, lib_path: Path) -> Path | None:
56+
"""Guess the dll path from lib path."""
57+
dll_path = None
58+
platlib = Path(_get_buildenv_platlib())
59+
lib_path = lib_path.relative_to(platlib)
60+
# Change the `.lib` to `.dll`
61+
if ".dll" in (suffixes := lib_path.suffixes):
62+
# In some cases like msys, they use `.dll.a`, in which case we can't use `with_suffix`
63+
if suffixes[-2] != ".dll":
64+
logger.warning(
65+
"Expected .dll suffix to be the penultimate extension, instead got: {lib_path}",
66+
lib_path=lib_path,
67+
)
68+
return None
69+
# Drop the last suffix it should then be just .dll file
70+
dll_name = lib_path.stem
71+
else:
72+
dll_name = lib_path.with_suffix(".dll").name
73+
# Try to find the dll in the same package directory
74+
if len(lib_path.parts) > 1:
75+
pkg_dir = lib_path.parts[0]
76+
for root, _, files in os.walk(platlib / pkg_dir):
77+
if dll_name in files:
78+
dll_path = Path(root) / dll_name
79+
break
80+
else:
81+
logger.debug(
82+
"Did not find the dll file under {pkg_dir}",
83+
pkg_dir=pkg_dir,
84+
)
85+
if not dll_path:
86+
logger.debug(
87+
"Looking for {dll_name} in all platlib path.",
88+
dll_name=dll_name,
89+
)
90+
for root, _, files in os.walk(platlib):
91+
if dll_name in files:
92+
dll_path = Path(root) / dll_name
93+
break
94+
else:
95+
logger.warning(
96+
"Could not find dll file {dll_name} corresponding to {lib_path}",
97+
dll_name=dll_name,
98+
lib_path=lib_path,
99+
)
100+
return None
101+
logger.debug(
102+
"Found dll file {dll_path}",
103+
dll_path=dll_path,
104+
)
105+
return self.path_relative_site_packages(dll_path)
106+
107+
def get_library_dependencies(self, target: Target) -> list[Target]:
108+
msg = "get_library_dependencies is not generalized for Windows."
109+
raise NotImplementedError(msg)
110+
111+
def get_dependency_dll(self, target: Target) -> list[Path]:
112+
"""Get the dll due to target link dependencies."""
113+
dll_paths = []
114+
for dep in target.dependencies:
115+
dep_target = next(targ for targ in self.targets if targ.id == dep.id)
116+
if dep_target.type != "SHARED_LIBRARY":
117+
logger.debug(
118+
"Skipping dependency {dep_target} of type {type}",
119+
dep_target=dep_target.name,
120+
type=dep_target.type,
121+
)
122+
continue
123+
if not dep_target.install:
124+
logger.warning(
125+
"Dependency {dep_target} is not installed",
126+
dep_target=dep_target.name,
127+
)
128+
continue
129+
dll_artifact = next(
130+
artifact.path.name
131+
for artifact in dep_target.artifacts
132+
if artifact.path.suffix == ".dll"
133+
)
134+
for install_path in self.get_wheel_install_paths(dep_target):
135+
dep_install_path = self.install_dir / install_path
136+
if (dep_install_path / dll_artifact).exists():
137+
break
138+
else:
139+
logger.warning(
140+
"Could not find installed {dll_artifact} location in install paths: {install_path}",
141+
dll_artifact=dll_artifact,
142+
install_path=[
143+
dest.path for dest in dep_target.install.destinations
144+
],
145+
)
146+
continue
147+
dll_path = self.path_relative_site_packages(dep_install_path)
148+
dll_paths.append(dll_path)
149+
return dll_paths
150+
151+
def get_package_dll(self, target: Target) -> list[Path]:
152+
"""
153+
Get the dll due to external package linkage.
154+
155+
Have to use the guess the dll paths until the package targets are exposed.
156+
https://gitlab.kitware.com/cmake/cmake/-/issues/26755
157+
"""
158+
if not target.link:
159+
return []
160+
dll_paths = []
161+
for link_command in target.link.commandFragments:
162+
if link_command.role == "flags":
163+
if not link_command.fragment:
164+
logger.debug(
165+
"Skipping {target} link-flags: {flags}",
166+
target=target.name,
167+
flags=link_command.fragment,
168+
)
169+
continue
170+
if link_command.role != "libraries":
171+
logger.warning(
172+
"File-api link role {role} is not supported. "
173+
"Target={target}, command={command}",
174+
target=target.name,
175+
role=link_command.role,
176+
command=link_command.fragment,
177+
)
178+
continue
179+
# The remaining case should be a path
180+
try:
181+
# TODO: how to best catch if a string is a valid path?
182+
lib_path = Path(link_command.fragment)
183+
if not lib_path.is_absolute():
184+
# If the link_command is a space-separated list of libraries, this should be skipped
185+
logger.debug(
186+
"Skipping non-absolute-path library: {fragment}",
187+
fragment=link_command.fragment,
188+
)
189+
continue
190+
try:
191+
self.path_relative_site_packages(lib_path)
192+
except ValueError:
193+
logger.debug(
194+
"Skipping library outside site-package path: {lib_path}",
195+
lib_path=lib_path,
196+
)
197+
continue
198+
dll_path = self.get_dll_path_from_lib(lib_path)
199+
if not dll_path:
200+
continue
201+
dll_paths.append(dll_path.parent)
202+
except Exception as exc:
203+
logger.warning(
204+
"Could not parse link-library as a path: {fragment}\nexc = {exc}",
205+
fragment=link_command.fragment,
206+
exc=exc,
207+
)
208+
continue
209+
return dll_paths
210+
28211
def patch_target(self, target: Target) -> None:
29-
# TODO: Implement patching
30-
pass
212+
# Here we just gather all dll paths needed for each target
213+
package_dlls = self.get_package_dll(target)
214+
dependency_dlls = self.get_dependency_dll(target)
215+
if not package_dlls and not dependency_dlls:
216+
logger.warning(
217+
"No dll files found for target {target}",
218+
target=target.name,
219+
)
220+
return
221+
logger.debug(
222+
"Found dlls for target {target}:\n"
223+
"package_dlls={package_dlls}\n"
224+
"dependency_dlls={dependency_dlls}\n",
225+
target=target.name,
226+
package_dlls=package_dlls,
227+
dependency_dlls=dependency_dlls,
228+
)
229+
self.dll_dirs.update(package_dlls)
230+
self.dll_dirs.update(dependency_dlls)
231+
232+
def patch_python_file(self, file: Path) -> None:
233+
"""
234+
Patch python package or top-level module.
235+
236+
Make sure the python files have an appropriate ``os.add_dll_directory``
237+
for the scripts directory.
238+
"""
239+
assert self.dll_dirs
240+
assert all(not path.is_absolute() for path in self.dll_dirs)
241+
logger.debug(
242+
"Patching python file: {file}",
243+
file=file,
244+
)
245+
platlib = Path(self.wheel_dirs["platlib"])
246+
content = file.read_text()
247+
mod_dir = file.parent
248+
path_to_platlib = os.path.relpath(platlib, mod_dir)
249+
patch_script = self.PATCH_PY_FILE.format(
250+
path_to_platlib=path_to_platlib,
251+
dll_paths=[str(path) for path in self.dll_dirs],
252+
)
253+
# TODO: Account for the header comments, __future__.annotations, etc.
254+
with file.open("w") as f:
255+
f.write(f"{patch_script}\n" + content)
256+
257+
def repair_wheel(self) -> None:
258+
super().repair_wheel()
259+
platlib = Path(self.wheel_dirs["platlib"])
260+
if not self.dll_dirs:
261+
logger.debug(
262+
"Skipping wheel repair because no site-package dlls were found."
263+
)
264+
return
265+
logger.debug(
266+
"Patching dll directories: {dll_dirs}",
267+
dll_dirs=self.dll_dirs,
268+
)
269+
# TODO: Not handling namespace packages with this
270+
for path in platlib.iterdir():
271+
assert isinstance(path, Path)
272+
if path.is_dir():
273+
pkg_file = path / "__init__.py"
274+
if not pkg_file.exists():
275+
logger.debug(
276+
"Ignoring non-python package: {pkg_file}",
277+
pkg_file=pkg_file,
278+
)
279+
continue
280+
self.patch_python_file(pkg_file)
281+
elif path.suffix == ".py":
282+
self.patch_python_file(path)

tests/test_repair_wheel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def test_full_build(
7272
wheels = list(dist.glob("*.whl"))
7373
isolated.install(*wheels)
7474

75-
isolated.run("main")
76-
isolated.module("repair_wheel")
75+
if platform.system() != "Windows":
76+
# Requires a more specialized patch
77+
isolated.run("main")
78+
isolated.module("repair_wheel")
7779
isolated.execute(
7880
"from repair_wheel._module import hello; hello()",
7981
)

0 commit comments

Comments
 (0)