Skip to content

Commit 1d73c8c

Browse files
committed
Add support for passing HUD url to ghstack checkout
Signed-off-by: Edward Yang <[email protected]> ghstack-source-id: 192936e ghstack-comment-id: 3660975931 Pull-Request: #307
1 parent 2d0ada6 commit 1d73c8c

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

src/ghstack/github_utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,19 @@ def get_github_repo_info(
123123
}
124124

125125

126+
# Matches GitHub PR URLs like:
127+
# https://github.com/owner/repo/pull/123
128+
# https://github.com/owner/repo/pull/123/
129+
# https://github.com/owner/repo/pull/123/files
130+
# https://github.com/owner/repo/pull/123/commits
126131
RE_PR_URL = re.compile(
127-
r"^https://(?P<github_url>[^/]+)/(?P<owner>[^/]+)/(?P<name>[^/]+)/pull/(?P<number>[0-9]+)/?$"
132+
r"^https://(?P<github_url>[^/]+)/(?P<owner>[^/]+)/(?P<name>[^/]+)/pull/(?P<number>[0-9]+)(?:/.*)?$"
133+
)
134+
135+
# Matches PyTorch HUD URLs like:
136+
# https://hud.pytorch.org/pr/169404
137+
RE_PYTORCH_HUD_URL = re.compile(
138+
r"^https://hud\.pytorch\.org/pr/(?P<number>[0-9]+)/?$"
128139
)
129140

130141
GitHubPullRequestParams = TypedDict(
@@ -144,6 +155,17 @@ def parse_pull_request(
144155
sh: Optional[ghstack.shell.Shell] = None,
145156
remote_name: Optional[str] = None,
146157
) -> GitHubPullRequestParams:
158+
# Check for PyTorch HUD URL first (hud.pytorch.org/pr/NUMBER)
159+
hud_match = RE_PYTORCH_HUD_URL.match(pull_request)
160+
if hud_match:
161+
number = int(hud_match.group("number"))
162+
return {
163+
"github_url": "github.com",
164+
"owner": "pytorch",
165+
"name": "pytorch",
166+
"number": number,
167+
}
168+
147169
m = RE_PR_URL.match(pull_request)
148170
if not m:
149171
# We can reconstruct the URL if just a PR number is passed

test_github_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import ghstack.github_utils
6+
7+
8+
class TestParsePullRequest(unittest.TestCase):
9+
def test_github_url_basic(self) -> None:
10+
result = ghstack.github_utils.parse_pull_request(
11+
"https://github.com/pytorch/pytorch/pull/169404"
12+
)
13+
self.assertEqual(result["github_url"], "github.com")
14+
self.assertEqual(result["owner"], "pytorch")
15+
self.assertEqual(result["name"], "pytorch")
16+
self.assertEqual(result["number"], 169404)
17+
18+
def test_github_url_trailing_slash(self) -> None:
19+
result = ghstack.github_utils.parse_pull_request(
20+
"https://github.com/pytorch/pytorch/pull/169404/"
21+
)
22+
self.assertEqual(result["github_url"], "github.com")
23+
self.assertEqual(result["owner"], "pytorch")
24+
self.assertEqual(result["name"], "pytorch")
25+
self.assertEqual(result["number"], 169404)
26+
27+
def test_github_url_files_suffix(self) -> None:
28+
result = ghstack.github_utils.parse_pull_request(
29+
"https://github.com/pytorch/pytorch/pull/169404/files"
30+
)
31+
self.assertEqual(result["github_url"], "github.com")
32+
self.assertEqual(result["owner"], "pytorch")
33+
self.assertEqual(result["name"], "pytorch")
34+
self.assertEqual(result["number"], 169404)
35+
36+
def test_github_url_commits_suffix(self) -> None:
37+
result = ghstack.github_utils.parse_pull_request(
38+
"https://github.com/pytorch/pytorch/pull/169404/commits"
39+
)
40+
self.assertEqual(result["github_url"], "github.com")
41+
self.assertEqual(result["owner"], "pytorch")
42+
self.assertEqual(result["name"], "pytorch")
43+
self.assertEqual(result["number"], 169404)
44+
45+
def test_github_url_commits_with_sha(self) -> None:
46+
result = ghstack.github_utils.parse_pull_request(
47+
"https://github.com/pytorch/pytorch/pull/169404/commits/abc123def"
48+
)
49+
self.assertEqual(result["github_url"], "github.com")
50+
self.assertEqual(result["owner"], "pytorch")
51+
self.assertEqual(result["name"], "pytorch")
52+
self.assertEqual(result["number"], 169404)
53+
54+
def test_pytorch_hud_url_basic(self) -> None:
55+
result = ghstack.github_utils.parse_pull_request(
56+
"https://hud.pytorch.org/pr/169404"
57+
)
58+
self.assertEqual(result["github_url"], "github.com")
59+
self.assertEqual(result["owner"], "pytorch")
60+
self.assertEqual(result["name"], "pytorch")
61+
self.assertEqual(result["number"], 169404)
62+
63+
def test_pytorch_hud_url_trailing_slash(self) -> None:
64+
result = ghstack.github_utils.parse_pull_request(
65+
"https://hud.pytorch.org/pr/169404/"
66+
)
67+
self.assertEqual(result["github_url"], "github.com")
68+
self.assertEqual(result["owner"], "pytorch")
69+
self.assertEqual(result["name"], "pytorch")
70+
self.assertEqual(result["number"], 169404)
71+
72+
def test_different_owner_repo(self) -> None:
73+
result = ghstack.github_utils.parse_pull_request(
74+
"https://github.com/facebook/react/pull/12345"
75+
)
76+
self.assertEqual(result["github_url"], "github.com")
77+
self.assertEqual(result["owner"], "facebook")
78+
self.assertEqual(result["name"], "react")
79+
self.assertEqual(result["number"], 12345)
80+
81+
def test_invalid_url_raises(self) -> None:
82+
with self.assertRaises(RuntimeError):
83+
ghstack.github_utils.parse_pull_request("not-a-valid-url")
84+
85+
def test_invalid_hud_url_raises(self) -> None:
86+
with self.assertRaises(RuntimeError):
87+
ghstack.github_utils.parse_pull_request("https://hud.pytorch.org/not-pr/123")
88+
89+
90+
if __name__ == "__main__":
91+
unittest.main()

0 commit comments

Comments
 (0)