Skip to content

Commit fa8a0e9

Browse files
committed
move types for Branches to implementation
1 parent 7c0831f commit fa8a0e9

File tree

6 files changed

+39
-43
lines changed

6 files changed

+39
-43
lines changed

pygit2/_pygit2.pyi

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import tarfile
2-
from collections.abc import Generator
32
from io import DEFAULT_BUFFER_SIZE, IOBase
43
from pathlib import Path
54
from queue import Queue
@@ -28,6 +27,7 @@ from ._libgit2.ffi import (
2827
_Pointer,
2928
)
3029
from .blame import Blame
30+
from .branches import Branches
3131
from .callbacks import CheckoutCallbacks, StashApplyCallbacks
3232
from .config import Config
3333
from .enums import (
@@ -63,7 +63,6 @@ from .index import MergeFileResult
6363
from .packbuilder import PackBuilder
6464
from .references import References
6565
from .remotes import RemoteCollection
66-
from .repository import BaseRepository
6766
from .submodules import SubmoduleCollection
6867

6968
GIT_OBJ_BLOB = Literal[3]
@@ -691,25 +690,6 @@ class _LsRemotesDict(TypedDict):
691690
symref_target: str | None
692691
oid: Oid
693692

694-
class Branches:
695-
local: 'Branches'
696-
remote: 'Branches'
697-
def __init__(
698-
self,
699-
repository: BaseRepository,
700-
flag: BranchType = ...,
701-
commit: Commit | _OidArg | None = None,
702-
) -> None: ...
703-
def __getitem__(self, name: str) -> Branch: ...
704-
def get(self, key: str) -> Branch: ...
705-
def __iter__(self) -> Iterator[str]: ...
706-
def create(
707-
self, name: str, commit: Object | Commit, force: bool = False
708-
) -> Branch: ...
709-
def delete(self, name: str) -> None: ...
710-
def with_commit(self, commit: Object | Commit | _OidArg | None) -> 'Branches': ...
711-
def __contains__(self, name: _OidArg) -> bool: ...
712-
713693
class Repository:
714694
_pointer: GitRepositoryC
715695
_repo: GitRepositoryC

pygit2/branches.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
from __future__ import annotations
2727

28-
from typing import TYPE_CHECKING
28+
from typing import TYPE_CHECKING, Iterator
2929

30-
from ._pygit2 import Commit, Oid
30+
from ._pygit2 import Branch, Commit, Oid
3131
from .enums import BranchType, ReferenceType
3232

3333
# Need BaseRepository for type hints, but don't let it cause a circular dependency
@@ -36,9 +36,15 @@
3636

3737

3838
class Branches:
39+
local: 'Branches'
40+
remote: 'Branches'
41+
3942
def __init__(
40-
self, repository: BaseRepository, flag: BranchType = BranchType.ALL, commit=None
41-
):
43+
self,
44+
repository: BaseRepository,
45+
flag: BranchType = BranchType.ALL,
46+
commit: Commit | Oid | str | None = None,
47+
) -> None:
4248
self._repository = repository
4349
self._flag = flag
4450
if commit is not None:
@@ -52,7 +58,7 @@ def __init__(
5258
self.local = Branches(repository, flag=BranchType.LOCAL, commit=commit)
5359
self.remote = Branches(repository, flag=BranchType.REMOTE, commit=commit)
5460

55-
def __getitem__(self, name: str):
61+
def __getitem__(self, name: str) -> Branch:
5662
branch = None
5763
if self._flag & BranchType.LOCAL:
5864
branch = self._repository.lookup_branch(name, BranchType.LOCAL)
@@ -65,36 +71,38 @@ def __getitem__(self, name: str):
6571

6672
return branch
6773

68-
def get(self, key: str):
74+
def get(self, key: str) -> Branch:
6975
try:
7076
return self[key]
7177
except KeyError:
72-
return None
78+
return None # type:ignore # next commit
7379

74-
def __iter__(self):
80+
def __iter__(self) -> Iterator[str]:
7581
for branch_name in self._repository.listall_branches(self._flag):
7682
if self._commit is None or self.get(branch_name) is not None:
7783
yield branch_name
7884

79-
def create(self, name: str, commit, force=False):
85+
def create(self, name: str, commit: Commit, force: bool = False) -> Branch:
8086
return self._repository.create_branch(name, commit, force)
8187

82-
def delete(self, name: str):
88+
def delete(self, name: str) -> None:
8389
self[name].delete()
8490

85-
def _valid(self, branch):
91+
def _valid(self, branch: Branch) -> bool:
8692
if branch.type == ReferenceType.SYMBOLIC:
87-
branch = branch.resolve()
93+
branch_direct = branch.resolve()
94+
else:
95+
branch_direct = branch
8896

8997
return (
9098
self._commit is None
91-
or branch.target == self._commit
92-
or self._repository.descendant_of(branch.target, self._commit)
99+
or branch_direct.target == self._commit
100+
or self._repository.descendant_of(branch_direct.target, self._commit)
93101
)
94102

95-
def with_commit(self, commit):
103+
def with_commit(self, commit: Commit | Oid | str | None) -> 'Branches':
96104
assert self._commit is None
97105
return Branches(self._repository, self._flag, commit)
98106

99-
def __contains__(self, name):
107+
def __contains__(self, name: str) -> bool:
100108
return self.get(name) is not None

pygit2/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from types import TracebackType
2929
from typing import (
3030
TYPE_CHECKING,
31-
Generic,
3231
Generator,
32+
Generic,
3333
Iterator,
3434
Optional,
3535
Protocol,
@@ -46,7 +46,7 @@
4646
from ._libgit2.ffi import ArrayC, GitStrrayC, char, char_pointer
4747

4848

49-
def maybe_string(ptr: 'char_pointer' | None) -> str | None:
49+
def maybe_string(ptr: 'char_pointer | None') -> str | None:
5050
if not ptr:
5151
return None
5252

@@ -106,7 +106,7 @@ def ptr_to_bytes(ptr_cdata):
106106

107107

108108
@contextlib.contextmanager
109-
def new_git_strarray() -> Generator[GitStrrayC, None, None]:
109+
def new_git_strarray() -> Generator['GitStrrayC', None, None]:
110110
strarray = ffi.new('git_strarray *')
111111
yield strarray
112112
C.git_strarray_dispose(strarray)

test/test_branch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def test_branches(testrepo: Repository) -> None:
5858

5959
def test_branches_create(testrepo: Repository) -> None:
6060
commit = testrepo[LAST_COMMIT]
61+
assert isinstance(commit, Commit)
6162
reference = testrepo.branches.create('version1', commit)
6263
assert 'version1' in testrepo.branches
6364
reference = testrepo.branches['version1']
@@ -142,7 +143,9 @@ def test_branches_with_commit(testrepo: Repository) -> None:
142143
branches = testrepo.branches.with_commit(LAST_COMMIT)
143144
assert sorted(branches) == ['master']
144145

145-
branches = testrepo.branches.with_commit(testrepo[LAST_COMMIT])
146+
commit = testrepo[LAST_COMMIT]
147+
assert isinstance(commit, Commit)
148+
branches = testrepo.branches.with_commit(commit)
146149
assert sorted(branches) == ['master']
147150

148151
branches = testrepo.branches.remote.with_commit(LAST_COMMIT)

test/test_branch_empty.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def test_branches_remote_getitem(repo: Repository) -> None:
5858

5959
def test_branches_upstream(repo: Repository) -> None:
6060
remote_master = repo.branches.remote['origin/master']
61-
master = repo.branches.create('master', repo[remote_master.target])
61+
commit = repo[remote_master.target]
62+
assert isinstance(commit, Commit)
63+
master = repo.branches.create('master', commit)
6264

6365
assert master.upstream is None
6466
master.upstream = remote_master
@@ -76,7 +78,9 @@ def set_bad_upstream():
7678

7779
def test_branches_upstream_name(repo: Repository) -> None:
7880
remote_master = repo.branches.remote['origin/master']
79-
master = repo.branches.create('master', repo[remote_master.target])
81+
commit = repo[remote_master.target]
82+
assert isinstance(commit, Commit)
83+
master = repo.branches.create('master', commit)
8084

8185
master.upstream = remote_master
8286
assert master.upstream_name == 'refs/remotes/origin/master'

test/test_repository.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ def test_worktree_custom_ref(testrepo: Repository) -> None:
906906

907907
# New branch based on head
908908
tip = testrepo.revparse_single('HEAD')
909+
assert isinstance(tip, Commit)
909910
worktree_ref = testrepo.branches.create(branch_name, tip)
910911
# Delete temp path so that it's not present when we attempt to add the
911912
# worktree later

0 commit comments

Comments
 (0)