Skip to content

Commit 5f09ae8

Browse files
committed
[fix] return wait_writable on non-blocking reads
1 parent 736ac5f commit 5f09ae8

File tree

2 files changed

+89
-17
lines changed

2 files changed

+89
-17
lines changed

src/main/java/org/jruby/ext/openssl/SSLSocket.java

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static CallSite callSite(final CallSite[] sites, final CallSiteIndex ind
141141
return sites[ index.ordinal() ];
142142
}
143143

144-
private SSLContext sslContext;
144+
SSLContext sslContext;
145145
private SSLEngine engine;
146146
private RubyIO io;
147147

@@ -209,7 +209,7 @@ private IRubyObject fallback_set_io_nonblock_checked(ThreadContext context, Ruby
209209

210210
private static final String SESSION_SOCKET_ID = "socket_id";
211211

212-
private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) {
212+
SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) {
213213
SSLEngine engine = this.engine;
214214
if ( engine != null ) return engine;
215215

@@ -574,7 +574,7 @@ private IRubyObject doHandshake(final boolean blocking, final boolean exception)
574574
doTasks();
575575
break;
576576
case NEED_UNWRAP:
577-
if (readAndUnwrap(blocking) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
577+
if (readAndUnwrap(blocking, exception) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
578578
throw new SSLHandshakeException("Socket closed");
579579
}
580580
// during initialHandshake, calling readAndUnwrap that results UNDERFLOW does not mean writable.
@@ -721,10 +721,6 @@ private int read(final ByteBuffer dst, final boolean blocking, final boolean exc
721721
return limit;
722722
}
723723

724-
private int readAndUnwrap(final boolean blocking) throws IOException {
725-
return readAndUnwrap(blocking, true);
726-
}
727-
728724
private int readAndUnwrap(final boolean blocking, final boolean exception) throws IOException {
729725
final int bytesRead = socketChannelImpl().read(netReadData);
730726
if ( bytesRead == -1 ) {
@@ -813,10 +809,10 @@ private void doShutdown() throws IOException {
813809
}
814810

815811
/**
816-
* @return the (@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol}
812+
* @return the {@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol}
817813
*/
818-
private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject len, final IRubyObject buff,
819-
final boolean blocking, final boolean exception) {
814+
private IRubyObject sysreadImpl(final ThreadContext context,
815+
final IRubyObject len, final IRubyObject buff, final boolean blocking, final boolean exception) {
820816
final Ruby runtime = context.runtime;
821817

822818
final int length = RubyNumeric.fix2int(len);
@@ -836,11 +832,12 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l
836832
}
837833

838834
try {
839-
// flush any pending encrypted write data before reading; after write_nonblock,
840-
// encrypted bytes may remain in the buffer that haven't been sent, if we read wout flushing,
841-
// server may not have received the complete request (e.g. net/http POST body) and will not respond
835+
// Flush pending write data before reading (after write_nonblock encrypted bytes may still be buffered)
842836
if ( engine != null && netWriteData.hasRemaining() ) {
843-
flushData(blocking);
837+
if ( flushData(blocking) && ! blocking ) {
838+
if ( exception ) throw newSSLErrorWaitWritable(runtime, "write would block");
839+
return runtime.newSymbol("wait_writable");
840+
}
844841
}
845842

846843
// So we need to make sure to only block when there is no data left to process
@@ -851,7 +848,7 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l
851848

852849
final ByteBuffer dst = ByteBuffer.allocate(length);
853850
int read = -1;
854-
// ensure >0 bytes read; sysread is blocking read.
851+
// ensure > 0 bytes read; sysread is blocking read
855852
while ( read <= 0 ) {
856853
if ( engine == null ) {
857854
read = socketChannelImpl().read(dst);
@@ -1238,7 +1235,7 @@ public IRubyObject ssl_version(ThreadContext context) {
12381235
return context.runtime.newString( engine.getSession().getProtocol() );
12391236
}
12401237

1241-
private transient SocketChannelImpl socketChannel;
1238+
transient SocketChannelImpl socketChannel;
12421239

12431240
private SocketChannelImpl socketChannelImpl() {
12441241
if ( socketChannel != null ) return socketChannel;
@@ -1253,7 +1250,7 @@ private SocketChannelImpl socketChannelImpl() {
12531250
throw new IllegalStateException("unknow channel impl: " + channel + " of type " + channel.getClass().getName());
12541251
}
12551252

1256-
private interface SocketChannelImpl {
1253+
interface SocketChannelImpl {
12571254

12581255
boolean isOpen() ;
12591256

src/test/java/org/jruby/ext/openssl/SSLSocketTest.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package org.jruby.ext.openssl;
22

3+
import java.io.IOException;
34
import java.nio.ByteBuffer;
5+
import java.nio.channels.SelectionKey;
6+
import java.nio.channels.Selector;
7+
import javax.net.ssl.SSLEngine;
48

9+
import org.jruby.Ruby;
510
import org.jruby.RubyArray;
611
import org.jruby.RubyFixnum;
12+
import org.jruby.RubyHash;
713
import org.jruby.RubyInteger;
814
import org.jruby.RubyString;
915
import org.jruby.exceptions.RaiseException;
@@ -12,6 +18,7 @@
1218
import org.junit.After;
1319
import org.junit.Before;
1420
import org.junit.Test;
21+
1522
import static org.junit.Assert.*;
1623

1724
public class SSLSocketTest extends OpenSSLHelper {
@@ -173,4 +180,72 @@ private void closeQuietly(final RubyArray sslPair) {
173180
}
174181
}
175182
}
183+
184+
// ----------
185+
186+
/**
187+
* MRI's ossl_ssl_read_internal returns :wait_writable (or raises SSLErrorWaitWritable / "write would block")
188+
* when SSL_read hits SSL_ERROR_WANT_WRITE. Pending netWriteData is JRuby's equivalent state.
189+
*/
190+
@Test
191+
public void sysreadNonblockReturnsWaitWritableWhenPendingEncryptedBytesRemain() {
192+
final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1));
193+
final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false);
194+
engine.setUseClientMode(true);
195+
196+
socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 });
197+
198+
final RubyHash opts = RubyHash.newKwargs(runtime, "exception", runtime.getFalse()); // exception: false
199+
final IRubyObject result = socket.sysread_nonblock(currentContext(), runtime.newFixnum(1), opts);
200+
201+
assertEquals("wait_writable", result.asJavaString());
202+
assertEquals(1, socket.netWriteData.remaining());
203+
}
204+
205+
@Test
206+
public void sysreadNonblockRaisesWaitWritableWhenPendingEncryptedBytesRemain() {
207+
final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1));
208+
final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false);
209+
engine.setUseClientMode(true);
210+
211+
socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 });
212+
213+
try {
214+
socket.sysread_nonblock(currentContext(), runtime.newFixnum(1));
215+
fail("expected SSLErrorWaitWritable");
216+
}
217+
catch (RaiseException ex) {
218+
assertEquals("OpenSSL::SSL::SSLErrorWaitWritable", ex.getException().getMetaClass().getName());
219+
assertTrue(ex.getMessage().contains("write would block"));
220+
assertEquals(1, socket.netWriteData.remaining());
221+
}
222+
}
223+
224+
private static SSLSocket newSSLSocket(final Ruby runtime, final SSLSocket.SocketChannelImpl socketChannel) {
225+
final SSLContext sslContext = new SSLContext(runtime);
226+
sslContext.doSetup(runtime.getCurrentContext());
227+
final SSLSocket sslSocket = new SSLSocket(runtime, runtime.getObject());
228+
sslSocket.sslContext = sslContext;
229+
sslSocket.socketChannel = socketChannel;
230+
return sslSocket;
231+
}
232+
233+
private static SSLSocket.SocketChannelImpl partialWriteChannel(final int bytesPerWrite) {
234+
return new SSLSocket.SocketChannelImpl() {
235+
public boolean isOpen() { return true; }
236+
public int read(final ByteBuffer dst) { return 0; }
237+
public int write(final ByteBuffer src) {
238+
final int written = Math.min(bytesPerWrite, src.remaining());
239+
src.position(src.position() + written);
240+
return written;
241+
}
242+
public int getRemotePort() { return 443; }
243+
public boolean isSelectable() { return false; }
244+
public boolean isBlocking() { return false; }
245+
public void configureBlocking(final boolean block) { }
246+
public SelectionKey register(final Selector selector, final int ops) throws IOException {
247+
throw new UnsupportedOperationException();
248+
}
249+
};
250+
}
176251
}

0 commit comments

Comments
 (0)