Skip to content

Commit 30f338c

Browse files
authored
Merge pull request #277 from RasmusOrsoe/parquet_to_sqlite_converter
Parquet to sqlite converter
2 parents d109faf + 0849615 commit 30f338c

File tree

3 files changed

+214
-7
lines changed

3 files changed

+214
-7
lines changed

examples/parquet_to_sqlite.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from graphnet.data.utilities.parquet_to_sqlite import ParquetToSQLiteConverter
2+
3+
if __name__ == "__main__":
4+
# path to parquet file or directory containing parquet files
5+
parquet_path = "/my_file.parquet"
6+
# path to where you want the database to be stored
7+
outdir = "/home/my_databases/"
8+
# name of the database. Will be saved in outdir/database_name/data/database_name.db
9+
database_name = "my_database_from_parquet"
10+
11+
converter = ParquetToSQLiteConverter(
12+
mc_truth_table="mc_truth", parquet_path=parquet_path
13+
)
14+
converter.run(outdir=outdir, database_name=database_name)

src/graphnet/data/sqlite/sqlite_utilities.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,42 @@ def save_to_sql(df: pd.DataFrame, table_name: str, database: str):
3131
engine = sqlalchemy.create_engine("sqlite:///" + database)
3232
df.to_sql(table_name, con=engine, index=False, if_exists="append")
3333
engine.dispose()
34+
35+
36+
def attach_index(database: str, table_name: str):
37+
"""Attaches the table index. Important for query times!"""
38+
code = (
39+
"PRAGMA foreign_keys=off;\n"
40+
"BEGIN TRANSACTION;\n"
41+
f"CREATE INDEX event_no_{table_name} ON {table_name} (event_no);\n"
42+
"COMMIT TRANSACTION;\n"
43+
"PRAGMA foreign_keys=on;"
44+
)
45+
run_sql_code(database, code)
3446
return
3547

3648

37-
def create_table(database, table_name, df):
49+
def create_table(
50+
df: pd.DataFrame,
51+
table_name: str,
52+
database_path: str,
53+
is_pulse_map: bool = False,
54+
):
3855
"""Creates a table.
56+
3957
Args:
40-
pipeline_database (str): path to the pipeline database
41-
df (str): pandas.DataFrame of combined predictions
58+
database (str): path to the database
59+
table_name (str): name of the table
60+
columns (str): the names of the columns of the table
61+
is_pulse_map (bool, optional): whether or not this is a pulse map table. Defaults to False.
4262
"""
4363
query_columns = list()
4464
for column in df.columns:
4565
if column == "event_no":
46-
type_ = "INTEGER PRIMARY KEY NOT NULL"
47-
else:
48-
type_ = "FLOAT"
66+
if not is_pulse_map:
67+
type_ = "INTEGER PRIMARY KEY NOT NULL"
68+
else:
69+
type_ = "NOT NULL"
4970
query_columns.append(f"{column} {type_}")
5071
query_columns = ", ".join(query_columns)
5172

