1414from github .Tag import Tag
1515from github .Workflow import Workflow
1616
17- from codegen .git .clients .github_client_factory import GithubClientFactory
18- from codegen .git .clients .types import GithubClientType
19- from codegen .git .schemas .github import GithubScope , GithubType
17+ from codegen .git .clients .github_client import GithubClient
2018from codegen .git .schemas .repo_config import RepoConfig
2119from codegen .git .utils .format import format_comparison
2220
@@ -27,33 +25,27 @@ class GitRepoClient:
2725 """Wrapper around PyGithub's Remote Repository."""
2826
2927 repo_config : RepoConfig
30- github_type : GithubType = GithubType .GithubEnterprise
31- gh_client : GithubClientType
32- read_client : Repository
33- access_scope : GithubScope
34- __write_client : Repository | None # Will not be initialized if access scope is read-only
28+ gh_client : GithubClient
29+ _repo : Repository
3530
36- def __init__ (self , repo_config : RepoConfig , github_type : GithubType = GithubType . GithubEnterprise , access_scope : GithubScope = GithubScope . READ ) -> None :
31+ def __init__ (self , repo_config : RepoConfig ) -> None :
3732 self .repo_config = repo_config
38- self .github_type = github_type
39- self .gh_client = GithubClientFactory . create_from_repo ( self .repo_config , github_type )
40- self . read_client = self . _create_client ( GithubScope . READ )
41- self . __write_client = self . _create_client ( GithubScope . WRITE ) if access_scope == GithubScope . WRITE else None
42- self . access_scope = access_scope
43-
44- def _create_client (self , github_scope : GithubScope = GithubScope . READ ) -> Repository :
45- client = self .gh_client .get_repo_by_full_name (self .repo_config .full_name , github_scope = github_scope )
33+ self .gh_client = self . _create_github_client ()
34+ self ._repo = self ._create_client ( )
35+
36+ def _create_github_client ( self ) -> GithubClient :
37+ return GithubClient ()
38+
39+ def _create_client (self ) -> Repository :
40+ client = self .gh_client .get_repo_by_full_name (self .repo_config .full_name )
4641 if not client :
47- msg = f"Repo { self .repo_config .full_name } not found in { self . github_type . value } !"
42+ msg = f"Repo { self .repo_config .full_name } not found!"
4843 raise ValueError (msg )
4944 return client
5045
5146 @property
52- def _write_client (self ) -> Repository :
53- if self .__write_client is None :
54- msg = "Cannot perform write operations with read-only client! Try setting github_scope to GithubScope.WRITE."
55- raise ValueError (msg )
56- return self .__write_client
47+ def repo (self ) -> Repository :
48+ return self ._repo
5749
5850 ####################################################################################################################
5951 # PROPERTIES
@@ -65,7 +57,7 @@ def id(self) -> int:
6557
6658 @property
6759 def default_branch (self ) -> str :
68- return self .read_client .default_branch
60+ return self .repo .default_branch
6961
7062 ####################################################################################################################
7163 # CONTENTS
@@ -76,7 +68,7 @@ def get_contents(self, file_path: str, ref: str | None = None) -> str | None:
7668 if not ref :
7769 ref = self .default_branch
7870 try :
79- file = self .read_client .get_contents (file_path , ref = ref )
71+ file = self .repo .get_contents (file_path , ref = ref )
8072 file_contents = file .decoded_content .decode ("utf-8" ) # type: ignore[union-attr]
8173 return file_contents
8274 except UnknownObjectException :
@@ -100,7 +92,7 @@ def get_last_modified_date_of_path(self, path: str) -> datetime:
10092 str: The last modified date of the directory in ISO format (YYYY-MM-DDTHH:MM:SSZ).
10193
10294 """
103- commits = self .read_client .get_commits (path = path )
95+ commits = self .repo .get_commits (path = path )
10496 if commits .totalCount > 0 :
10597 # Get the date of the latest commit
10698 last_modified_date = commits [0 ].commit .committer .date
@@ -124,7 +116,7 @@ def create_review_comment(
124116 start_line : Opt [int ] = NotSet ,
125117 ) -> None :
126118 # TODO: add protections (ex: can write to PR)
127- writeable_pr = self ._write_client .get_pull (pull .number )
119+ writeable_pr = self .repo .get_pull (pull .number )
128120 writeable_pr .create_review_comment (
129121 body = body ,
130122 commit = commit ,
@@ -140,7 +132,7 @@ def create_issue_comment(
140132 body : str ,
141133 ) -> None :
142134 # TODO: add protections (ex: can write to PR)
143- writeable_pr = self ._write_client .get_pull (pull .number )
135+ writeable_pr = self .repo .get_pull (pull .number )
144136 writeable_pr .create_issue_comment (body = body )
145137
146138 ####################################################################################################################
@@ -163,7 +155,7 @@ def get_pull_by_branch_and_state(
163155 head_branch_name = f"{ self .repo_config .organization_name } :{ head_branch_name } "
164156
165157 # retrieve all pulls ordered by created descending
166- prs = self .read_client .get_pulls (base = base_branch_name , head = head_branch_name , state = state , sort = "created" , direction = "desc" )
158+ prs = self .repo .get_pulls (base = base_branch_name , head = head_branch_name , state = state , sort = "created" , direction = "desc" )
167159 if prs .totalCount > 0 :
168160 return prs [0 ]
169161 else :
@@ -174,7 +166,7 @@ def get_pull_safe(self, number: int) -> PullRequest | None:
174166 TODO: catching UnknownObjectException is common enough to create a decorator
175167 """
176168 try :
177- pr = self .read_client .get_pull (number )
169+ pr = self .repo .get_pull (number )
178170 return pr
179171 except UnknownObjectException as e :
180172 return None
@@ -209,10 +201,10 @@ def create_pull(
209201 if base_branch_name is None :
210202 base_branch_name = self .default_branch
211203 try :
212- pr = self ._write_client .create_pull (title = title or f"Draft PR for { head_branch_name } " , body = body or "" , head = head_branch_name , base = base_branch_name , draft = draft )
204+ pr = self .repo .create_pull (title = title or f"Draft PR for { head_branch_name } " , body = body or "" , head = head_branch_name , base = base_branch_name , draft = draft )
213205 logger .info (f"Created pull request for head branch: { head_branch_name } at { pr .html_url } " )
214206 # NOTE: return a read-only copy to prevent people from editing it
215- return self .read_client .get_pull (pr .number )
207+ return self .repo .get_pull (pr .number )
216208 except GithubException as ge :
217209 logger .warning (f"Failed to create PR got GithubException\n \t { ge } " )
218210 except Exception as e :
@@ -235,15 +227,15 @@ def squash_and_merge(self, base_branch_name: str, head_branch_name: str, squash_
235227 merge = squash_pr .merge (commit_message = squash_commit_msg , commit_title = squash_commit_title , merge_method = "squash" ) # type: ignore[arg-type]
236228
237229 def edit_pull (self , pull : PullRequest , title : Opt [str ] = NotSet , body : Opt [str ] = NotSet , state : Opt [str ] = NotSet ) -> None :
238- writable_pr = self ._write_client .get_pull (pull .number )
230+ writable_pr = self .repo .get_pull (pull .number )
239231 writable_pr .edit (title = title , body = body , state = state )
240232
241233 def add_label_to_pull (self , pull : PullRequest , label : Label ) -> None :
242- writeable_pr = self ._write_client .get_pull (pull .number )
234+ writeable_pr = self .repo .get_pull (pull .number )
243235 writeable_pr .add_to_labels (label )
244236
245237 def remove_label_from_pull (self , pull : PullRequest , label : Label ) -> None :
246- writeable_pr = self ._write_client .get_pull (pull .number )
238+ writeable_pr = self .repo .get_pull (pull .number )
247239 writeable_pr .remove_from_labels (label )
248240
249241 ####################################################################################################################
@@ -264,7 +256,7 @@ def get_or_create_branch(self, new_branch_name: str, base_branch_name: str | Non
264256 def get_branch_safe (self , branch_name : str , attempts : int = 1 , wait_seconds : int = 1 ) -> Branch | None :
265257 for i in range (attempts ):
266258 try :
267- return self .read_client .get_branch (branch_name )
259+ return self .repo .get_branch (branch_name )
268260 except GithubException as e :
269261 if e .status == 404 and i < attempts - 1 :
270262 time .sleep (wait_seconds )
@@ -276,14 +268,14 @@ def create_branch(self, new_branch_name: str, base_branch_name: str | None = Non
276268 if base_branch_name is None :
277269 base_branch_name = self .default_branch
278270
279- base_branch = self .read_client .get_branch (base_branch_name )
271+ base_branch = self .repo .get_branch (base_branch_name )
280272 # TODO: also wrap git ref. low pri b/c the only write operation on refs is creating one
281- self ._write_client .create_git_ref (sha = base_branch .commit .sha , ref = f"refs/heads/{ new_branch_name } " )
273+ self .repo .create_git_ref (sha = base_branch .commit .sha , ref = f"refs/heads/{ new_branch_name } " )
282274 branch = self .get_branch_safe (new_branch_name )
283275 return branch
284276
285277 def create_branch_from_sha (self , new_branch_name : str , base_sha : str ) -> Branch | None :
286- self ._write_client .create_git_ref (ref = f"refs/heads/{ new_branch_name } " , sha = base_sha )
278+ self .repo .create_git_ref (ref = f"refs/heads/{ new_branch_name } " , sha = base_sha )
287279 branch = self .get_branch_safe (new_branch_name )
288280 return branch
289281
@@ -295,7 +287,7 @@ def delete_branch(self, branch_name: str) -> None:
295287
296288 branch_to_delete = self .get_branch_safe (branch_name )
297289 if branch_to_delete :
298- ref_to_delete = self ._write_client .get_git_ref (f"heads/{ branch_name } " )
290+ ref_to_delete = self .repo .get_git_ref (f"heads/{ branch_name } " )
299291 ref_to_delete .delete ()
300292 logger .info (f"Branch: { branch_name } deleted successfully!" )
301293 else :
@@ -307,7 +299,7 @@ def delete_branch(self, branch_name: str) -> None:
307299
308300 def get_commit_safe (self , commit_sha : str ) -> Commit | None :
309301 try :
310- return self .read_client .get_commit (commit_sha )
302+ return self .repo .get_commit (commit_sha )
311303 except UnknownObjectException as e :
312304 logger .warning (f"Commit { commit_sha } not found:\n \t { e } " )
313305 return None
@@ -338,7 +330,7 @@ def compare_branches(self, base_branch_name: str | None, head_branch_name: str,
338330
339331 # NOTE: base utility that other compare functions should try to use
340332 def compare (self , base : str , head : str , show_commits : bool = False ) -> str :
341- comparison = self .read_client .compare (base , head )
333+ comparison = self .repo .compare (base , head )
342334 return format_comparison (comparison , show_commits = show_commits )
343335
344336 ####################################################################################################################
@@ -349,7 +341,7 @@ def compare(self, base: str, head: str, show_commits: bool = False) -> str:
349341 def get_label_safe (self , label_name : str ) -> Label | None :
350342 try :
351343 label_name = label_name .strip ()
352- label = self .read_client .get_label (label_name )
344+ label = self .repo .get_label (label_name )
353345 return label
354346 except UnknownObjectException as e :
355347 return None
@@ -360,10 +352,10 @@ def get_label_safe(self, label_name: str) -> Label | None:
360352 def create_label (self , label_name : str , color : str ) -> Label :
361353 # TODO: also offer description field
362354 label_name = label_name .strip ()
363- self ._write_client .create_label (label_name , color )
355+ self .repo .create_label (label_name , color )
364356 # TODO: is there a way to convert new_label to a read-only label without making another API call?
365357 # NOTE: return a read-only label to prevent people from editing it
366- return self .read_client .get_label (label_name )
358+ return self .repo .get_label (label_name )
367359
368360 def get_or_create_label (self , label_name : str , color : str ) -> Label :
369361 existing_label = self .get_label_safe (label_name )
@@ -377,7 +369,7 @@ def get_or_create_label(self, label_name: str, color: str) -> Label:
377369
378370 def get_check_suite_safe (self , check_suite_id : int ) -> CheckSuite | None :
379371 try :
380- return self .read_client .get_check_suite (check_suite_id )
372+ return self .repo .get_check_suite (check_suite_id )
381373 except UnknownObjectException as e :
382374 return None
383375 except Exception as e :
@@ -390,7 +382,7 @@ def get_check_suite_safe(self, check_suite_id: int) -> CheckSuite | None:
390382
391383 def get_check_run_safe (self , check_run_id : int ) -> CheckRun | None :
392384 try :
393- return self .read_client .get_check_run (check_run_id )
385+ return self .repo .get_check_run (check_run_id )
394386 except UnknownObjectException as e :
395387 return None
396388 except Exception as e :
@@ -406,24 +398,24 @@ def create_check_run(
406398 conclusion : Opt [str ] = NotSet ,
407399 output : Opt [dict [str , str | list [dict [str , str | int ]]]] = NotSet ,
408400 ) -> CheckRun :
409- new_check_run = self ._write_client .create_check_run (name = name , head_sha = head_sha , details_url = details_url , status = status , conclusion = conclusion , output = output )
410- return self .read_client .get_check_run (new_check_run .id )
401+ new_check_run = self .repo .create_check_run (name = name , head_sha = head_sha , details_url = details_url , status = status , conclusion = conclusion , output = output )
402+ return self .repo .get_check_run (new_check_run .id )
411403
412404 ####################################################################################################################
413405 # WORKFLOW
414406 ####################################################################################################################
415407
416408 def get_workflow_safe (self , file_name : str ) -> Workflow | None :
417409 try :
418- return self .read_client .get_workflow (file_name )
410+ return self .repo .get_workflow (file_name )
419411 except UnknownObjectException as e :
420412 return None
421413 except Exception as e :
422414 logger .warning (f"Error getting workflow by file name: { file_name } \n \t { e } " )
423415 return None
424416
425417 def create_workflow_dispatch (self , workflow : Workflow , ref : Branch | Tag | Commit | str , inputs : Opt [dict ] = NotSet ):
426- writeable_workflow = self ._write_client .get_workflow (workflow .id )
418+ writeable_workflow = self .repo .get_workflow (workflow .id )
427419 writeable_workflow .create_dispatch (ref = ref , inputs = inputs )
428420
429421 ####################################################################################################################
@@ -439,5 +431,5 @@ def merge_upstream(self, branch_name: str) -> bool:
439431 """
440432 assert isinstance (branch_name , str ), branch_name
441433 post_parameters = {"branch" : branch_name }
442- status , _ , _ = self ._write_client ._requester .requestJson ("POST" , f"{ self ._write_client .url } /merge-upstream" , input = post_parameters )
434+ status , _ , _ = self .repo ._requester .requestJson ("POST" , f"{ self .repo .url } /merge-upstream" , input = post_parameters )
443435 return status == 200
0 commit comments