2525
2626from __future__ import annotations
2727
28- from typing import TYPE_CHECKING
28+ from typing import TYPE_CHECKING , Iterator
2929
30- from ._pygit2 import Commit , Oid
30+ from ._pygit2 import Branch , Commit , Oid
3131from .enums import BranchType , ReferenceType
3232
3333# Need BaseRepository for type hints, but don't let it cause a circular dependency
3636
3737
3838class Branches :
39+ local : 'Branches'
40+ remote : 'Branches'
41+
3942 def __init__ (
40- self , repository : BaseRepository , flag : BranchType = BranchType .ALL , commit = None
41- ):
43+ self ,
44+ repository : BaseRepository ,
45+ flag : BranchType = BranchType .ALL ,
46+ commit : Commit | Oid | str | None = None ,
47+ ) -> None :
4248 self ._repository = repository
4349 self ._flag = flag
4450 if commit is not None :
@@ -52,7 +58,7 @@ def __init__(
5258 self .local = Branches (repository , flag = BranchType .LOCAL , commit = commit )
5359 self .remote = Branches (repository , flag = BranchType .REMOTE , commit = commit )
5460
55- def __getitem__ (self , name : str ):
61+ def __getitem__ (self , name : str ) -> Branch :
5662 branch = None
5763 if self ._flag & BranchType .LOCAL :
5864 branch = self ._repository .lookup_branch (name , BranchType .LOCAL )
@@ -65,36 +71,38 @@ def __getitem__(self, name: str):
6571
6672 return branch
6773
68- def get (self , key : str ):
74+ def get (self , key : str ) -> Branch :
6975 try :
7076 return self [key ]
7177 except KeyError :
72- return None
78+ return None # type:ignore # next commit
7379
74- def __iter__ (self ):
80+ def __iter__ (self ) -> Iterator [ str ] :
7581 for branch_name in self ._repository .listall_branches (self ._flag ):
7682 if self ._commit is None or self .get (branch_name ) is not None :
7783 yield branch_name
7884
79- def create (self , name : str , commit , force = False ):
85+ def create (self , name : str , commit : Commit , force : bool = False ) -> Branch :
8086 return self ._repository .create_branch (name , commit , force )
8187
82- def delete (self , name : str ):
88+ def delete (self , name : str ) -> None :
8389 self [name ].delete ()
8490
85- def _valid (self , branch ) :
91+ def _valid (self , branch : Branch ) -> bool :
8692 if branch .type == ReferenceType .SYMBOLIC :
87- branch = branch .resolve ()
93+ branch_direct = branch .resolve ()
94+ else :
95+ branch_direct = branch
8896
8997 return (
9098 self ._commit is None
91- or branch .target == self ._commit
92- or self ._repository .descendant_of (branch .target , self ._commit )
99+ or branch_direct .target == self ._commit
100+ or self ._repository .descendant_of (branch_direct .target , self ._commit )
93101 )
94102
95- def with_commit (self , commit ) :
103+ def with_commit (self , commit : Commit | Oid | str | None ) -> 'Branches' :
96104 assert self ._commit is None
97105 return Branches (self ._repository , self ._flag , commit )
98106
99- def __contains__ (self , name ) :
107+ def __contains__ (self , name : str ) -> bool :
100108 return self .get (name ) is not None
0 commit comments