Skip to content

Commit bd031f0

Browse files
authored
fix: add IsActive checks to Writer methods to prevent panic after Close (#416)
1 parent 81277e4 commit bd031f0

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

connection_impl.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ func (c *connection) ReadByte() (b byte, err error) {
243243

244244
// Malloc implements Connection.
245245
func (c *connection) Malloc(n int) (buf []byte, err error) {
246+
if !c.IsActive() {
247+
return nil, Exception(ErrConnClosed, "when malloc")
248+
}
246249
return c.outputBuffer.Malloc(n)
247250
}
248251

@@ -273,31 +276,49 @@ func (c *connection) Flush() error {
273276

274277
// MallocAck implements Connection.
275278
func (c *connection) MallocAck(n int) (err error) {
279+
if !c.IsActive() {
280+
return Exception(ErrConnClosed, "when malloc ack")
281+
}
276282
return c.outputBuffer.MallocAck(n)
277283
}
278284

279285
// Append implements Connection.
280286
func (c *connection) Append(w Writer) (err error) {
287+
if !c.IsActive() {
288+
return Exception(ErrConnClosed, "when append")
289+
}
281290
return c.outputBuffer.Append(w)
282291
}
283292

284293
// WriteString implements Connection.
285294
func (c *connection) WriteString(s string) (n int, err error) {
295+
if !c.IsActive() {
296+
return 0, Exception(ErrConnClosed, "when write string")
297+
}
286298
return c.outputBuffer.WriteString(s)
287299
}
288300

289301
// WriteBinary implements Connection.
290302
func (c *connection) WriteBinary(b []byte) (n int, err error) {
303+
if !c.IsActive() {
304+
return 0, Exception(ErrConnClosed, "when write binary")
305+
}
291306
return c.outputBuffer.WriteBinary(b)
292307
}
293308

294309
// WriteDirect implements Connection.
295310
func (c *connection) WriteDirect(p []byte, remainCap int) (err error) {
311+
if !c.IsActive() {
312+
return Exception(ErrConnClosed, "when write direct")
313+
}
296314
return c.outputBuffer.WriteDirect(p, remainCap)
297315
}
298316

299317
// WriteByte implements Connection.
300318
func (c *connection) WriteByte(b byte) (err error) {
319+
if !c.IsActive() {
320+
return Exception(ErrConnClosed, "when write byte")
321+
}
301322
return c.outputBuffer.WriteByte(b)
302323
}
303324

connection_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,45 @@ func TestConnectionServerClose(t *testing.T) {
858858
wg.Wait()
859859
}
860860

861+
func TestWriterAfterClose(t *testing.T) {
862+
r, w := GetSysFdPairs()
863+
rconn, wconn := &connection{}, &connection{}
864+
rconn.init(&netFD{fd: r}, nil)
865+
wconn.init(&netFD{fd: w}, nil)
866+
867+
err := wconn.Close()
868+
MustNil(t, err)
869+
870+
for wconn.IsActive() {
871+
runtime.Gosched()
872+
}
873+
874+
methods := []struct {
875+
name string
876+
fn func() error
877+
}{
878+
{"Malloc", func() error { _, err := wconn.Malloc(1); return err }},
879+
{"MallocAck", func() error { return wconn.MallocAck(0) }},
880+
{"WriteBinary", func() error { _, err := wconn.WriteBinary([]byte("hi")); return err }},
881+
{"WriteString", func() error { _, err := wconn.WriteString("hi"); return err }},
882+
{"WriteByte", func() error { return wconn.WriteByte('a') }},
883+
{"WriteDirect", func() error { return wconn.WriteDirect([]byte("hi"), 0) }},
884+
{"Flush", func() error { return wconn.Flush() }},
885+
}
886+
for _, tc := range methods {
887+
t.Run(tc.name, func(t *testing.T) {
888+
defer func() {
889+
if r := recover(); r != nil {
890+
t.Fatalf("Writer.%s panicked after Close: %v", tc.name, r)
891+
}
892+
}()
893+
err := tc.fn()
894+
Assert(t, err != nil, fmt.Sprintf("Writer.%s should return error after Close", tc.name))
895+
})
896+
}
897+
rconn.Close()
898+
}
899+
861900
func TestConnectionDailTimeoutAndClose(t *testing.T) {
862901
ln := createTestTCPListener(t)
863902
defer ln.Close()

mux/shard_queue.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ func (q *ShardQueue) foreach() {
172172

173173
// deal is used to get deal of netpoll.Writer.
174174
func (q *ShardQueue) deal(gts []WriterGetter) {
175+
if !q.conn.IsActive() {
176+
return
177+
}
175178
writer := q.conn.Writer()
176179
for _, gt := range gts {
177180
buf, isNil := gt()

netpoll_unix_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,6 @@ func TestServerReadAndClose(t *testing.T) {
436436
runtime.Gosched() // wait for poller close connection
437437
}
438438
_, err = conn.Writer().WriteBinary(sendMsg)
439-
MustNil(t, err)
440-
err = conn.Writer().Flush()
441439
Assert(t, errors.Is(err, ErrConnClosed), err)
442440

443441
err = loop.Shutdown(context.Background())

0 commit comments

Comments
 (0)