@@ -54,5 +75,9 @@ def create_table(database, table_name, df):
5475
f"CREATE TABLE {table_name} ({query_columns});\n"
5576
"PRAGMA foreign_keys=on;"
5677
)
57-
run_sql_code(database, code)
78+
run_sql_code(
79+
database_path,
80+
code,
81+
)
82+
5883
return
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import pandas as pd
2+
import os
3+
import sqlite3
4+
import awkward as ak
5+
6+
import glob
7+
from typing import List, Optional, Union
8+
from tqdm.auto import trange
9+
import numpy as np
10+
import sqlalchemy
11+
from graphnet.data.sqlite.sqlite_utilities import (
12+
run_sql_code,
13+
save_to_sql,
14+
attach_index,
15+
create_table,
16+
)
17+
18+
from graphnet.utilities.logging import LoggerMixin
19+
20+
21+
class ParquetToSQLiteConverter(LoggerMixin):
22+
"""Converts Parquet files to a SQLite database. Each event in the parquet file(s) are assigned a unique event id.
23+
By default, every field in the parquet file(s) are extracted. One can choose to exclude certain fields by using the argument exclude_fields.
24+
"""
25+
26+
def __init__(
27+
self,
28+
parquet_path: Union[str, List[str]],
29+
mc_truth_table: str = "mc_truth",
30+
excluded_fields: Optional[Union[str, List[str]]] = None,
31+
):
32+
# checks
33+
if isinstance(parquet_path, str):
34+
pass
35+
elif isinstance(parquet_path, list):
36+
assert isinstance(
37+
parquet_path[0], str
38+
), "Argument `parquet_path` must be a string or list of strings"
39+
else:
40+
assert isinstance(
41+
parquet_path, str
42+
), "Argument `parquet_path` must be a string or list of strings"
43+
44+
assert isinstance(
45+
mc_truth_table, str
46+
), "Argument `mc_truth_table` must be a string"
47+
self._parquet_files = self._find_parquet_files(parquet_path)
48+
if excluded_fields is not None:
49+
self._excluded_fields = excluded_fields
50+
else:
51+
self._excluded_fields = []
52+
self._mc_truth_table = mc_truth_table
53+
self._event_counter = 0
54+
self._created_tables = []
55+
56+
def _find_parquet_files(self, paths: Union[str, List[str]]) -> List[str]:
57+
if isinstance(paths, str):
58+
if paths.endswith(".parquet"):
59+
files = [paths]
60+
else:
61+
files = glob.glob(f"{paths}/*.parquet")
62+
elif isinstance(paths, list):
63+
files = []
64+
for path in paths:
65+
files.extend(self._find_parquet_files(path))
66+
assert len(files) > 0, f"No files found in {paths}"
67+
return files
68+
69+
def run(self, outdir: str, database_name: str):
70+
self._create_output_directories(outdir, database_name)
71+
database_path = os.path.join(
72+
outdir, database_name, "data", database_name + ".db"
73+
)
74+
for i in trange(
75+
len(self._parquet_files), desc="Main", colour="#0000ff", position=0
76+
):
77+
parquet_file = ak.from_parquet(self._parquet_files[i])
78+
n_events_in_file = self._count_events(parquet_file)
79+
for j in trange(
80+
len(parquet_file.fields),
81+
desc="%s" % (self._parquet_files[i].split("/")[-1]),
82+
colour="#ffa500",
83+
position=1,
84+
leave=False,
85+
):
86+
if parquet_file.fields[j] not in self._excluded_fields:
87+
self._save_to_sql(
88+
database_path,
89+
parquet_file,
90+
parquet_file.fields[j],
91+
n_events_in_file,
92+
)
93+
self._event_counter += n_events_in_file
94+
self._save_config(outdir, database_name)
95+
print(
96+
f"Database saved at: \n{outdir}/{database_name}/data/{database_name}.db"
97+
)
98+
99+
def _count_events(self, open_parquet_file: ak.Array) -> int:
100+
return len(open_parquet_file[self._mc_truth_table])
101+
102+
def _save_to_sql(
103+
self,
104+
database_path: str,
105+
ak_array: ak.Array = None,
106+
field_name: str = None,
107+
n_events_in_file: int = None,
108+
):
109+
df = self._make_df(ak_array, field_name, n_events_in_file)
110+
if field_name in self._created_tables:
111+
save_to_sql(
112+
database_path,
113+
field_name,
114+
df,
115+
)
116+
else:
117+
if len(df) > n_events_in_file:
118+
is_pulse_map = True
119+
else:
120+
is_pulse_map = False
121+
create_table(df, field_name, database_path, is_pulse_map)
122+
if is_pulse_map:
123+
attach_index(database_path, table_name=field_name)
124+
self._created_tables.append(field_name)
125+
save_to_sql(
126+
database_path,
127+
field_name,
128+
df,
129+
)
130+
131+
def _convert_to_dataframe(
132+
self,
133+
ak_array: ak.Array,
134+
field_name: str,
135+
n_events_in_file: int,
136+
) -> pd.DataFrame:
137+
df = pd.DataFrame(ak.to_pandas(ak_array[field_name]))
138+
if len(df.columns) == 1:
139+
if df.columns == ["values"]:
140+
df.columns = [field_name]
141+
if (
142+
len(df) != n_events_in_file
143+
): # if true, the dataframe contains more than 1 row pr. event (e.g. Pulsemap).
144+
event_nos = []
145+
c = 0
146+
for event_no in range(
147+
self._event_counter, self._event_counter + n_events_in_file, 1
148+
):
149+
try:
150+
event_nos.extend(
151+
np.repeat(event_no, len(df[df.columns[0]][c])).tolist()
152+
)
153+
except KeyError: # KeyError indicates that this df has no entry for event_no (e.g. an event with no detector response)
154+
pass
155+
c += 1
156+
else:
157+
event_nos = np.arange(0, n_events_in_file, 1) + self._event_counter
158+
df["event_no"] = event_nos
159+
return df
160+
161+
def _create_output_directories(self, outdir: str, database_name: str):
162+
os.makedirs(outdir + "/" + database_name + "/data", exist_ok=True)
163+
os.makedirs(outdir + "/" + database_name + "/config", exist_ok=True)
164+
165+
def _save_config(self, outdir: str, database_name: str):
166+
"""Save the list of converted Parquet files to a CSV file."""
167+
df = pd.DataFrame(data=self._parquet_files, columns=["files"])
168+
df.to_csv(outdir + "/" + database_name + "/config/files.csv")

0 commit comments

Comments
 (0)