Skip to content

Commit 1ebe8a6

Browse files
committed
f use a flag enum
1 parent 8618fb3 commit 1ebe8a6

File tree

1 file changed

+55
-34
lines changed

1 file changed

+55
-34
lines changed

lightning/src/ln/interception_tests.rs

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,38 @@ enum ForwardingMods {
3030
None,
3131
}
3232

33-
fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification: ForwardingMods) {
33+
#[derive(Clone, Copy, Debug)]
34+
enum Flag {
35+
InterceptSCIDs = (HTLCInterceptionFlags::ToInterceptSCIDs as usize).trailing_zeros() as isize,
36+
OfflinePrivChans =
37+
(HTLCInterceptionFlags::ToOfflinePrivateChannels as usize).trailing_zeros() as isize,
38+
OnlinePrivChans =
39+
(HTLCInterceptionFlags::ToOnlinePrivateChannels as usize).trailing_zeros() as isize,
40+
PublicChans = (HTLCInterceptionFlags::ToPublicChannels as usize).trailing_zeros() as isize,
41+
UnknownSCIDs = (HTLCInterceptionFlags::ToUnknownSCIDs as usize).trailing_zeros() as isize,
42+
}
43+
44+
impl Flag {
45+
fn all() -> [Flag; 5] {
46+
use Flag::*;
47+
let all = [InterceptSCIDs, OfflinePrivChans, OnlinePrivChans, PublicChans, UnknownSCIDs];
48+
49+
// Make sure that our list of flags is actually all HTLCs
50+
let mut all_flags = 0;
51+
for flag in all {
52+
all_flags |= flag.mask();
53+
}
54+
assert_eq!(all_flags, HTLCInterceptionFlags::AllValidHTLCs as u8);
55+
56+
all
57+
}
58+
59+
fn mask(&self) -> u8 {
60+
1 << *self as usize
61+
}
62+
}
63+
64+
fn do_test_htlc_interception_flags(flags_bitmask: u8, flag: Flag, modification: ForwardingMods) {
3465
// Tests that the `htlc_interception_flags` bitmask given by `flags_bitmask` correctly
3566
// intercepts (or doesn't intercept) an HTLC which is of type `flag_bit`
3667
let chanmon_cfgs = create_chanmon_cfgs(3);
@@ -53,32 +84,31 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
5384

5485
// First open the right type of channel (and get it in the right state) for the bit we're
5586
// testing.
56-
let (target_scid, target_chan_id) = match flag_bit {
57-
1 | 2 => {
87+
let (target_scid, target_chan_id) = match flag {
88+
Flag::OfflinePrivChans | Flag::OnlinePrivChans => {
5889
create_unannounced_chan_between_nodes_with_value(&nodes, 1, 2, 100000, 0);
5990
let chan_id = nodes[2].node.list_channels()[0].channel_id;
6091
let scid = nodes[2].node.list_channels()[0].short_channel_id.unwrap();
61-
if (1 << flag_bit) == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
92+
if flag.mask() == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
6293
nodes[1].node.peer_disconnected(node_2_id);
6394
nodes[2].node.peer_disconnected(node_1_id);
6495
} else {
65-
assert_eq!(1 << flag_bit, HTLCInterceptionFlags::ToOnlinePrivateChannels as u8);
96+
assert_eq!(flag.mask(), HTLCInterceptionFlags::ToOnlinePrivateChannels as u8);
6697
}
6798
(scid, chan_id)
6899
},
69-
0 | 3 | 4 => {
100+
Flag::InterceptSCIDs | Flag::PublicChans | Flag::UnknownSCIDs => {
70101
let (chan_upd, _, chan_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2);
71-
if (1 << flag_bit) == HTLCInterceptionFlags::ToInterceptSCIDs as u8 {
102+
if flag.mask() == HTLCInterceptionFlags::ToInterceptSCIDs as u8 {
72103
(nodes[1].node.get_intercept_scid(), chan_id)
73-
} else if (1 << flag_bit) == HTLCInterceptionFlags::ToPublicChannels as u8 {
104+
} else if flag.mask() == HTLCInterceptionFlags::ToPublicChannels as u8 {
74105
(chan_upd.contents.short_channel_id, chan_id)
75-
} else if (1 << flag_bit) == HTLCInterceptionFlags::ToUnknownSCIDs as u8 {
106+
} else if flag.mask() == HTLCInterceptionFlags::ToUnknownSCIDs as u8 {
76107
(42424242, chan_id)
77108
} else {
78109
panic!();
79110
}
80111
},
81-
_ => panic!("Invalid flag_bit: {}", flag_bit),
82112
};
83113

84114
// Start every node on the same block height to ensure we don't hit spurious CLTV issues
@@ -95,7 +125,7 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
95125
get_route_and_payment_hash!(nodes[0], nodes[2], pay_params, amt_msat);
96126
route.paths[0].hops[1].short_channel_id = target_scid;
97127

98-
let interception_bit_match = (flags_bitmask & (1 << flag_bit)) != 0;
128+
let interception_bit_match = (flags_bitmask & flag.mask()) != 0;
99129
match modification {
100130
ForwardingMods::FeeTooLow => {
101131
assert!(
@@ -136,13 +166,13 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
136166
if let Event::HTLCIntercepted { intercept_id: id, requested_next_hop_scid, .. } = &events[0]
137167
{
138168
assert_eq!(*requested_next_hop_scid, target_scid,
139-
"Bitmask {flags_bitmask:#x}: Expected interception for bit {flag_bit} to target SCID {target_scid}");
169+
"Bitmask {flags_bitmask:#x}: Expected interception for bit {flag:?} to target SCID {target_scid}");
140170
intercept_id = *id;
141171
} else {
142172
panic!("{events:?}");
143173
}
144174

145-
if (1 << flag_bit) == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
175+
if flag.mask() == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
146176
let mut reconnect_args = ReconnectArgs::new(&nodes[1], &nodes[2]);
147177
reconnect_args.send_channel_ready = (true, true);
148178
reconnect_nodes(reconnect_args);
@@ -165,8 +195,8 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
165195
} else {
166196
// If we were not set to intercept, check that the HTLC either failed or was
167197
// automatically forwarded as appropriate.
168-
match (modification, flag_bit) {
169-
(ForwardingMods::None, 2 | 3) => {
198+
match (modification, flag) {
199+
(ForwardingMods::None, Flag::OnlinePrivChans | Flag::PublicChans) => {
170200
check_added_monitors(&nodes[1], 1);
171201

172202
let forward_ev = SendEvent::from_node(&nodes[1]);
@@ -192,18 +222,18 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
192222
ForwardingMods::None => None,
193223
};
194224
let (expected_failure_type, reason);
195-
if (1 << flag_bit) == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
225+
if flag.mask() == HTLCInterceptionFlags::ToOfflinePrivateChannels as u8 {
196226
expected_failure_type = HTLCHandlingFailureType::Forward {
197227
node_id: Some(node_2_id),
198228
channel_id: target_chan_id,
199229
};
200230
reason = reason_from_mod.unwrap_or(LocalHTLCFailureReason::PeerOffline);
201-
} else if (1 << flag_bit) == HTLCInterceptionFlags::ToInterceptSCIDs as u8 {
231+
} else if flag.mask() == HTLCInterceptionFlags::ToInterceptSCIDs as u8 {
202232
expected_failure_type = HTLCHandlingFailureType::InvalidForward {
203233
requested_forward_scid: target_scid,
204234
};
205235
reason = reason_from_mod.unwrap_or(LocalHTLCFailureReason::UnknownNextPeer);
206-
} else if (1 << flag_bit) == HTLCInterceptionFlags::ToUnknownSCIDs as u8 {
236+
} else if flag.mask() == HTLCInterceptionFlags::ToUnknownSCIDs as u8 {
207237
expected_failure_type = HTLCHandlingFailureType::InvalidForward {
208238
requested_forward_scid: target_scid,
209239
};
@@ -235,17 +265,14 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
235265
}
236266

237267
const MAX_BITMASK: u8 = HTLCInterceptionFlags::AllValidHTLCs as u8;
238-
const MAX_FLAG: u8 = 4;
239268

240269
#[test]
241270
fn test_htlc_interception_flags() {
242271
// Test all 2^5 = 32 combinations of the HTLCInterceptionFlags bitmask
243272
// For each combination, test 5 different HTLC forwards and verify correct interception behavior
244-
assert_eq!((1 << MAX_FLAG + 1) - 1, MAX_BITMASK);
245-
246273
for flags_bitmask in 0..=MAX_BITMASK {
247-
for flag_bit in 0..=MAX_FLAG {
248-
do_test_htlc_interception_flags(flags_bitmask, flag_bit, ForwardingMods::None);
274+
for flag in Flag::all() {
275+
do_test_htlc_interception_flags(flags_bitmask, flag, ForwardingMods::None);
249276
}
250277
}
251278
}
@@ -254,24 +281,18 @@ fn test_htlc_interception_flags() {
254281
fn test_htlc_bad_for_chan_config() {
255282
// Test that interception won't be done if an HTLC fails to meet the target channel's channel
256283
// config.
257-
let have_chan_flags = [
258-
HTLCInterceptionFlags::ToOfflinePrivateChannels,
259-
HTLCInterceptionFlags::ToOnlinePrivateChannels,
260-
HTLCInterceptionFlags::ToPublicChannels,
261-
];
284+
let have_chan_flags = [Flag::OfflinePrivChans, Flag::OnlinePrivChans, Flag::PublicChans];
262285
for flag in have_chan_flags {
263-
assert_eq!((flag as u8).count_ones(), 1);
264-
let bit = (flag as u8).trailing_zeros() as u8;
265-
do_test_htlc_interception_flags(flag as u8, bit, ForwardingMods::FeeTooLow);
266-
do_test_htlc_interception_flags(flag as u8, bit, ForwardingMods::CLTVBelowConfig);
286+
do_test_htlc_interception_flags(flag.mask(), flag, ForwardingMods::FeeTooLow);
287+
do_test_htlc_interception_flags(flag.mask(), flag, ForwardingMods::CLTVBelowConfig);
267288
}
268289
}
269290

270291
#[test]
271292
fn test_htlc_bad_no_chan() {
272293
// Test that setting the CLTV below the hard-coded minimum fails whether we're intercepting for
273294
// a channel or not.
274-
for flag_bit in 0..=MAX_FLAG {
275-
do_test_htlc_interception_flags(1 << flag_bit, flag_bit, ForwardingMods::CLTVBelowMin);
295+
for flag in Flag::all() {
296+
do_test_htlc_interception_flags(flag.mask(), flag, ForwardingMods::CLTVBelowMin);
276297
}
277298
}

0 commit comments

Comments
 (0)