Skip to content

Commit 8eeda83

Browse files
authored
Add dump arrow schema and batchrecord (#1103)
* add dump arrow schema and batchrecord * docs and clippy * add doc type * add type spec for keyword params
1 parent 12a3fd1 commit 8eeda83

File tree

7 files changed

+374
-1
lines changed

7 files changed

+374
-1
lines changed

lib/explorer/backend/data_frame.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ defmodule Explorer.Backend.DataFrame do
3333
@type query_frame :: Explorer.Backend.QueryFrame.t()
3434
@type lazy_series :: Explorer.Backend.LazySeries.t()
3535

36+
@type compact_level :: option(:oldest | :newest)
3637
@type compression :: {algorithm :: option(atom()), level :: option(integer())}
3738
@type columns_for_io :: list(column_name()) | list(pos_integer()) | nil
3839

@@ -122,6 +123,9 @@ defmodule Explorer.Backend.DataFrame do
122123
contents :: binary(),
123124
columns :: columns_for_io()
124125
) :: io_result(df)
126+
@callback dump_ipc_schema(df, compact_level()) :: io_result(binary())
127+
@callback dump_ipc_record_batch(df, integer(), compression(), compact_level()) ::
128+
io_result(list(binary()))
125129

126130
# IO: IPC Stream
127131
@callback from_ipc_stream(

lib/explorer/data_frame.ex

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,35 @@ defmodule Explorer.DataFrame do
269269
dtypes: %{String.t() => Explorer.Series.dtype()}
270270
}
271271

272+
@typedoc """
273+
Represents the max chunk size for a given batch record.
274+
"""
275+
@type max_chunk_size :: nil | integer()
276+
277+
@typedoc """
278+
Represents the compression algorithm to use when writing files.
279+
"""
280+
@type compression :: nil | :zstd | :lz4
281+
282+
@typedoc """
283+
Represents the compact level used by the backend.
284+
"""
285+
@type compact_level :: nil | :oldest | :newest
286+
287+
@typedoc """
288+
Represents function keyword options
289+
"""
290+
@type dump_ipc_schema_options :: [compact_level: compact_level()]
291+
292+
@typedoc """
293+
Represents function keyword options
294+
"""
295+
@type dump_ipc_record_batch_options :: [
296+
max_chunk_size: max_chunk_size(),
297+
compression: compression(),
298+
compact_level: compact_level()
299+
]
300+
272301
@default_infer_schema_length 1000
273302
@default_sample_nrows 5
274303
@integer_types Explorer.Shared.integer_types()
@@ -1221,6 +1250,87 @@ defmodule Explorer.DataFrame do
12211250
end
12221251
end
12231252

