Skip to content

Commit d886a3d

Browse files
committed
[WIP] Implement windows
Signed-off-by: Cristian Le <git@lecris.dev>
1 parent a9cd9db commit d886a3d

File tree

2 files changed

+248
-6
lines changed

2 files changed

+248
-6
lines changed

src/scikit_build_core/repair_wheel/windows.py

Lines changed: 245 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING
7+
import dataclasses
8+
import os.path
9+
import textwrap
10+
from pathlib import Path
11+
from typing import TYPE_CHECKING, ClassVar
812

9-
from . import WheelRepairer
13+
from .._logging import logger
14+
from . import WheelRepairer, _get_buildenv_platlib
1015

1116
if TYPE_CHECKING:
1217
from ..file_api.model.codemodel import Target
@@ -18,13 +23,248 @@ def __dir__() -> list[str]:
1823
return __all__
1924

2025

26+
@dataclasses.dataclass
2127
class WindowsWheelRepairer(WheelRepairer):
2228
"""
23-
Do some windows specific magic.
29+
Patch the package and top-level python module files with ``os.add_dll_directory``.
2430
"""
2531

2632
_platform = "Windows"
2733

34+
PATCH_PY_FILE: ClassVar[str] = textwrap.dedent("""\
35+
# start scikit-build-core Windows patch
36+
def _skbuild_patch_dll_dir():
37+
import os
38+
import os.path
39+
40+
mod_dir = os.path.abspath(os.path.dirname(__file__))
41+
path_to_platlib = os.path.normpath({path_to_platlib!r})
42+
dll_paths = {dll_paths!r}
43+
for path in dll_paths:
44+
path = os.path.normpath(path)
45+
path = os.path.join(mod_dir, path_to_platlib, path)
46+
os.add_dll_directory(path)
47+
48+
_skbuild_patch_dll_dir()
49+
del _skbuild_patch_dll_dir
50+
# end scikit-build-core Windows patch
51+
""")
52+
dll_dirs: set[Path] = dataclasses.field(default_factory=set, init=False)
53+
"""All dll paths used relative to ``platlib``."""
54+
55+
def get_dll_path_from_lib(self, lib_path: Path) -> Path | None:
56+
"""Guess the dll path from lib path."""
57+
dll_path = None
58+
platlib = Path(_get_buildenv_platlib())
59+
lib_path = lib_path.relative_to(platlib)
60+
dll_name = lib_path.with_suffix(".dll").name
61+
# Try to find the dll in the same package directory
62+
if len(lib_path.parts) > 1:
63+
pkg_dir = lib_path.parts[0]
64+
for root, _, files in os.walk(platlib / pkg_dir):
65+
if dll_name in files:
66+
dll_path = Path(root) / dll_name
67+
break
68+
else:
69+
logger.debug(
70+
"Did not find the dll file under {pkg_dir}",
71+
pkg_dir=pkg_dir,
72+
)
73+
if not dll_path:
74+
logger.debug(
75+
"Looking for {dll_name} in all platlib path.",
76+
dll_name=dll_name,
77+
)
78+
for root, _, files in os.walk(platlib):
79+
if dll_name in files:
80+
dll_path = Path(root) / dll_name
81+
break
82+
else:
83+
logger.warning(
84+
"Could not find dll file {dll_name} corresponding to {lib_path}",
85+
dll_name=dll_name,
86+
lib_path=lib_path,
87+
)
88+
return None
89+
logger.debug(
90+
"Found dll file {dll_path}",
91+
dll_path=dll_path,
92+
)
93+
return self.path_relative_site_packages(dll_path)
94+
95+
def get_library_dependencies(self, target: Target) -> list[Target]:
96+
msg = "get_library_dependencies is not generalized for Windows."
97+
raise NotImplementedError(msg)
98+
99+
def get_dependency_dll(self, target: Target) -> list[Path]:
100+
"""Get the dll due to target link dependencies."""
101+
dll_paths = []
102+
for dep in target.dependencies:
103+
dep_target = next(targ for targ in self.targets if targ.id == dep.id)
104+
if dep_target.type != "SHARED_LIBRARY":
105+
logger.debug(
106+
"Skipping dependency {dep_target} of type {type}",
107+
dep_target=dep_target.name,
108+
type=dep_target.type,
109+
)
110+
continue
111+
if not dep_target.install:
112+
logger.warning(
113+
"Dependency {dep_target} is not installed",
114+
dep_target=dep_target.name,
115+
)
116+
continue
117+
dll_artifact = next(
118+
artifact.path
119+
for artifact in dep_target.artifacts
120+
if artifact.path.suffix == ".dll"
121+
)
122+
for install_path in self.get_wheel_install_paths(dep_target):
123+
dep_install_path = self.install_dir / install_path
124+
if (dep_install_path / dll_artifact).exists():
125+
break
126+
else:
127+
logger.warning(
128+
"Could not find installed {dll_artifact} location in install paths: {install_path}",
129+
dll_artifact=dll_artifact,
130+
install_path=[
131+
dest.path for dest in dep_target.install.destinations
132+
],
133+
)
134+
continue
135+
dll_path = self.path_relative_site_packages(dep_install_path)
136+
dll_paths.append(dll_path)
137+
return dll_paths
138+
139+
def get_package_dll(self, target: Target) -> list[Path]:
140+
"""
141+
Get the dll due to external package linkage.
142+
143+
Have to use the guess the dll paths until the package targets are exposed.
144+
https://gitlab.kitware.com/cmake/cmake/-/issues/26755
145+
"""
146+
if not target.link:
147+
return []
148+
dll_paths = []
149+
for link_command in target.link.commandFragments:
150+
if link_command.role == "flags":
151+
if not link_command.fragment:
152+
logger.debug(
153+
"Skipping {target} link-flags: {flags}",
154+
target=target.name,
155+
flags=link_command.fragment,
156+
)
157+
continue
158+
if link_command.role != "libraries":
159+
logger.warning(
160+
"File-api link role {role} is not supported. "
161+
"Target={target}, command={command}",
162+
target=target.name,
163+
role=link_command.role,
164+
command=link_command.fragment,
165+
)
166+
continue
167+
# The remaining case should be a path
168+
try:
169+
# TODO: how to best catch if a string is a valid path?
170+
lib_path = Path(link_command.fragment)
171+
if not lib_path.is_absolute():
172+
# If the link_command is a space-separated list of libraries, this should be skipped
173+
logger.debug(
174+
"Skipping non-absolute-path library: {fragment}",
175+
fragment=link_command.fragment,
176+
)
177+
continue
178+
try:
179+
lib_path = self.path_relative_site_packages(lib_path)
180+
except ValueError:
181+
logger.debug(
182+
"Skipping library outside site-package path: {lib_path}",
183+
lib_path=lib_path,
184+
)
185+
continue
186+
dll_path = self.get_dll_path_from_lib(lib_path)
187+
if not dll_path:
188+
continue
189+
dll_paths.append(dll_path)
190+
except Exception as exc:
191+
logger.warning(
192+
"Could not parse link-library as a path: {fragment}\nexc = {exc}",
193+
fragment=link_command.fragment,
194+
exc=exc,
195+
)
196+
continue
197+
return dll_paths
198+
28199
def patch_target(self, target: Target) -> None:
29-
# TODO: Implement patching
30-
pass
200+
# Here we just gather all dll paths needed for each target
201+
package_dlls = self.get_package_dll(target)
202+
dependency_dlls = self.get_dependency_dll(target)
203+
if not package_dlls and not dependency_dlls:
204+
logger.warning(
205+
"No dll files found for target {target}",
206+
target=target.name,
207+
)
208+
return
209+
logger.debug(
210+
"Found dlls for target {target}:\n"
211+
"package_dlls={package_dlls}\n"
212+
"dependency_dlls={dependency_dlls}\n",
213+
target=target.name,
214+
package_dlls=package_dlls,
215+
dependency_dlls=dependency_dlls,
216+
)
217+
self.dll_dirs.update(package_dlls)
218+
self.dll_dirs.update(dependency_dlls)
219+
220+
def patch_python_file(self, file: Path) -> None:
221+
"""
222+
Patch python package or top-level module.
223+
224+
Make sure the python files have an appropriate ``os.add_dll_directory``
225+
for the scripts directory.
226+
"""
227+
assert self.dll_dirs
228+
assert all(not path.is_absolute() for path in self.dll_dirs)
229+
logger.debug(
230+
"Patching python file: {file}",
231+
file=file,
232+
)
233+
platlib = Path(self.wheel_dirs["platlib"])
234+
content = file.read_text()
235+
mod_dir = file.parent
236+
path_to_platlib = os.path.relpath(platlib, mod_dir)
237+
patch_script = self.PATCH_PY_FILE.format(
238+
path_to_platlib=path_to_platlib,
239+
dll_paths=[str(path) for path in self.dll_dirs],
240+
)
241+
# TODO: Account for the header comments, __future__.annotations, etc.
242+
with file.open("w") as f:
243+
f.write(f"{patch_script}\n" + content)
244+
245+
def repair_wheel(self) -> None:
246+
super().repair_wheel()
247+
platlib = Path(self.wheel_dirs["platlib"])
248+
if not self.dll_dirs:
249+
logger.debug(
250+
"Skipping wheel repair because no site-package dlls were found."
251+
)
252+
return
253+
logger.debug(
254+
"Patching dll directories: {dll_dirs}",
255+
dll_dirs=self.dll_dirs,
256+
)
257+
# TODO: Not handling namespace packages with this
258+
for path in platlib.iterdir():
259+
assert isinstance(path, Path)
260+
if path.is_dir():
261+
pkg_file = path / "__init__.py"
262+
if not pkg_file.exists():
263+
logger.debug(
264+
"Ignoring non-python package: {pkg_file}",
265+
pkg_file=pkg_file,
266+
)
267+
continue
268+
self.patch_python_file(pkg_file)
269+
elif path.suffix == ".py":
270+
self.patch_python_file(path)

tests/test_repair_wheel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def test_full_build(
7272
wheels = list(dist.glob("*.whl"))
7373
isolated.install(*wheels)
7474

75-
isolated.run("main")
75+
if platform.system() != "Windows":
76+
# For some reason isolated.run cannot run this on windows
77+
isolated.run("main")
7678
isolated.module("repair_wheel")
7779
isolated.execute(
7880
"from repair_wheel._module import hello; hello()",

0 commit comments

Comments
 (0)