Skip to content

Commit 4a6ffc2

Browse files
committed
Add merge rules engine for validating PRs before landing
Introduces a new merge_rules module that provides: - MergeRule dataclass for rule configuration (patterns, approvers, checks) - MergeRulesLoader to load rules from .github/merge_rules.yaml - MergeValidator to validate PRs against rules including: - File pattern matching (fnmatch/glob) - Approver validation with team expansion (org/team-slug) - CI check status validation - ValidationResult and MergeValidationError for error handling Also adds new GitHub API methods to GitHubEndpoint: - get_pr_reviews, get_pr_files, get_check_runs - get_team_members, get_file_contents, post_issue_comment Adds pyyaml dependency for YAML parsing. ghstack-source-id: 0407221 ghstack-comment-id: 3807838673 Pull-Request: #317
1 parent 3987ada commit 4a6ffc2

File tree

5 files changed

+435
-0
lines changed

5 files changed

+435
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"typing-extensions>=3",
1515
"click<9,>=8",
1616
"flake8<8.0.0,>=7.0.0",
17+
"pyyaml<7,>=6",
1718
]
1819
name = "ghstack"
1920
version = "0.14.0"

src/ghstack/github.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,42 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
9595
Returns: parsed JSON response
9696
"""
9797
pass
98+
99+
# Merge rules related API methods
100+
101+
def get_pr_reviews(self, owner: str, repo: str, number: int) -> Any:
102+
"""Get reviews for a pull request."""
103+
return self.get(f"repos/{owner}/{repo}/pulls/{number}/reviews")
104+
105+
def get_pr_files(self, owner: str, repo: str, number: int) -> Any:
106+
"""Get files changed in a pull request."""
107+
return self.get(f"repos/{owner}/{repo}/pulls/{number}/files")
108+
109+
def get_check_runs(self, owner: str, repo: str, ref: str) -> Any:
110+
"""Get check runs for a commit ref."""
111+
return self.get(f"repos/{owner}/{repo}/commits/{ref}/check-runs")
112+
113+
def get_team_members(self, org: str, team_slug: str) -> Any:
114+
"""Get members of a team."""
115+
return self.get(f"orgs/{org}/teams/{team_slug}/members")
116+
117+
def get_file_contents(
118+
self, owner: str, repo: str, path: str, ref: str = "HEAD"
119+
) -> str:
120+
"""
121+
Get the contents of a file from the repository.
122+
123+
Returns the decoded file contents as a string.
124+
"""
125+
import base64
126+
127+
result = self.get(f"repos/{owner}/{repo}/contents/{path}?ref={ref}")
128+
content = result.get("content", "")
129+
encoding = result.get("encoding", "")
130+
if encoding == "base64":
131+
return base64.b64decode(content).decode("utf-8")
132+
return content
133+
134+
def post_issue_comment(self, owner: str, repo: str, number: int, body: str) -> Any:
135+
"""Post a comment on an issue or pull request."""
136+
return self.post(f"repos/{owner}/{repo}/issues/{number}/comments", body=body)

src/ghstack/github_fake.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,19 @@ def repository(self, info: GraphQLResolveInfo) -> Repository:
257257
return github_state(info).repositories[self._repository]
258258

259259

260+
@dataclass
261+
class PullRequestReview:
262+
user: str
263+
state: str # APPROVED, CHANGES_REQUESTED, COMMENTED, etc.
264+
265+
266+
@dataclass
267+
class CheckRun:
268+
name: str
269+
status: str # queued, in_progress, completed
270+
conclusion: Optional[str] # success, failure, neutral, cancelled, skipped, etc.
271+
272+
260273
@dataclass
261274
class PullRequest(Node):
262275
baseRef: Optional[Ref]
@@ -274,6 +287,10 @@ class PullRequest(Node):
274287
url: str
275288
reviewers: List[str] = dataclasses.field(default_factory=list)
276289
labels: List[str] = dataclasses.field(default_factory=list)
290+
# Merge rules related fields
291+
files: List[str] = dataclasses.field(default_factory=list)
292+
reviews: List[PullRequestReview] = dataclasses.field(default_factory=list)
293+
check_runs: List[CheckRun] = dataclasses.field(default_factory=list)
277294

278295
def repository(self, info: GraphQLResolveInfo) -> Repository:
279296
return github_state(info).repositories[self._repository]
@@ -464,6 +481,75 @@ def rest(self, method: str, path: str, **kwargs: Any) -> Any:
464481
# For now, pretend all branches are not protected
465482
raise ghstack.github.NotFoundError()
466483

484+
# GET /repos/{owner}/{repo}/pulls/{number}/reviews
485+
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/reviews$", path):
486+
state = self.state
487+
repo = state.repository(m.group(1), m.group(2))
488+
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
489+
return [
490+
{"user": {"login": r.user}, "state": r.state} for r in pr.reviews
491+
]
492+
493+
# GET /repos/{owner}/{repo}/pulls/{number}/files
494+
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)/files$", path):
495+
state = self.state
496+
repo = state.repository(m.group(1), m.group(2))
497+
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
498+
return [{"filename": f} for f in pr.files]
499+
500+
# GET /repos/{owner}/{repo}/pulls/{number}
501+
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls/(\d+)$", path):
502+
state = self.state
503+
repo = state.repository(m.group(1), m.group(2))
504+
pr = state.pull_request(repo, GitHubNumber(int(m.group(3))))
505+
head_sha = ""
506+
if pr.headRef:
507+
head_sha = pr.headRef.target.oid
508+
return {
509+
"number": pr.number,
510+
"title": pr.title,
511+
"body": pr.body,
512+
"head": {"sha": head_sha},
513+
"base": {"ref": pr.baseRefName},
514+
}
515+
516+
# GET /repos/{owner}/{repo}/commits/{ref}/check-runs
517+
if m := re.match(
518+
r"^repos/([^/]+)/([^/]+)/commits/([^/]+)/check-runs$", path
519+
):
520+
# For the fake endpoint, we need to find the PR by head SHA
521+
# and return its check runs
522+
state = self.state
523+
ref = m.group(3)
524+
# Search for PR with matching head ref
525+
for pr in state.pull_requests.values():
526+
if pr.headRef and pr.headRef.target.oid == ref:
527+
return {
528+
"total_count": len(pr.check_runs),
529+
"check_runs": [
530+
{
531+
"name": c.name,
532+
"status": c.status,
533+
"conclusion": c.conclusion,
534+
}
535+
for c in pr.check_runs
536+
],
537+
}
538+
# No matching PR found
539+
return {"total_count": 0, "check_runs": []}
540+
541+
# GET /orgs/{org}/teams/{team_slug}/members
542+
if m := re.match(r"^orgs/([^/]+)/teams/([^/]+)/members$", path):
543+
# Return empty list for fake endpoint
544+
return []
545+
546+
# GET /repos/{owner}/{repo}/contents/{path}
547+
if m := re.match(
548+
r"^repos/([^/]+)/([^/]+)/contents/(.+?)(?:\?ref=(.+))?$", path
549+
):
550+
# Return a NotFoundError for the fake endpoint
551+
raise ghstack.github.NotFoundError()
552+
467553
elif method == "post":
468554
if m := re.match(r"^repos/([^/]+)/([^/]+)/pulls$", path):
469555
return self._create_pull(

0 commit comments

Comments
 (0)