Skip to content

Commit 9af0db4

Browse files
committed
f refactor forward-checking api
1 parent ea6f661 commit 9af0db4

File tree

1 file changed

+46
-55
lines changed

1 file changed

+46
-55
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4993,55 +4993,53 @@ where
49934993
}
49944994
}
49954995

4996-
fn forward_needs_intercept(
4997-
&self, outbound_chan: Option<&FundedChannel<SP>>, outgoing_scid: u64,
4998-
) -> bool {
4996+
fn forward_needs_intercept_to_known_chan(&self, outbound_chan: &FundedChannel<SP>) -> bool {
49994997
let intercept_flags = self.config.read().unwrap().htlc_interception_flags;
5000-
if let Some(chan) = outbound_chan {
5001-
if !chan.context.should_announce() {
5002-
if chan.context.is_connected() {
5003-
if intercept_flags & (HTLCInterceptionFlags::ToOnlinePrivateChannels as u8) != 0
5004-
{
5005-
return true;
5006-
}
5007-
} else {
5008-
if intercept_flags & (HTLCInterceptionFlags::ToOfflinePrivateChannels as u8)
5009-
!= 0
5010-
{
5011-
return true;
5012-
}
4998+
if !outbound_chan.context.should_announce() {
4999+
if outbound_chan.context.is_connected() {
5000+
if intercept_flags & (HTLCInterceptionFlags::ToOnlinePrivateChannels as u8) != 0 {
5001+
return true;
50135002
}
50145003
} else {
5015-
if intercept_flags & (HTLCInterceptionFlags::ToPublicChannels as u8) != 0 {
5004+
if intercept_flags & (HTLCInterceptionFlags::ToOfflinePrivateChannels as u8) != 0 {
50165005
return true;
50175006
}
50185007
}
50195008
} else {
5020-
if fake_scid::is_valid_intercept(
5021-
&self.fake_scid_rand_bytes,
5022-
outgoing_scid,
5023-
&self.chain_hash,
5024-
) {
5025-
if intercept_flags & (HTLCInterceptionFlags::ToInterceptSCIDs as u8) != 0 {
5026-
return true;
5027-
}
5028-
} else if fake_scid::is_valid_phantom(
5029-
&self.fake_scid_rand_bytes,
5030-
outgoing_scid,
5031-
&self.chain_hash,
5032-
) {
5033-
// Handled as a normal forward
5034-
} else if intercept_flags & (HTLCInterceptionFlags::ToUnknownSCIDs as u8) != 0 {
5009+
if intercept_flags & (HTLCInterceptionFlags::ToPublicChannels as u8) != 0 {
50355010
return true;
50365011
}
50375012
}
50385013
false
50395014
}
50405015

5016+
fn forward_needs_intercept_to_unknown_chan(&self, outgoing_scid: u64) -> bool {
5017+
let intercept_flags = self.config.read().unwrap().htlc_interception_flags;
5018+
if fake_scid::is_valid_intercept(
5019+
&self.fake_scid_rand_bytes,
5020+
outgoing_scid,
5021+
&self.chain_hash,
5022+
) {
5023+
if intercept_flags & (HTLCInterceptionFlags::ToInterceptSCIDs as u8) != 0 {
5024+
return true;
5025+
}
5026+
} else if fake_scid::is_valid_phantom(
5027+
&self.fake_scid_rand_bytes,
5028+
outgoing_scid,
5029+
&self.chain_hash,
5030+
) {
5031+
// Handled as a normal forward
5032+
} else if intercept_flags & (HTLCInterceptionFlags::ToUnknownSCIDs as u8) != 0 {
5033+
return true;
5034+
}
5035+
false
5036+
}
5037+
50415038
#[rustfmt::skip]
50425039
fn can_forward_htlc_to_outgoing_channel(
5043-
&self, chan: &mut FundedChannel<SP>, msg: &msgs::UpdateAddHTLC, next_packet: &NextPacketDetails
5044-
) -> Result<bool, LocalHTLCFailureReason> {
5040+
&self, chan: &mut FundedChannel<SP>, msg: &msgs::UpdateAddHTLC,
5041+
next_packet: &NextPacketDetails, will_intercept: bool,
5042+
) -> Result<(), LocalHTLCFailureReason> {
50455043
if !chan.context.should_announce()
50465044
&& !self.config.read().unwrap().accept_forwards_to_priv_channels
50475045
{
@@ -5050,15 +5048,13 @@ where
50505048
// we don't allow forwards outbound over them.
50515049
return Err(LocalHTLCFailureReason::PrivateChannelForward);
50525050
}
5053-
let intercepted;
50545051
if let HopConnector::ShortChannelId(outgoing_scid) = next_packet.outgoing_connector {
50555052
if chan.funding.get_channel_type().supports_scid_privacy() && outgoing_scid != chan.context.outbound_scid_alias() {
50565053
// `option_scid_alias` (referred to in LDK as `scid_privacy`) means
50575054
// "refuse to forward unless the SCID alias was used", so we pretend
50585055
// we don't have the channel here.
50595056
return Err(LocalHTLCFailureReason::RealSCIDForward);
50605057
}
5061-
intercepted = self.forward_needs_intercept(Some(chan), outgoing_scid);
50625058
} else {
50635059
return Err(LocalHTLCFailureReason::InvalidTrampolineForward);
50645060
}
@@ -5068,7 +5064,7 @@ where
50685064
// around to doing the actual forward, but better to fail early if we can and
50695065
// hopefully an attacker trying to path-trace payments cannot make this occur
50705066
// on a small/per-node/per-channel scale.
5071-
if !intercepted && !chan.context.is_live() {
5067+
if !will_intercept && !chan.context.is_live() {
50725068
if !chan.context.is_enabled() {
50735069
return Err(LocalHTLCFailureReason::ChannelDisabled);
50745070
} else if !chan.context.is_connected() {
@@ -5080,9 +5076,7 @@ where
50805076
if next_packet.outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() {
50815077
return Err(LocalHTLCFailureReason::AmountBelowMinimum);
50825078
}
5083-
chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value)?;
5084-
5085-
Ok(intercepted)
5079+
chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value)
50865080
}
50875081

50885082
/// Executes a callback `C` that returns some value `X` on the channel found with the given
@@ -5109,31 +5103,32 @@ where
51095103
}
51105104

51115105
fn can_forward_htlc_intercepted(
5112-
&self, msg: &msgs::UpdateAddHTLC, next_packet_details: &NextPacketDetails,
5106+
&self, msg: &msgs::UpdateAddHTLC, next_hop: &NextPacketDetails,
51135107
) -> Result<bool, LocalHTLCFailureReason> {
5114-
let outgoing_scid = match next_packet_details.outgoing_connector {
5108+
let outgoing_scid = match next_hop.outgoing_connector {
51155109
HopConnector::ShortChannelId(scid) => scid,
51165110
HopConnector::Trampoline(_) => {
51175111
return Err(LocalHTLCFailureReason::InvalidTrampolineForward);
51185112
},
51195113
};
51205114
// TODO: We do the fake SCID namespace check a bunch of times here (and indirectly via
5121-
// `forward_needs_intercept`, including as called in
5115+
// `forward_needs_intercept_*`, including as called in
51225116
// `can_forward_htlc_to_outgoing_channel`), we should find a way to reduce the number of
51235117
// times we do it.
51245118
let intercept =
51255119
match self.do_funded_channel_callback(outgoing_scid, |chan: &mut FundedChannel<SP>| {
5126-
self.can_forward_htlc_to_outgoing_channel(chan, msg, next_packet_details)
5120+
let intercept = self.forward_needs_intercept_to_known_chan(chan);
5121+
self.can_forward_htlc_to_outgoing_channel(chan, msg, next_hop, intercept)?;
5122+
Ok(intercept)
51275123
}) {
51285124
Some(Ok(intercept)) => intercept,
51295125
Some(Err(e)) => return Err(e),
51305126
None => {
51315127
// Perform basic sanity checks on the amounts and CLTV being forwarded
5132-
if next_packet_details.outgoing_amt_msat > msg.amount_msat {
5128+
if next_hop.outgoing_amt_msat > msg.amount_msat {
51335129
return Err(LocalHTLCFailureReason::FeeInsufficient);
51345130
}
5135-
let cltv_delta =
5136-
msg.cltv_expiry.saturating_sub(next_packet_details.outgoing_cltv_value);
5131+
let cltv_delta = msg.cltv_expiry.saturating_sub(next_hop.outgoing_cltv_value);
51375132
if cltv_delta < MIN_CLTV_EXPIRY_DELTA.into() {
51385133
return Err(LocalHTLCFailureReason::IncorrectCLTVExpiry);
51395134
}
@@ -5144,7 +5139,7 @@ where
51445139
&self.chain_hash,
51455140
) {
51465141
false
5147-
} else if self.forward_needs_intercept(None, outgoing_scid) {
5142+
} else if self.forward_needs_intercept_to_unknown_chan(outgoing_scid) {
51485143
true
51495144
} else {
51505145
return Err(LocalHTLCFailureReason::UnknownNextPeer);
@@ -5153,11 +5148,7 @@ where
51535148
};
51545149

51555150
let cur_height = self.best_block.read().unwrap().height + 1;
5156-
check_incoming_htlc_cltv(
5157-
cur_height,
5158-
next_packet_details.outgoing_cltv_value,
5159-
msg.cltv_expiry,
5160-
)?;
5151+
check_incoming_htlc_cltv(cur_height, next_hop.outgoing_cltv_value, msg.cltv_expiry)?;
51615152

51625153
Ok(intercept)
51635154
}
@@ -15982,9 +15973,9 @@ where
1598215973

1598315974
let should_intercept = self
1598415975
.do_funded_channel_callback(next_hop_scid, |chan| {
15985-
self.forward_needs_intercept(Some(chan), next_hop_scid)
15976+
self.forward_needs_intercept_to_known_chan(chan)
1598615977
})
15987-
.unwrap_or_else(|| self.forward_needs_intercept(None, next_hop_scid));
15978+
.unwrap_or_else(|| self.forward_needs_intercept_to_unknown_chan(next_hop_scid));
1598815979

1598915980
if should_intercept {
1599015981
let intercept_id = InterceptId::from_htlc_id_and_chan_id(

0 commit comments

Comments
 (0)