Skip to content

Commit c77d787

Browse files
committed
Refactor AttackCategory enum verify
1 parent 0b76192 commit c77d787

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

torchattack/attack.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ class AttackCategory(Enum):
1616

1717
@classmethod
1818
def verify(cls, obj: Union[str, 'AttackCategory']) -> 'AttackCategory':
19-
if obj is not None:
20-
if type(obj) is str:
21-
obj = cls[obj.replace(cls.__name__ + '.', '')]
22-
elif not isinstance(obj, cls):
23-
raise TypeError(
24-
f'Invalid AttackCategory class provided; expected {cls.__name__} '
25-
f'but received {obj.__class__.__name__}.'
26-
)
27-
return obj
19+
if type(obj) is str:
20+
return cls[obj.replace(cls.__name__ + '.', '')]
21+
elif isinstance(obj, cls):
22+
return obj
23+
else:
24+
raise TypeError(
25+
f'Invalid AttackCategory class provided; expected {cls.__name__} '
26+
f'but received {obj.__class__.__name__}.'
27+
)
2828

2929

3030
ATTACK_REGISTRY: dict[str, Type['Attack']] = {}
@@ -128,6 +128,7 @@ def __eq__(self, other: Any) -> bool:
128128
'hooks', # PNAPatchOut, TGR, VDC
129129
'sub_basis', # GeoDA
130130
'generator', # BIA, CDA, LTP
131+
'lbq', # MuMoDIG
131132
]
132133
for attr in eq_name_attrs:
133134
if not (hasattr(self, attr) and hasattr(other, attr)):

0 commit comments

Comments
 (0)