Skip to content

Commit d49b1da

Browse files
committed
Preserve the signature of the wrappers in gpt_layer_specs
1 parent fb474d0 commit d49b1da

3 files changed

Lines changed: 340 additions & 1 deletion

File tree

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TransformerLayerSubmodules,
3838
get_transformer_layer_offset,
3939
)
40+
from megatron.core.typed_torch import copy_signature
4041
from megatron.core.utils import is_te_min_version
4142

4243
try:
@@ -166,6 +167,7 @@ def get_gpt_layer_with_inference_submodules(
166167

167168

168169
@functools.wraps(get_gpt_layer_with_inference_submodules)
170+
@copy_signature(get_gpt_layer_with_inference_submodules)
169171
def get_gpt_layer_with_inference_spec(*args, **kwargs) -> ModuleSpec:
170172
"""Use this spec to use inference optimized linear layers.
171173
@@ -306,6 +308,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
306308

307309

308310
@functools.wraps(get_gpt_layer_with_transformer_engine_submodules)
311+
@copy_signature(get_gpt_layer_with_transformer_engine_submodules)
309312
def get_gpt_layer_with_transformer_engine_spec(*args, **kwargs) -> ModuleSpec:
310313
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
311314
@@ -431,6 +434,7 @@ def get_gpt_layer_local_submodules(
431434

432435

433436
@functools.wraps(get_gpt_layer_local_submodules)
437+
@copy_signature(get_gpt_layer_local_submodules)
434438
def get_gpt_layer_local_spec(*args, **kwargs) -> ModuleSpec:
435439
"""Use this spec for an implementation using only modules in Megatron-Core.
436440

megatron/core/typed_torch.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
"""Utilities for improved type hinting with torch interfaces."""
33
from __future__ import annotations
44

5+
import inspect
56
from collections.abc import Callable
6-
from typing import Generic, ParamSpec, Protocol, TypeVar
7+
from typing import Any, Concatenate, Generic, Literal, ParamSpec, Protocol, TypeVar, overload
78

89
import torch
910

@@ -48,3 +49,154 @@ def not_none(value: T | None) -> T:
4849
if value is None:
4950
raise ValueError('Expected value to be not None')
5051
return value
52+
53+
54+
R_src = TypeVar('R_src')
55+
R_dst = TypeVar('R_dst')
56+
P_src = ParamSpec('P_src')
57+
P_dst = ParamSpec('P_dst')
58+
First_dst = TypeVar('First_dst')
59+
60+
61+
@overload
62+
def copy_signature(
63+
source: Callable[P_src, Any],
64+
/,
65+
*,
66+
handle_return_type: Literal['preserve'] = 'preserve',
67+
handle_first_src_param: Literal['copy'] = 'copy',
68+
handle_first_dst_param: Literal['drop'] = 'drop',
69+
) -> Callable[[Callable[..., R_dst]], Callable[P_src, R_dst]]: ...
70+
71+
72+
@overload
73+
def copy_signature(
74+
source: Callable[P_src, R_src],
75+
/,
76+
*,
77+
handle_return_type: Literal['overwrite'],
78+
handle_first_src_param: Literal['copy'] = 'copy',
79+
handle_first_dst_param: Literal['drop'] = 'drop',
80+
) -> Callable[[Callable[..., Any]], Callable[P_src, R_src]]: ...
81+
82+
83+
@overload
84+
def copy_signature(
85+
source: Callable[Concatenate[Any, P_src], Any],
86+
/,
87+
*,
88+
handle_return_type: Literal['preserve'] = 'preserve',
89+
handle_first_src_param: Literal['skip'],
90+
handle_first_dst_param: Literal['drop'] = 'drop',
91+
) -> Callable[[Callable[..., R_dst]], Callable[P_src, R_dst]]: ...
92+
93+
94+
@overload
95+
def copy_signature(
96+
source: Callable[Concatenate[Any, P_src], R_src],
97+
/,
98+
*,
99+
handle_return_type: Literal['overwrite'],
100+
handle_first_src_param: Literal['skip'],
101+
handle_first_dst_param: Literal['drop'] = 'drop',
102+
) -> Callable[[Callable[..., Any]], Callable[P_src, R_src]]: ...
103+
104+
105+
@overload
106+
def copy_signature(
107+
source: Callable[P_src, Any],
108+
/,
109+
*,
110+
handle_return_type: Literal['preserve'] = 'preserve',
111+
handle_first_src_param: Literal['copy'] = 'copy',
112+
handle_first_dst_param: Literal['preserve'],
113+
) -> Callable[
114+
[Callable[Concatenate[First_dst, ...], R_dst]], Callable[Concatenate[First_dst, P_src], R_dst]
115+
]: ...
116+
117+
118+
@overload
119+
def copy_signature(
120+
source: Callable[P_src, R_src],
121+
/,
122+
*,
123+
handle_return_type: Literal['overwrite'],
124+
handle_first_src_param: Literal['copy'] = 'copy',
125+
handle_first_dst_param: Literal['preserve'],
126+
) -> Callable[
127+
[Callable[Concatenate[First_dst, ...], Any]], Callable[Concatenate[First_dst, P_src], R_src]
128+
]: ...
129+
130+
131+
@overload
132+
def copy_signature(
133+
source: Callable[Concatenate[Any, P_src], Any],
134+
/,
135+
*,
136+
handle_return_type: Literal['preserve'] = 'preserve',
137+
handle_first_src_param: Literal['skip'],
138+
handle_first_dst_param: Literal['preserve'],
139+
) -> Callable[
140+
[Callable[Concatenate[First_dst, ...], R_dst]], Callable[Concatenate[First_dst, P_src], R_dst]
141+
]: ...
142+
143+
144+
@overload
145+
def copy_signature(
146+
source: Callable[Concatenate[Any, P_src], R_src],
147+
/,
148+
*,
149+
handle_return_type: Literal['overwrite'],
150+
handle_first_src_param: Literal['skip'],
151+
handle_first_dst_param: Literal['preserve'],
152+
) -> Callable[
153+
[Callable[Concatenate[First_dst, ...], Any]], Callable[Concatenate[First_dst, P_src], R_src]
154+
]: ...
155+
156+
157+
def copy_signature(
158+
source: Callable[..., Any],
159+
/,
160+
*,
161+
handle_return_type: Literal['preserve', 'overwrite'] = 'preserve',
162+
handle_first_src_param: Literal['copy', 'skip'] = 'copy',
163+
handle_first_dst_param: Literal['preserve', 'drop'] = 'drop',
164+
):
165+
"""Decorator to copy the signature from one function to another.
166+
167+
Args:
168+
source: The function or callable from which to copy the signature.
169+
handle_return_type: How to handle the return type annotation.
170+
'preserve' to keep the decorated function's return type,
171+
'overwrite' to use the source function's return type.
172+
handle_first_src_param: How to handle the first parameter of the source function.
173+
'copy' to include it in the decorated function's signature,
174+
'skip' to exclude it. Useful for removing 'self' or 'cls'.
175+
handle_first_dst_param: How to handle the first parameter of the decorated function.
176+
'preserve' to keep it in the decorated function's signature,
177+
'drop' to exclude it. Useful for preserving 'self' or 'cls'.
178+
179+
Returns:
180+
A decorator that copies the signature from `func` to the decorated function.
181+
"""
182+
source_signature = inspect.signature(source)
183+
184+
def decorator(decorated: Callable[..., Any], /) -> Callable[..., Any]:
185+
dest_signature = inspect.signature(decorated)
186+
new_params = []
187+
if handle_first_dst_param == 'preserve':
188+
new_params.append(next(iter(dest_signature.parameters.values())))
189+
src_params_iter = iter(source_signature.parameters.values())
190+
if handle_first_src_param == 'skip':
191+
next(src_params_iter)
192+
new_params.extend(src_params_iter)
193+
new_signature = dest_signature.replace(parameters=new_params)
194+
if handle_return_type == 'overwrite':
195+
new_signature = new_signature.replace(
196+
return_annotation=source_signature.return_annotation
197+
)
198+
199+
decorated.__signature__ = new_signature # type: ignore
200+
return decorated
201+
202+
return decorator
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import inspect
2+
from typing import Any
3+
4+
import pytest
5+
6+
from megatron.core.typed_torch import copy_signature, not_none
7+
8+
9+
def source_func(a: int, *, b: str) -> str:
10+
"""Sample function to copy the signature from."""
11+
return str(a) + b
12+
13+
14+
class SourceClass:
15+
"""Sample class with a method to copy the signature from."""
16+
17+
def method(self, a: int, *, b: str) -> str:
18+
"""Sample method to copy the signature from."""
19+
return str(a) + b
20+
21+
22+
@copy_signature(source_func)
23+
def dest_func_from_func(*args: Any, **kwargs: Any) -> list[str]:
24+
"""Function with copied signature from source_func."""
25+
return [source_func(*args, **kwargs)]
26+
27+
28+
@copy_signature(source_func, handle_return_type='overwrite')
29+
def dest_func_from_func_overwrite(*args: Any, **kwargs: Any) -> object:
30+
"""Function with copied signature from source_func, but overwritten return type."""
31+
return source_func(*args, **kwargs)
32+
33+
34+
@copy_signature(SourceClass.method, handle_first_src_param='skip')
35+
def dest_func_from_method(*args: Any, **kwargs: Any) -> int:
36+
"""Function with copied signature from SourceClass.method."""
37+
return len(SourceClass().method(*args, **kwargs))
38+
39+
40+
@copy_signature(SourceClass.method, handle_return_type='overwrite', handle_first_src_param='skip')
41+
def dest_func_from_method_overwrite(*args: Any, **kwargs: Any) -> object:
42+
"""Function with copied signature from SourceClass.method, but overwritten return type."""
43+
return SourceClass().method(*args, **kwargs)
44+
45+
46+
class DestClass:
47+
"""Class with methods that have copied signatures."""
48+
49+
@copy_signature(source_func, handle_first_dst_param='preserve')
50+
def dest_method_from_func(self, *args: Any, **kwargs: Any) -> list[str]:
51+
"""Method with copied signature from source_func."""
52+
return [source_func(*args, **kwargs)]
53+
54+
@copy_signature(source_func, handle_return_type='overwrite', handle_first_dst_param='preserve')
55+
def dest_method_from_func_overwrite(self, *args: Any, **kwargs: Any) -> object:
56+
"""Method with copied signature from source_func, but overwritten return type."""
57+
return source_func(*args, **kwargs)
58+
59+
@classmethod
60+
@copy_signature(
61+
SourceClass.method, handle_first_src_param='skip', handle_first_dst_param='preserve'
62+
)
63+
def dest_method_from_method(cls, *args: Any, **kwargs: Any) -> int:
64+
"""Class method with copied signature from SourceClass.method."""
65+
return len(SourceClass().method(*args, **kwargs))
66+
67+
@copy_signature(
68+
SourceClass.method,
69+
handle_return_type='overwrite',
70+
handle_first_src_param='skip',
71+
handle_first_dst_param='preserve',
72+
)
73+
def dest_method_from_method_overwrite(self, *args: Any, **kwargs: Any) -> object:
74+
"""Method with copied signature from SourceClass.method, but overwritten return type."""
75+
return SourceClass().method(*args, **kwargs)
76+
77+
78+
class TestCopySignature:
79+
def test_original_return_type(self):
80+
"""Test that the original return types are preserved."""
81+
f2f: list[str] = dest_func_from_func(1, b='a')
82+
assert f2f == ['1a']
83+
assert inspect.signature(dest_func_from_func) == inspect.Signature(
84+
[
85+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
86+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
87+
],
88+
return_annotation=list[str],
89+
)
90+
91+
m2f: int = dest_func_from_method(1, b='a')
92+
assert m2f == 2
93+
assert inspect.signature(dest_func_from_method) == inspect.Signature(
94+
[
95+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
96+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
97+
],
98+
return_annotation=int,
99+
)
100+
101+
f2m: list[str] = DestClass().dest_method_from_func(
102+
1, b='a'
103+
) + DestClass.dest_method_from_func(DestClass(), 1, b='a')
104+
assert f2m == ['1a', '1a']
105+
assert inspect.signature(DestClass.dest_method_from_func) == inspect.Signature(
106+
[
107+
inspect.Parameter('self', inspect.Parameter.POSITIONAL_OR_KEYWORD),
108+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
109+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
110+
],
111+
return_annotation=list[str],
112+
)
113+
114+
m2m: int = DestClass.dest_method_from_method(1, b='a')
115+
assert m2m == 2
116+
assert inspect.signature(DestClass.dest_method_from_method) == inspect.Signature(
117+
[
118+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
119+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
120+
],
121+
return_annotation=int,
122+
)
123+
124+
def test_overwritten_return_type(self):
125+
"""Test that the return types are overwritten correctly."""
126+
f2f: str = dest_func_from_func_overwrite(1, b='a')
127+
assert f2f == '1a'
128+
assert inspect.signature(dest_func_from_func_overwrite) == inspect.Signature(
129+
[
130+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
131+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
132+
],
133+
return_annotation=str,
134+
)
135+
136+
m2f: str = dest_func_from_method_overwrite(1, b='a')
137+
assert m2f == '1a'
138+
assert inspect.signature(dest_func_from_method_overwrite) == inspect.Signature(
139+
[
140+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
141+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
142+
],
143+
return_annotation=str,
144+
)
145+
146+
f2m: str = DestClass().dest_method_from_func_overwrite(
147+
1, b='a'
148+
) + DestClass.dest_method_from_func_overwrite(DestClass(), 1, b='a')
149+
assert f2m == '1a1a'
150+
assert inspect.signature(DestClass.dest_method_from_func_overwrite) == inspect.Signature(
151+
[
152+
inspect.Parameter('self', inspect.Parameter.POSITIONAL_OR_KEYWORD),
153+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
154+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
155+
],
156+
return_annotation=str,
157+
)
158+
159+
m2m: str = DestClass().dest_method_from_method_overwrite(1, b='a')
160+
assert m2m == '1a'
161+
assert inspect.signature(DestClass.dest_method_from_method_overwrite) == inspect.Signature(
162+
[
163+
inspect.Parameter('self', inspect.Parameter.POSITIONAL_OR_KEYWORD),
164+
inspect.Parameter('a', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int),
165+
inspect.Parameter('b', inspect.Parameter.KEYWORD_ONLY, annotation=str),
166+
],
167+
return_annotation=str,
168+
)
169+
170+
171+
class TestNotNone:
172+
"""Tests not_none."""
173+
174+
def test_none(self):
175+
"""Test that passing None raises a ValueError."""
176+
with pytest.raises(ValueError, match=r'Expected value to be not None'):
177+
not_none(None)
178+
179+
def test_not_none(self):
180+
"""Test that passing a non-None value returns the value."""
181+
value = 42
182+
result = not_none(value)
183+
assert result == value

0 commit comments

Comments
 (0)