@@ -54,7 +54,12 @@ pub(crate) use self::unix::{ensure_daemon, open_socket, print_socket, run_daemon
5454/// [WorkspaceTransport] instance if the socket is currently active
5555pub fn open_transport ( runtime : Runtime ) -> io:: Result < Option < impl WorkspaceTransport > > {
5656 match runtime. block_on ( open_socket ( ) ) {
57- Ok ( Some ( ( read, write) ) ) => Ok ( Some ( SocketTransport :: open ( runtime, read, write) ) ) ,
57+ Ok ( Some ( ( read, write) ) ) => Ok ( Some ( SocketTransport :: open_with_timeout (
58+ runtime,
59+ read,
60+ write,
61+ DEFAULT_REQUEST_TIMEOUT ,
62+ ) ) ) ,
5863 Ok ( None ) => Ok ( None ) ,
5964 Err ( err) => Err ( err) ,
6065 }
@@ -99,8 +104,11 @@ pub struct SocketTransport {
99104 runtime : Runtime ,
100105 write_send : Sender < ( Vec < u8 > , bool ) > ,
101106 pending_requests : PendingRequests ,
107+ request_timeout : Duration ,
102108}
103109
110+ const DEFAULT_REQUEST_TIMEOUT : Duration = Duration :: from_secs ( 15 ) ;
111+
104112/// Stores a handle to the map of pending requests, and clears the map
105113/// automatically when the handle is dropped
106114#[ derive( Clone , Default ) ]
@@ -131,7 +139,12 @@ impl Drop for PendingRequests {
131139}
132140
133141impl SocketTransport {
134- pub fn open < R , W > ( runtime : Runtime , socket_read : R , socket_write : W ) -> Self
142+ pub fn open_with_timeout < R , W > (
143+ runtime : Runtime ,
144+ socket_read : R ,
145+ socket_write : W ,
146+ request_timeout : Duration ,
147+ ) -> Self
135148 where
136149 R : AsyncRead + Unpin + Send + ' static ,
137150 W : AsyncWrite + Unpin + Send + ' static ,
@@ -172,6 +185,7 @@ impl SocketTransport {
172185 runtime,
173186 write_send,
174187 pending_requests : pending_requests_2,
188+ request_timeout,
175189 }
176190 }
177191}
@@ -185,27 +199,30 @@ impl WorkspaceTransport for SocketTransport {
185199 P : Serialize ,
186200 R : DeserializeOwned ,
187201 {
202+ let request_id = request. id ;
188203 let ( send, recv) = oneshot:: channel ( ) ;
189-
190- self . pending_requests . insert ( request. id , send) ;
191-
192204 let is_shutdown = request. method == "pgls/shutdown" ;
193205
194206 let request = JsonRpcRequest {
195207 jsonrpc : Cow :: Borrowed ( "2.0" ) ,
196- id : request . id ,
208+ id : request_id ,
197209 method : Cow :: Borrowed ( request. method ) ,
198210 params : request. params ,
199211 } ;
200212
201- let request = to_vec ( & request) . map_err ( |err| {
202- TransportError :: SerdeError ( format ! (
203- "failed to serialize {} into byte buffer: {err}" ,
204- type_name:: <P >( )
205- ) )
206- } ) ?;
213+ let request = match to_vec ( & request) {
214+ Ok ( request) => request,
215+ Err ( err) => {
216+ return Err ( TransportError :: SerdeError ( format ! (
217+ "failed to serialize {} into byte buffer: {err}" ,
218+ type_name:: <P >( )
219+ ) ) ) ;
220+ }
221+ } ;
207222
208- let response = self . runtime . block_on ( async move {
223+ self . pending_requests . insert ( request_id, send) ;
224+
225+ let response = match self . runtime . block_on ( async move {
209226 self . write_send
210227 . send ( ( request, is_shutdown) )
211228 . await
@@ -219,11 +236,17 @@ impl WorkspaceTransport for SocketTransport {
219236 Err ( _) => Err ( TransportError :: ChannelClosed ) ,
220237 }
221238 }
222- _ = sleep( Duration :: from_secs ( 15 ) ) => {
239+ _ = sleep( self . request_timeout ) => {
223240 Err ( TransportError :: Timeout )
224241 }
225242 }
226- } ) ?;
243+ } ) {
244+ Ok ( response) => response,
245+ Err ( err) => {
246+ self . pending_requests . remove ( & request_id) ;
247+ return Err ( err) ;
248+ }
249+ } ;
227250
228251 let response = response. get ( ) ;
229252 let result = from_str ( response) . map_err ( |err| {
@@ -472,3 +495,86 @@ impl FromStr for TransportHeader {
472495 }
473496 }
474497}
498+
499+ #[ cfg( test) ]
500+ mod tests {
501+ use std:: fmt;
502+ use std:: time:: Duration ;
503+
504+ use pgls_workspace:: TransportError ;
505+ use pgls_workspace:: workspace:: { TransportRequest , WorkspaceTransport } ;
506+ use serde:: Serialize ;
507+ use serde:: ser:: { Error as SerError , Serializer } ;
508+ use serde_json:: Value ;
509+ use tokio:: io:: { duplex, split} ;
510+ use tokio:: runtime:: Runtime ;
511+
512+ use super :: SocketTransport ;
513+
514+ struct FailingParams ;
515+
516+ impl Serialize for FailingParams {
517+ fn serialize < S > ( & self , _serializer : S ) -> Result < S :: Ok , S :: Error >
518+ where
519+ S : Serializer ,
520+ {
521+ Err ( S :: Error :: custom ( "expected serialization failure" ) )
522+ }
523+ }
524+
525+ impl fmt:: Debug for FailingParams {
526+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
527+ f. write_str ( "FailingParams" )
528+ }
529+ }
530+
531+ fn disconnected_transport ( ) -> SocketTransport {
532+ let runtime = Runtime :: new ( ) . expect ( "failed to create tokio runtime" ) ;
533+ let ( stream, peer) = duplex ( 1024 ) ;
534+ drop ( peer) ;
535+ let ( read, write) = split ( stream) ;
536+ SocketTransport :: open_with_timeout ( runtime, read, write, Duration :: from_millis ( 50 ) )
537+ }
538+
539+ #[ test]
540+ fn request_does_not_retain_pending_entries_when_serialization_fails ( ) {
541+ let transport = disconnected_transport ( ) ;
542+
543+ let result: Result < Value , TransportError > = transport. request ( TransportRequest {
544+ id : 1 ,
545+ method : "pgls/get_file_content" ,
546+ params : FailingParams ,
547+ } ) ;
548+
549+ assert ! ( matches!( result, Err ( TransportError :: SerdeError ( _) ) ) ) ;
550+ assert_eq ! (
551+ transport. pending_requests. len( ) ,
552+ 0 ,
553+ "pending request should be cleaned up on serialization failure"
554+ ) ;
555+ }
556+
557+ #[ test]
558+ fn request_does_not_retain_pending_entries_on_timeout_or_channel_close ( ) {
559+ let transport = disconnected_transport ( ) ;
560+
561+ let result: Result < Value , TransportError > = transport. request ( TransportRequest {
562+ id : 2 ,
563+ method : "pgls/get_file_content" ,
564+ params : ( ) ,
565+ } ) ;
566+
567+ assert ! (
568+ matches!(
569+ result,
570+ Err ( TransportError :: Timeout | TransportError :: ChannelClosed )
571+ ) ,
572+ "expected timeout or channel-closed error, got {result:?}"
573+ ) ;
574+ assert_eq ! (
575+ transport. pending_requests. len( ) ,
576+ 0 ,
577+ "pending request should be cleaned up on timeout/channel-close"
578+ ) ;
579+ }
580+ }
0 commit comments