@@ -208,6 +208,11 @@ mod consts {
208208struct PyReader {
209209 obj : Py < PyAny > ,
210210 is_text_io : bool ,
211+ // Buffer to store excess bytes read from Python streams. This is necessary
212+ // because when reading from a TextIO object, Python's `read(n)` returns up
213+ // to `n` characters, which can be more than `n` bytes if there are
214+ // multibyte characters.
215+ buffer : Vec < u8 > ,
211216}
212217
213218impl PyReader {
@@ -224,33 +229,64 @@ impl PyReader {
224229 let is_text_io =
225230 obj_bound. is_instance ( consts:: text_io_base ( py) ?) ?;
226231
227- Ok ( Self { obj, is_text_io } )
232+ Ok ( Self { obj, is_text_io, buffer : Vec :: new ( ) } )
228233 } )
229234 }
230235}
231236
232237impl Read for PyReader {
233- fn read ( & mut self , mut buf : & mut [ u8 ] ) -> io:: Result < usize > {
238+ fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
239+ if buf. is_empty ( ) {
240+ return Ok ( 0 ) ;
241+ }
242+
243+ // If we have leftover bytes from a previous read, consume them first.
244+ if !self . buffer . is_empty ( ) {
245+ let n = std:: cmp:: min ( buf. len ( ) , self . buffer . len ( ) ) ;
246+ buf[ ..n] . copy_from_slice ( & self . buffer [ ..n] ) ;
247+ self . buffer . drain ( ..n) ;
248+ return Ok ( n) ;
249+ }
250+
234251 Python :: attach ( |py| {
252+ // Call Python `read` method. We request `buf.len()` units.
253+ // For text streams, this means `buf.len()` characters.
235254 let data =
236255 self . obj . call_method1 ( py, consts:: read ( py) , ( buf. len ( ) , ) ) ?;
237256
238- if self . is_text_io {
239- let bytes = data. extract :: < Cow < str > > ( py) . unwrap ( ) ;
240- buf. write_all ( bytes. as_bytes ( ) ) ?;
241- Ok ( bytes. len ( ) )
257+ let bytes = if self . is_text_io {
258+ let s = data. extract :: < Cow < str > > ( py) . unwrap ( ) ;
259+ s. as_bytes ( ) . to_vec ( )
242260 } else {
243- let bytes = data. extract :: < Cow < [ u8 ] > > ( py) . unwrap ( ) ;
244- buf. write_all ( bytes. as_ref ( ) ) ?;
245- Ok ( bytes. len ( ) )
261+ data. extract :: < Cow < [ u8 ] > > ( py) . unwrap ( ) . to_vec ( )
262+ } ;
263+
264+ if bytes. is_empty ( ) {
265+ return Ok ( 0 ) ;
266+ }
267+
268+ // Copy as many bytes as fit in `buf`.
269+ let n = std:: cmp:: min ( buf. len ( ) , bytes. len ( ) ) ;
270+ buf[ ..n] . copy_from_slice ( & bytes[ ..n] ) ;
271+
272+ // If Python returned more bytes than fit in `buf` (due to multi-byte
273+ // characters in text mode), store the excess in our buffer.
274+ if n < bytes. len ( ) {
275+ self . buffer . extend_from_slice ( & bytes[ n..] ) ;
246276 }
277+
278+ Ok ( n)
247279 } )
248280 }
249281}
250282
251283struct PyWriter {
252284 obj : Py < PyAny > ,
253285 is_text_io : bool ,
286+ // Buffer to store incomplete UTF-8 sequences at the end of chunks.
287+ // This is necessary because the formatter writes data in chunks, and a
288+ // chunk boundary can fall in the middle of a multi-byte UTF-8 character.
289+ buffer : Vec < u8 > ,
254290}
255291
256292impl PyWriter {
@@ -267,26 +303,59 @@ impl PyWriter {
267303 let is_text_io =
268304 obj_bound. is_instance ( consts:: text_io_base ( py) ?) ?;
269305
270- Ok ( Self { obj, is_text_io } )
306+ Ok ( Self { obj, is_text_io, buffer : Vec :: new ( ) } )
271307 } )
272308 }
273309}
274310
275311impl Write for PyWriter {
276312 fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
277313 Python :: attach ( |py| {
278- let arg = if self . is_text_io {
279- let s = std:: str:: from_utf8 ( buf) . expect (
280- "tried to write non-utf8 data to a TextIO object." ,
281- ) ;
282- PyString :: new ( py, s) . into_any ( )
283- } else {
284- PyBytes :: new ( py, buf) . into_any ( )
285- } ;
314+ if !self . is_text_io {
315+ let arg = PyBytes :: new ( py, buf) . into_any ( ) ;
316+ let n =
317+ self . obj . call_method1 ( py, consts:: write ( py) , ( arg, ) ) ?;
318+ return n. extract ( py) . map_err ( io:: Error :: from) ;
319+ }
286320
287- let n = self . obj . call_method1 ( py, consts:: write ( py) , ( arg, ) ) ?;
321+ // Append new data to buffer.
322+ self . buffer . extend_from_slice ( buf) ;
288323
289- n. extract ( py) . map_err ( io:: Error :: from)
324+ // Try to convert the buffered data to a valid UTF-8 string.
325+ match std:: str:: from_utf8 ( & self . buffer ) {
326+ Ok ( s) => {
327+ let arg = PyString :: new ( py, s) . into_any ( ) ;
328+ self . obj . call_method1 ( py, consts:: write ( py) , ( arg, ) ) ?;
329+ self . buffer . clear ( ) ;
330+ Ok ( buf. len ( ) )
331+ }
332+ Err ( e) => {
333+ let valid_len = e. valid_up_to ( ) ;
334+ if e. error_len ( ) . is_some ( ) {
335+ // Real UTF-8 error in the middle of the data.
336+ return Err ( io:: Error :: new (
337+ io:: ErrorKind :: InvalidData ,
338+ e,
339+ ) ) ;
340+ }
341+ // Incomplete UTF-8 sequence at the end of the buffer.
342+ // Write the valid part and keep the incomplete part in the buffer.
343+ if valid_len > 0 {
344+ let s = std:: str:: from_utf8 ( & self . buffer [ ..valid_len] )
345+ . unwrap ( ) ;
346+ let arg = PyString :: new ( py, s) . into_any ( ) ;
347+ self . obj . call_method1 (
348+ py,
349+ consts:: write ( py) ,
350+ ( arg, ) ,
351+ ) ?;
352+ self . buffer . drain ( ..valid_len) ;
353+ }
354+ // We return `buf.len()` because we accepted all bytes (either
355+ // wrote them or buffered them).
356+ Ok ( buf. len ( ) )
357+ }
358+ }
290359 } )
291360 }
292361
0 commit comments