@@ -30,7 +30,38 @@ enum ForwardingMods {
3030 None ,
3131}
3232
33- fn do_test_htlc_interception_flags ( flags_bitmask : u8 , flag_bit : u8 , modification : ForwardingMods ) {
33+ #[ derive( Clone , Copy , Debug ) ]
34+ enum Flag {
35+ InterceptSCIDs = ( HTLCInterceptionFlags :: ToInterceptSCIDs as usize ) . trailing_zeros ( ) as isize ,
36+ OfflinePrivChans =
37+ ( HTLCInterceptionFlags :: ToOfflinePrivateChannels as usize ) . trailing_zeros ( ) as isize ,
38+ OnlinePrivChans =
39+ ( HTLCInterceptionFlags :: ToOnlinePrivateChannels as usize ) . trailing_zeros ( ) as isize ,
40+ PublicChans = ( HTLCInterceptionFlags :: ToPublicChannels as usize ) . trailing_zeros ( ) as isize ,
41+ UnknownSCIDs = ( HTLCInterceptionFlags :: ToUnknownSCIDs as usize ) . trailing_zeros ( ) as isize ,
42+ }
43+
44+ impl Flag {
45+ fn all ( ) -> [ Flag ; 5 ] {
46+ use Flag :: * ;
47+ let all = [ InterceptSCIDs , OfflinePrivChans , OnlinePrivChans , PublicChans , UnknownSCIDs ] ;
48+
49+ // Make sure that our list of flags is actually all HTLCs
50+ let mut all_flags = 0 ;
51+ for flag in all {
52+ all_flags |= flag. mask ( ) ;
53+ }
54+ assert_eq ! ( all_flags, HTLCInterceptionFlags :: AllValidHTLCs as u8 ) ;
55+
56+ all
57+ }
58+
59+ fn mask ( & self ) -> u8 {
60+ 1 << * self as usize
61+ }
62+ }
63+
64+ fn do_test_htlc_interception_flags ( flags_bitmask : u8 , flag : Flag , modification : ForwardingMods ) {
3465 // Tests that the `htlc_interception_flags` bitmask given by `flags_bitmask` correctly
3566 // intercepts (or doesn't intercept) an HTLC which is of type `flag_bit`
3667 let chanmon_cfgs = create_chanmon_cfgs ( 3 ) ;
@@ -53,32 +84,31 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
5384
5485 // First open the right type of channel (and get it in the right state) for the bit we're
5586 // testing.
56- let ( target_scid, target_chan_id) = match flag_bit {
57- 1 | 2 => {
87+ let ( target_scid, target_chan_id) = match flag {
88+ Flag :: OfflinePrivChans | Flag :: OnlinePrivChans => {
5889 create_unannounced_chan_between_nodes_with_value ( & nodes, 1 , 2 , 100000 , 0 ) ;
5990 let chan_id = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . channel_id ;
6091 let scid = nodes[ 2 ] . node . list_channels ( ) [ 0 ] . short_channel_id . unwrap ( ) ;
61- if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
92+ if flag . mask ( ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
6293 nodes[ 1 ] . node . peer_disconnected ( node_2_id) ;
6394 nodes[ 2 ] . node . peer_disconnected ( node_1_id) ;
6495 } else {
65- assert_eq ! ( 1 << flag_bit , HTLCInterceptionFlags :: ToOnlinePrivateChannels as u8 ) ;
96+ assert_eq ! ( flag . mask ( ) , HTLCInterceptionFlags :: ToOnlinePrivateChannels as u8 ) ;
6697 }
6798 ( scid, chan_id)
6899 } ,
69- 0 | 3 | 4 => {
100+ Flag :: InterceptSCIDs | Flag :: PublicChans | Flag :: UnknownSCIDs => {
70101 let ( chan_upd, _, chan_id, _) = create_announced_chan_between_nodes ( & nodes, 1 , 2 ) ;
71- if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToInterceptSCIDs as u8 {
102+ if flag . mask ( ) == HTLCInterceptionFlags :: ToInterceptSCIDs as u8 {
72103 ( nodes[ 1 ] . node . get_intercept_scid ( ) , chan_id)
73- } else if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToPublicChannels as u8 {
104+ } else if flag . mask ( ) == HTLCInterceptionFlags :: ToPublicChannels as u8 {
74105 ( chan_upd. contents . short_channel_id , chan_id)
75- } else if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToUnknownSCIDs as u8 {
106+ } else if flag . mask ( ) == HTLCInterceptionFlags :: ToUnknownSCIDs as u8 {
76107 ( 42424242 , chan_id)
77108 } else {
78109 panic ! ( ) ;
79110 }
80111 } ,
81- _ => panic ! ( "Invalid flag_bit: {}" , flag_bit) ,
82112 } ;
83113
84114 // Start every node on the same block height to ensure we don't hit spurious CLTV issues
@@ -95,7 +125,7 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
95125 get_route_and_payment_hash ! ( nodes[ 0 ] , nodes[ 2 ] , pay_params, amt_msat) ;
96126 route. paths [ 0 ] . hops [ 1 ] . short_channel_id = target_scid;
97127
98- let interception_bit_match = ( flags_bitmask & ( 1 << flag_bit ) ) != 0 ;
128+ let interception_bit_match = ( flags_bitmask & flag . mask ( ) ) != 0 ;
99129 match modification {
100130 ForwardingMods :: FeeTooLow => {
101131 assert ! (
@@ -136,13 +166,13 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
136166 if let Event :: HTLCIntercepted { intercept_id : id, requested_next_hop_scid, .. } = & events[ 0 ]
137167 {
138168 assert_eq ! ( * requested_next_hop_scid, target_scid,
139- "Bitmask {flags_bitmask:#x}: Expected interception for bit {flag_bit } to target SCID {target_scid}" ) ;
169+ "Bitmask {flags_bitmask:#x}: Expected interception for bit {flag:? } to target SCID {target_scid}" ) ;
140170 intercept_id = * id;
141171 } else {
142172 panic ! ( "{events:?}" ) ;
143173 }
144174
145- if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
175+ if flag . mask ( ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
146176 let mut reconnect_args = ReconnectArgs :: new ( & nodes[ 1 ] , & nodes[ 2 ] ) ;
147177 reconnect_args. send_channel_ready = ( true , true ) ;
148178 reconnect_nodes ( reconnect_args) ;
@@ -165,8 +195,8 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
165195 } else {
166196 // If we were not set to intercept, check that the HTLC either failed or was
167197 // automatically forwarded as appropriate.
168- match ( modification, flag_bit ) {
169- ( ForwardingMods :: None , 2 | 3 ) => {
198+ match ( modification, flag ) {
199+ ( ForwardingMods :: None , Flag :: OnlinePrivChans | Flag :: PublicChans ) => {
170200 check_added_monitors ( & nodes[ 1 ] , 1 ) ;
171201
172202 let forward_ev = SendEvent :: from_node ( & nodes[ 1 ] ) ;
@@ -192,18 +222,18 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
192222 ForwardingMods :: None => None ,
193223 } ;
194224 let ( expected_failure_type, reason) ;
195- if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
225+ if flag . mask ( ) == HTLCInterceptionFlags :: ToOfflinePrivateChannels as u8 {
196226 expected_failure_type = HTLCHandlingFailureType :: Forward {
197227 node_id : Some ( node_2_id) ,
198228 channel_id : target_chan_id,
199229 } ;
200230 reason = reason_from_mod. unwrap_or ( LocalHTLCFailureReason :: PeerOffline ) ;
201- } else if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToInterceptSCIDs as u8 {
231+ } else if flag . mask ( ) == HTLCInterceptionFlags :: ToInterceptSCIDs as u8 {
202232 expected_failure_type = HTLCHandlingFailureType :: InvalidForward {
203233 requested_forward_scid : target_scid,
204234 } ;
205235 reason = reason_from_mod. unwrap_or ( LocalHTLCFailureReason :: UnknownNextPeer ) ;
206- } else if ( 1 << flag_bit ) == HTLCInterceptionFlags :: ToUnknownSCIDs as u8 {
236+ } else if flag . mask ( ) == HTLCInterceptionFlags :: ToUnknownSCIDs as u8 {
207237 expected_failure_type = HTLCHandlingFailureType :: InvalidForward {
208238 requested_forward_scid : target_scid,
209239 } ;
@@ -235,17 +265,14 @@ fn do_test_htlc_interception_flags(flags_bitmask: u8, flag_bit: u8, modification
235265}
236266
237267const MAX_BITMASK : u8 = HTLCInterceptionFlags :: AllValidHTLCs as u8 ;
238- const MAX_FLAG : u8 = 4 ;
239268
240269#[ test]
241270fn test_htlc_interception_flags ( ) {
242271 // Test all 2^5 = 32 combinations of the HTLCInterceptionFlags bitmask
243272 // For each combination, test 5 different HTLC forwards and verify correct interception behavior
244- assert_eq ! ( ( 1 << MAX_FLAG + 1 ) - 1 , MAX_BITMASK ) ;
245-
246273 for flags_bitmask in 0 ..=MAX_BITMASK {
247- for flag_bit in 0 ..= MAX_FLAG {
248- do_test_htlc_interception_flags ( flags_bitmask, flag_bit , ForwardingMods :: None ) ;
274+ for flag in Flag :: all ( ) {
275+ do_test_htlc_interception_flags ( flags_bitmask, flag , ForwardingMods :: None ) ;
249276 }
250277 }
251278}
@@ -254,24 +281,18 @@ fn test_htlc_interception_flags() {
254281fn test_htlc_bad_for_chan_config ( ) {
255282 // Test that interception won't be done if an HTLC fails to meet the target channel's channel
256283 // config.
257- let have_chan_flags = [
258- HTLCInterceptionFlags :: ToOfflinePrivateChannels ,
259- HTLCInterceptionFlags :: ToOnlinePrivateChannels ,
260- HTLCInterceptionFlags :: ToPublicChannels ,
261- ] ;
284+ let have_chan_flags = [ Flag :: OfflinePrivChans , Flag :: OnlinePrivChans , Flag :: PublicChans ] ;
262285 for flag in have_chan_flags {
263- assert_eq ! ( ( flag as u8 ) . count_ones( ) , 1 ) ;
264- let bit = ( flag as u8 ) . trailing_zeros ( ) as u8 ;
265- do_test_htlc_interception_flags ( flag as u8 , bit, ForwardingMods :: FeeTooLow ) ;
266- do_test_htlc_interception_flags ( flag as u8 , bit, ForwardingMods :: CLTVBelowConfig ) ;
286+ do_test_htlc_interception_flags ( flag. mask ( ) , flag, ForwardingMods :: FeeTooLow ) ;
287+ do_test_htlc_interception_flags ( flag. mask ( ) , flag, ForwardingMods :: CLTVBelowConfig ) ;
267288 }
268289}
269290
270291#[ test]
271292fn test_htlc_bad_no_chan ( ) {
272293 // Test that setting the CLTV below the hard-coded minimum fails whether we're intercepting for
273294 // a channel or not.
274- for flag_bit in 0 ..= MAX_FLAG {
275- do_test_htlc_interception_flags ( 1 << flag_bit , flag_bit , ForwardingMods :: CLTVBelowMin ) ;
295+ for flag in Flag :: all ( ) {
296+ do_test_htlc_interception_flags ( flag . mask ( ) , flag , ForwardingMods :: CLTVBelowMin ) ;
276297 }
277298}
0 commit comments