1253+
@doc """
1254+
Writes a dataframe schema to a binary representation of an IPC schema message.
1255+
1256+
## Options
1257+
1258+
* `:compact_level` - The compact level used by the backend.
1259+
For the Polars backend it indicates to use view types on newest
1260+
or use non-view types on oldest.
1261+
Supported options are:
1262+
1263+
* `nil` (oldest, default)
1264+
* `:oldest`
1265+
* `:newest`.
1266+
1267+
"""
1268+
@doc type: :io
1269+
@spec dump_ipc_schema(df :: DataFrame.t(), opts :: dump_ipc_schema_options()) ::
1270+
{:ok, binary()} | {:error, Exception.t()}
1271+
def dump_ipc_schema(df, opts \\ []) do
1272+
opts = Keyword.validate!(opts, compact_level: nil)
1273+
compact_level = ipc_compact_level(opts[:compact_level])
1274+
1275+
Shared.apply_dataframe(df, :dump_ipc_schema, [compact_level], false)
1276+
end
1277+
1278+
@doc """
1279+
Writes a dataframe to a list of binary, each binary is a representation of a chunked record batch.
1280+
1281+
## Options
1282+
* `:max_chunk_size` - The max size of each binary.
1283+
Supported options are:
1284+
1285+
* `nil` (10mb, default)
1286+
* a value in bytes
1287+
1288+
* `:compression` - The compression algorithm to use when writing files.
1289+
Supported options are:
1290+
1291+
* `nil` (uncompressed, default)
1292+
* `:zstd`
1293+
* `:lz4`.
1294+
1295+
* `:compact_level` - The compact level used by the backend.
1296+
For the Polars backend it indicates to use view types on newest
1297+
or use non-view types on oldest.
1298+
Supported options are:
1299+
1300+
* `nil` (oldest, default)
1301+
* `:oldest`
1302+
* `:newest`.
1303+
1304+
"""
1305+
@doc type: :io
1306+
@spec dump_ipc_record_batch(df :: DataFrame.t(), opts :: dump_ipc_record_batch_options()) ::
1307+
{:ok, list(binary())} | {:error, Exception.t()}
1308+
def dump_ipc_record_batch(df, opts \\ []) do
1309+
opts = Keyword.validate!(opts, max_chunk_size: nil, compression: nil, compact_level: nil)
1310+
max_chunck_size = ipc_max_chunk_size(opts[:max_chunk_size])
1311+
compression = ipc_compression(opts[:compression])
1312+
compact_level = ipc_compact_level(opts[:compact_level])
1313+
1314+
Shared.apply_dataframe(
1315+
df,
1316+
:dump_ipc_record_batch,
1317+
[max_chunck_size, compression, compact_level],
1318+
false
1319+
)
1320+
end
1321+
1322+
defp ipc_max_chunk_size(nil), do: nil
1323+
defp ipc_max_chunk_size(byte_size) when is_integer(byte_size), do: byte_size
1324+
1325+
defp ipc_max_chunk_size(other),
1326+
do: raise(ArgumentError, "unsupported :max_chunk_size #{inspect(other)}")
1327+
1328+
defp ipc_compact_level(nil), do: nil
1329+
defp ipc_compact_level(level) when level in ~w(oldest newest)a, do: level
1330+
1331+
defp ipc_compact_level(other),
1332+
do: raise(ArgumentError, "unsupported :compact_level #{inspect(other)}")
1333+
12241334
@doc """
12251335
Reads a binary representing an IPC file into a dataframe.
12261336

lib/explorer/polars_backend/data_frame.ex

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,32 @@ defmodule Explorer.PolarsBackend.DataFrame do
447447
end
448448
end
449449

450+
@impl true
451+
def dump_ipc_schema(%DataFrame{data: df}, compact_level) do
452+
case Native.df_dump_ipc_schema(df, maybe_atom_to_string(compact_level)) do
453+
{:ok, string} -> {:ok, string}
454+
{:error, error} -> {:error, RuntimeError.exception(error)}
455+
end
456+
end
457+
458+
@impl true
459+
def dump_ipc_record_batch(
460+
%DataFrame{data: df},
461+
max_chunk_size,
462+
{compression, _level},
463+
compact_level
464+
) do
465+
case Native.df_dump_ipc_record_batch(
466+
df,
467+
max_chunk_size,
468+
maybe_atom_to_string(compression),
469+
maybe_atom_to_string(compact_level)
470+
) do
471+
{:ok, list} -> {:ok, list}
472+
{:error, error} -> {:error, RuntimeError.exception(error)}
473+
end
474+
end
475+
450476
@impl true
451477
def from_ipc_stream(%module{} = entry, columns) when module in [S3.Entry, HTTP.Entry] do
452478
path = Shared.build_path_for_entry(entry)

lib/explorer/polars_backend/lazy_frame.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,8 @@ defmodule Explorer.PolarsBackend.LazyFrame do
642642
dump_csv: 4,
643643
dump_ipc: 2,
644644
dump_ipc_stream: 2,
645+
dump_ipc_schema: 2,
646+
dump_ipc_record_batch: 4,
645647
dump_ndjson: 1,
646648
dump_parquet: 2,
647649
mask: 2,

lib/explorer/polars_backend/native.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ defmodule Explorer.PolarsBackend.Native do
8181
def df_dump_parquet(_df, _compression), do: err()
8282
def df_dump_ipc(_df, _compression), do: err()
8383
def df_dump_ipc_stream(_df, _compression), do: err()
84+
def df_dump_ipc_schema(_df, _compact_level), do: err()
85+
def df_dump_ipc_record_batch(_df, _max_chunk_size, _compression, _compact_level), do: err()
8486

8587
def df_from_csv(
8688
_filename,

native/explorer/src/dataframe/io.rs

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
//
99
// Today we have the following formats: CSV, NDJSON, Parquet, Apache Arrow and Apache Arrow Stream.
1010
//
11+
use polars::frame::chunk_df_for_writing;
1112
use polars::prelude::*;
13+
use polars::{io::schema_to_arrow_checked, prelude::CompatLevel};
14+
use polars_arrow::io::ipc::write::{
15+
default_ipc_fields, encode_record_batch, schema_to_bytes, EncodedData, WriteOptions,
16+
};
1217
use std::num::NonZeroUsize;
1318

1419
use rustler::{Binary, Env, NewBinary};
1520
use std::fs::File;
16-
use std::io::{BufReader, BufWriter, Cursor};
21+
use std::io::{BufReader, BufWriter, Cursor, Write};
22+
use std::sync::Arc;
1723

1824
use crate::datatypes::{ExParquetCompression, ExQuoteStyle, ExS3Entry, ExSeriesDtype};
1925
use crate::{ExDataFrame, ExplorerError};
@@ -443,6 +449,160 @@ fn decode_ipc_compression(compression: &str) -> Result<IpcCompression, ExplorerE
443449
}
444450
}
445451

452+
fn decode_compact_level(compact_level: &str) -> Result<CompatLevel, ExplorerError> {
453+
match compact_level {
454+
"oldest" => Ok(CompatLevel::oldest()),
455+
"newest" => Ok(CompatLevel::newest()),
456+
other => Err(ExplorerError::Other(format!(
457+
"the compact level {other} is not supported"
458+
))),
459+
}
460+
}
461+
462+
#[rustler::nif(schedule = "DirtyCpu")]
463+
pub fn df_dump_ipc_schema<'a>(
464+
env: Env<'a>,
465+
df: ExDataFrame,
466+
compact_level: Option<&str>,
467+
) -> Result<Binary<'a>, ExplorerError> {
468+
let compact_level = match compact_level {
469+
Some(level) => decode_compact_level(level)?,
470+
None => CompatLevel::oldest(),
471+
};
472+
let schema = schema_to_arrow_checked(df.schema(), compact_level, "ipc")?;
473+
let ipc_fields = default_ipc_fields(schema.iter_values());
474+
let schema_bytes = schema_to_bytes(&schema, &ipc_fields, None);
475+
let encoded_message = EncodedData {
476+
ipc_message: schema_bytes,
477+
arrow_data: Vec::new(),
478+
};
479+
480+
let mut buf = vec![];
481+
write_message(&mut buf, &encoded_message)?;
482+
483+
let mut values_binary = NewBinary::new(env, buf.len());
484+
values_binary.copy_from_slice(&buf);
485+
486+
Ok(values_binary.into())
487+
}
488+
489+
#[rustler::nif(schedule = "DirtyCpu")]
490+
pub fn df_dump_ipc_record_batch<'a>(
491+
env: Env<'a>,
492+
df: ExDataFrame,
493+
max_chunk_size: Option<usize>,
494+
compression: Option<&str>,
495+
compact_level: Option<&str>,
496+
) -> Result<Vec<Binary<'a>>, ExplorerError> {
497+
let data = &mut df.clone();
498+
499+
let max_request_bytes = if let Some(max_chunk_size) = max_chunk_size {
500+
max_chunk_size
501+
} else {
502+
let base: usize = 2;
503+
10 * base.pow(20) // 10 MB
504+
};
505+
let chunk_num = data.estimated_size() / max_request_bytes + 1;
506+
let chunk_size = data.fields().len() / chunk_num;
507+
508+
chunk_df_for_writing(data, chunk_size)?;
509+
510+
let compact_level = match compact_level {
511+
Some(level) => decode_compact_level(level)?,
512+
None => CompatLevel::oldest(),
513+
};
514+
let iter = data.iter_chunks(compact_level, true);
515+
516+
let compression = match compression {
517+
Some(algo) => Some(decode_ipc_compression(algo)?.into()),
518+
None => None,
519+
};
520+
let options = WriteOptions { compression };
521+
522+
let mut result = Vec::new();
523+
524+
for batch in iter {
525+
let mut encoded_message = Default::default();
526+
encode_record_batch(&batch, &options, &mut encoded_message);
527+
let encoded_message = std::mem::take(&mut encoded_message);
528+
529+
let mut buf = vec![];
530+
write_message(&mut buf, &encoded_message)?;
531+
let mut values_binary = NewBinary::new(env, buf.len());
532+
values_binary.copy_from_slice(&buf);
533+
534+
result.push(values_binary.into());
535+
}
536+
537+
Ok(result)
538+
}
539+
540+
/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
541+
/// code from https://github.com/pola-rs/polars/blob/main/crates/polars-arrow/src/io/ipc/write/common_sync.rs
542+
/// the original code is not public for external crates to use it
543+
pub fn write_message<W: Write>(
544+
writer: &mut W,
545+
encoded: &EncodedData,
546+
) -> PolarsResult<(usize, usize)> {
547+
let arrow_data_len = encoded.arrow_data.len();
548+
549+
let a = 8 - 1;
550+
let buffer = &encoded.ipc_message;
551+
let flatbuf_size = buffer.len();
552+
let prefix_size = 8;
553+
let aligned_size = (flatbuf_size + prefix_size + a) & !a;
554+
let padding_bytes = aligned_size - flatbuf_size - prefix_size;
555+
556+
write_continuation(writer, (aligned_size - prefix_size) as i32)?;
557+
558+
// write the flatbuf
559+
if flatbuf_size > 0 {
560+
writer.write_all(buffer)?;
561+
}
562+
// write padding
563+
// aligned to a 8 byte boundary, so maximum is [u8;8]
564+
const PADDING_MAX: [u8; 8] = [0u8; 8];
565+
writer.write_all(&PADDING_MAX[..padding_bytes])?;
566+
567+
// write arrow data
568+
let body_len = if arrow_data_len > 0 {
569+
write_body_buffers(writer, &encoded.arrow_data)?
570+
} else {
571+
0
572+
};
573+
574+
Ok((aligned_size, body_len))
575+
}
576+
577+
fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> PolarsResult<usize> {
578+
let len = data.len();
579+
let pad_len = pad_to_64(data.len());
580+
let total_len = len + pad_len;
581+
582+
// write body buffer
583+
writer.write_all(data)?;
584+
if pad_len > 0 {
585+
writer.write_all(&vec![0u8; pad_len][..])?;
586+
}
587+
588+
Ok(total_len)
589+
}
590+
591+
/// Write a record batch to the writer, writing the message size before the message
592+
/// if the record batch is being written to a stream
593+
fn write_continuation<W: Write>(writer: &mut W, total_len: i32) -> PolarsResult<usize> {
594+
const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
595+
writer.write_all(&CONTINUATION_MARKER)?;
596+
writer.write_all(&total_len.to_le_bytes()[..])?;
597+
Ok(8)
598+
}
599+
600+
/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes
601+
#[inline]
602+
fn pad_to_64(len: usize) -> usize {
603+
((len + 63) & !63) - len
604+
}
605+
446606
// ============ IPC Streaming ============ //
447607

448608
#[rustler::nif(schedule = "DirtyIo")]

0 commit comments

Comments
 (0)