|
8 | 8 | // |
9 | 9 | // Today we have the following formats: CSV, NDJSON, Parquet, Apache Arrow and Apache Arrow Stream. |
10 | 10 | // |
| 11 | +use polars::frame::chunk_df_for_writing; |
11 | 12 | use polars::prelude::*; |
| 13 | +use polars::{io::schema_to_arrow_checked, prelude::CompatLevel}; |
| 14 | +use polars_arrow::io::ipc::write::{ |
| 15 | + default_ipc_fields, encode_record_batch, schema_to_bytes, EncodedData, WriteOptions, |
| 16 | +}; |
12 | 17 | use std::num::NonZeroUsize; |
13 | 18 |
|
14 | 19 | use rustler::{Binary, Env, NewBinary}; |
15 | 20 | use std::fs::File; |
16 | | -use std::io::{BufReader, BufWriter, Cursor}; |
| 21 | +use std::io::{BufReader, BufWriter, Cursor, Write}; |
| 22 | +use std::sync::Arc; |
17 | 23 |
|
18 | 24 | use crate::datatypes::{ExParquetCompression, ExQuoteStyle, ExS3Entry, ExSeriesDtype}; |
19 | 25 | use crate::{ExDataFrame, ExplorerError}; |
@@ -443,6 +449,160 @@ fn decode_ipc_compression(compression: &str) -> Result<IpcCompression, ExplorerE |
443 | 449 | } |
444 | 450 | } |
445 | 451 |
|
| 452 | +fn decode_compact_level(compact_level: &str) -> Result<CompatLevel, ExplorerError> { |
| 453 | + match compact_level { |
| 454 | + "oldest" => Ok(CompatLevel::oldest()), |
| 455 | + "newest" => Ok(CompatLevel::newest()), |
| 456 | + other => Err(ExplorerError::Other(format!( |
| 457 | + "the compact level {other} is not supported" |
| 458 | + ))), |
| 459 | + } |
| 460 | +} |
| 461 | + |
| 462 | +#[rustler::nif(schedule = "DirtyCpu")] |
| 463 | +pub fn df_dump_ipc_schema<'a>( |
| 464 | + env: Env<'a>, |
| 465 | + df: ExDataFrame, |
| 466 | + compact_level: Option<&str>, |
| 467 | +) -> Result<Binary<'a>, ExplorerError> { |
| 468 | + let compact_level = match compact_level { |
| 469 | + Some(level) => decode_compact_level(level)?, |
| 470 | + None => CompatLevel::oldest(), |
| 471 | + }; |
| 472 | + let schema = schema_to_arrow_checked(df.schema(), compact_level, "ipc")?; |
| 473 | + let ipc_fields = default_ipc_fields(schema.iter_values()); |
| 474 | + let schema_bytes = schema_to_bytes(&schema, &ipc_fields, None); |
| 475 | + let encoded_message = EncodedData { |
| 476 | + ipc_message: schema_bytes, |
| 477 | + arrow_data: Vec::new(), |
| 478 | + }; |
| 479 | + |
| 480 | + let mut buf = vec![]; |
| 481 | + write_message(&mut buf, &encoded_message)?; |
| 482 | + |
| 483 | + let mut values_binary = NewBinary::new(env, buf.len()); |
| 484 | + values_binary.copy_from_slice(&buf); |
| 485 | + |
| 486 | + Ok(values_binary.into()) |
| 487 | +} |
| 488 | + |
| 489 | +#[rustler::nif(schedule = "DirtyCpu")] |
| 490 | +pub fn df_dump_ipc_record_batch<'a>( |
| 491 | + env: Env<'a>, |
| 492 | + df: ExDataFrame, |
| 493 | + max_chunk_size: Option<usize>, |
| 494 | + compression: Option<&str>, |
| 495 | + compact_level: Option<&str>, |
| 496 | +) -> Result<Vec<Binary<'a>>, ExplorerError> { |
| 497 | + let data = &mut df.clone(); |
| 498 | + |
| 499 | + let max_request_bytes = if let Some(max_chunk_size) = max_chunk_size { |
| 500 | + max_chunk_size |
| 501 | + } else { |
| 502 | + let base: usize = 2; |
| 503 | + 10 * base.pow(20) // 10 MB |
| 504 | + }; |
| 505 | + let chunk_num = data.estimated_size() / max_request_bytes + 1; |
| 506 | + let chunk_size = data.fields().len() / chunk_num; |
| 507 | + |
| 508 | + chunk_df_for_writing(data, chunk_size)?; |
| 509 | + |
| 510 | + let compact_level = match compact_level { |
| 511 | + Some(level) => decode_compact_level(level)?, |
| 512 | + None => CompatLevel::oldest(), |
| 513 | + }; |
| 514 | + let iter = data.iter_chunks(compact_level, true); |
| 515 | + |
| 516 | + let compression = match compression { |
| 517 | + Some(algo) => Some(decode_ipc_compression(algo)?.into()), |
| 518 | + None => None, |
| 519 | + }; |
| 520 | + let options = WriteOptions { compression }; |
| 521 | + |
| 522 | + let mut result = Vec::new(); |
| 523 | + |
| 524 | + for batch in iter { |
| 525 | + let mut encoded_message = Default::default(); |
| 526 | + encode_record_batch(&batch, &options, &mut encoded_message); |
| 527 | + let encoded_message = std::mem::take(&mut encoded_message); |
| 528 | + |
| 529 | + let mut buf = vec![]; |
| 530 | + write_message(&mut buf, &encoded_message)?; |
| 531 | + let mut values_binary = NewBinary::new(env, buf.len()); |
| 532 | + values_binary.copy_from_slice(&buf); |
| 533 | + |
| 534 | + result.push(values_binary.into()); |
| 535 | + } |
| 536 | + |
| 537 | + Ok(result) |
| 538 | +} |
| 539 | + |
| 540 | +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written |
| 541 | +/// code from https://github.com/pola-rs/polars/blob/main/crates/polars-arrow/src/io/ipc/write/common_sync.rs |
| 542 | +/// the original code is not public for external crates to use it |
| 543 | +pub fn write_message<W: Write>( |
| 544 | + writer: &mut W, |
| 545 | + encoded: &EncodedData, |
| 546 | +) -> PolarsResult<(usize, usize)> { |
| 547 | + let arrow_data_len = encoded.arrow_data.len(); |
| 548 | + |
| 549 | + let a = 8 - 1; |
| 550 | + let buffer = &encoded.ipc_message; |
| 551 | + let flatbuf_size = buffer.len(); |
| 552 | + let prefix_size = 8; |
| 553 | + let aligned_size = (flatbuf_size + prefix_size + a) & !a; |
| 554 | + let padding_bytes = aligned_size - flatbuf_size - prefix_size; |
| 555 | + |
| 556 | + write_continuation(writer, (aligned_size - prefix_size) as i32)?; |
| 557 | + |
| 558 | + // write the flatbuf |
| 559 | + if flatbuf_size > 0 { |
| 560 | + writer.write_all(buffer)?; |
| 561 | + } |
| 562 | + // write padding |
| 563 | + // aligned to a 8 byte boundary, so maximum is [u8;8] |
| 564 | + const PADDING_MAX: [u8; 8] = [0u8; 8]; |
| 565 | + writer.write_all(&PADDING_MAX[..padding_bytes])?; |
| 566 | + |
| 567 | + // write arrow data |
| 568 | + let body_len = if arrow_data_len > 0 { |
| 569 | + write_body_buffers(writer, &encoded.arrow_data)? |
| 570 | + } else { |
| 571 | + 0 |
| 572 | + }; |
| 573 | + |
| 574 | + Ok((aligned_size, body_len)) |
| 575 | +} |
| 576 | + |
| 577 | +fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> PolarsResult<usize> { |
| 578 | + let len = data.len(); |
| 579 | + let pad_len = pad_to_64(data.len()); |
| 580 | + let total_len = len + pad_len; |
| 581 | + |
| 582 | + // write body buffer |
| 583 | + writer.write_all(data)?; |
| 584 | + if pad_len > 0 { |
| 585 | + writer.write_all(&vec![0u8; pad_len][..])?; |
| 586 | + } |
| 587 | + |
| 588 | + Ok(total_len) |
| 589 | +} |
| 590 | + |
| 591 | +/// Write a record batch to the writer, writing the message size before the message |
| 592 | +/// if the record batch is being written to a stream |
| 593 | +fn write_continuation<W: Write>(writer: &mut W, total_len: i32) -> PolarsResult<usize> { |
| 594 | + const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; |
| 595 | + writer.write_all(&CONTINUATION_MARKER)?; |
| 596 | + writer.write_all(&total_len.to_le_bytes()[..])?; |
| 597 | + Ok(8) |
| 598 | +} |
| 599 | + |
| 600 | +/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes |
| 601 | +#[inline] |
| 602 | +fn pad_to_64(len: usize) -> usize { |
| 603 | + ((len + 63) & !63) - len |
| 604 | +} |
| 605 | + |
446 | 606 | // ============ IPC Streaming ============ // |
447 | 607 |
|
448 | 608 | #[rustler::nif(schedule = "DirtyIo")] |
|
0 commit comments