Skip to content

Commit 6b05df4

Browse files
committed
fix: handle multi-byte characters in Python IO streams
Python's `TextIO.read(n)` returns up to `n` characters, which can result in more than `n` bytes when dealing with multi-byte UTF-8 sequences. Similarly, when writing to a `TextIO` object, a buffer may end in the middle of a multi-byte character. This adds buffering to both `PyReader` and `PyWriter` to correctly manage excess bytes during reads and incomplete UTF-8 sequences during writes. Closes VirusTotal#633
1 parent 99fcfd1 commit 6b05df4

2 files changed

Lines changed: 103 additions & 20 deletions

File tree

py/src/lib.rs

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ mod consts {
208208
struct PyReader {
209209
obj: Py<PyAny>,
210210
is_text_io: bool,
211+
// Buffer to store excess bytes read from Python streams. This is necessary
212+
// because when reading from a TextIO object, Python's `read(n)` returns up
213+
// to `n` characters, which can be more than `n` bytes if there are
214+
// multibyte characters.
215+
buffer: Vec<u8>,
211216
}
212217

213218
impl PyReader {
@@ -224,33 +229,64 @@ impl PyReader {
224229
let is_text_io =
225230
obj_bound.is_instance(consts::text_io_base(py)?)?;
226231

227-
Ok(Self { obj, is_text_io })
232+
Ok(Self { obj, is_text_io, buffer: Vec::new() })
228233
})
229234
}
230235
}
231236

232237
impl Read for PyReader {
233-
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
238+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
239+
if buf.is_empty() {
240+
return Ok(0);
241+
}
242+
243+
// If we have leftover bytes from a previous read, consume them first.
244+
if !self.buffer.is_empty() {
245+
let n = std::cmp::min(buf.len(), self.buffer.len());
246+
buf[..n].copy_from_slice(&self.buffer[..n]);
247+
self.buffer.drain(..n);
248+
return Ok(n);
249+
}
250+
234251
Python::attach(|py| {
252+
// Call Python `read` method. We request `buf.len()` units.
253+
// For text streams, this means `buf.len()` characters.
235254
let data =
236255
self.obj.call_method1(py, consts::read(py), (buf.len(),))?;
237256

238-
if self.is_text_io {
239-
let bytes = data.extract::<Cow<str>>(py).unwrap();
240-
buf.write_all(bytes.as_bytes())?;
241-
Ok(bytes.len())
257+
let bytes = if self.is_text_io {
258+
let s = data.extract::<Cow<str>>(py).unwrap();
259+
s.as_bytes().to_vec()
242260
} else {
243-
let bytes = data.extract::<Cow<[u8]>>(py).unwrap();
244-
buf.write_all(bytes.as_ref())?;
245-
Ok(bytes.len())
261+
data.extract::<Cow<[u8]>>(py).unwrap().to_vec()
262+
};
263+
264+
if bytes.is_empty() {
265+
return Ok(0);
266+
}
267+
268+
// Copy as many bytes as fit in `buf`.
269+
let n = std::cmp::min(buf.len(), bytes.len());
270+
buf[..n].copy_from_slice(&bytes[..n]);
271+
272+
// If Python returned more bytes than fit in `buf` (due to multi-byte
273+
// characters in text mode), store the excess in our buffer.
274+
if n < bytes.len() {
275+
self.buffer.extend_from_slice(&bytes[n..]);
246276
}
277+
278+
Ok(n)
247279
})
248280
}
249281
}
250282

251283
struct PyWriter {
252284
obj: Py<PyAny>,
253285
is_text_io: bool,
286+
// Buffer to store incomplete UTF-8 sequences at the end of chunks.
287+
// This is necessary because the formatter writes data in chunks, and a
288+
// chunk boundary can fall in the middle of a multi-byte UTF-8 character.
289+
buffer: Vec<u8>,
254290
}
255291

256292
impl PyWriter {
@@ -267,26 +303,59 @@ impl PyWriter {
267303
let is_text_io =
268304
obj_bound.is_instance(consts::text_io_base(py)?)?;
269305

270-
Ok(Self { obj, is_text_io })
306+
Ok(Self { obj, is_text_io, buffer: Vec::new() })
271307
})
272308
}
273309
}
274310

275311
impl Write for PyWriter {
276312
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
277313
Python::attach(|py| {
278-
let arg = if self.is_text_io {
279-
let s = std::str::from_utf8(buf).expect(
280-
"tried to write non-utf8 data to a TextIO object.",
281-
);
282-
PyString::new(py, s).into_any()
283-
} else {
284-
PyBytes::new(py, buf).into_any()
285-
};
314+
if !self.is_text_io {
315+
let arg = PyBytes::new(py, buf).into_any();
316+
let n =
317+
self.obj.call_method1(py, consts::write(py), (arg,))?;
318+
return n.extract(py).map_err(io::Error::from);
319+
}
286320

287-
let n = self.obj.call_method1(py, consts::write(py), (arg,))?;
321+
// Append new data to buffer.
322+
self.buffer.extend_from_slice(buf);
288323

289-
n.extract(py).map_err(io::Error::from)
324+
// Try to convert the buffered data to a valid UTF-8 string.
325+
match std::str::from_utf8(&self.buffer) {
326+
Ok(s) => {
327+
let arg = PyString::new(py, s).into_any();
328+
self.obj.call_method1(py, consts::write(py), (arg,))?;
329+
self.buffer.clear();
330+
Ok(buf.len())
331+
}
332+
Err(e) => {
333+
let valid_len = e.valid_up_to();
334+
if e.error_len().is_some() {
335+
// Real UTF-8 error in the middle of the data.
336+
return Err(io::Error::new(
337+
io::ErrorKind::InvalidData,
338+
e,
339+
));
340+
}
341+
// Incomplete UTF-8 sequence at the end of the buffer.
342+
// Write the valid part and keep the incomplete part in the buffer.
343+
if valid_len > 0 {
344+
let s = std::str::from_utf8(&self.buffer[..valid_len])
345+
.unwrap();
346+
let arg = PyString::new(py, s).into_any();
347+
self.obj.call_method1(
348+
py,
349+
consts::write(py),
350+
(arg,),
351+
)?;
352+
self.buffer.drain(..valid_len);
353+
}
354+
// We return `buf.len()` because we accepted all bytes (either
355+
// wrote them or buffered them).
356+
Ok(buf.len())
357+
}
358+
}
290359
})
291360
}
292361

py/tests/test_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,20 @@ def test_format():
346346
assert result == expected_output
347347

348348

349+
def test_format_non_ascii_long():
350+
import io
351+
# Create a long string with non-ASCII characters.
352+
# 5000 characters of 'ê' will be 10000 bytes.
353+
rule_content = 'rule test { strings: $a = "' + 'ê' * 5000 + '" condition: $a }'
354+
inp = io.StringIO(rule_content)
355+
output = io.StringIO()
356+
fmt = yara_x.Formatter()
357+
# This should not raise "ValueError: read error: failed to write whole buffer"
358+
fmt.format(inp, output)
359+
result = output.getvalue()
360+
assert 'ê' * 5000 in result
361+
362+
349363
def test_module():
350364
with pytest.raises(ValueError):
351365
yara_x.Module('AXS')

0 commit comments

Comments
 (0)