55
66from sqlalchemy import (
77 ARRAY ,
8- CTE ,
98 CompoundSelect ,
109 DateTime ,
1110 Integer ,
2120 tuple_ ,
2221)
2322
23+ from data_rentgen .db .models .dataset_symlink import DatasetSymlink
2424from data_rentgen .db .models .input import Input
2525from data_rentgen .db .models .job_dependency import JobDependency
2626from data_rentgen .db .models .output import Output
@@ -104,13 +104,14 @@ async def get_dependencies(
104104 infer_from_lineage : bool = False ,
105105 ) -> list [JobDependency ]:
106106 core_query = self ._get_core_hierarchy_query (include_indirect = infer_from_lineage )
107+ core_subquery = core_query .subquery ()
107108
108- query : Select | CompoundSelect
109+ query : Select
109110 match direction :
110111 case "UPSTREAM" :
111- query = select (core_query ).where (core_query .c .to_job_id == any_ (bindparam ("job_ids" )))
112+ query = select (core_subquery ).where (core_subquery .c .to_job_id == any_ (bindparam ("job_ids" )))
112113 case "DOWNSTREAM" :
113- query = select (core_query ).where (core_query .c .from_job_id == any_ (bindparam ("job_ids" )))
114+ query = select (core_subquery ).where (core_subquery .c .from_job_id == any_ (bindparam ("job_ids" )))
114115
115116 result = await self ._session .execute (
116117 query ,
@@ -125,37 +126,58 @@ def _get_core_hierarchy_query(
125126 self ,
126127 * ,
127128 include_indirect : bool = False ,
128- ) -> CTE :
129+ ) -> Select | CompoundSelect :
129130 query : Select | CompoundSelect
130131 query = select (
131132 JobDependency .from_job_id ,
132133 JobDependency .to_job_id ,
133134 JobDependency .type ,
134135 )
135136 if include_indirect :
136- query = query .union (
137- select (
138- Output .job_id .label ("from_job_id" ),
139- Input .job_id .label ("to_job_id" ),
140- literal ("INFERRED_FROM_LINEAGE" ).label ("type" ),
141- )
142- .distinct ()
137+ # Where clause and columns are common part for all unions
138+ where_clauses = [
139+ Input .created_at >= bindparam ("since" ),
140+ Output .created_at >= bindparam ("since" ),
141+ Output .created_at >= Input .created_at ,
142+ Output .job_id != Input .job_id ,
143+ or_ (
144+ bindparam ("until" , type_ = DateTime (timezone = True )).is_ (None ),
145+ and_ (
146+ Input .created_at <= bindparam ("until" ),
147+ Output .created_at <= bindparam ("until" ),
148+ ),
149+ ),
150+ ]
151+ inferred_columns = select (
152+ Output .job_id .label ("from_job_id" ),
153+ Input .job_id .label ("to_job_id" ),
154+ literal ("INFERRED_FROM_LINEAGE" ).label ("type" ),
155+ ).distinct ()
156+
157+ # IO connections via same dataset
158+ direct_connection = inferred_columns .join (
159+ Input ,
160+ Output .dataset_id == Input .dataset_id ,
161+ ).where (* where_clauses )
162+ # IO connections Output.d_id == Symlink.to_d_id Symlink.from_d_id == Input.d_id
163+ via_symlinks_from_output = (
164+ inferred_columns .join (DatasetSymlink , Output .dataset_id == DatasetSymlink .to_dataset_id )
143165 .join (
144166 Input ,
145- Output . dataset_id == Input .dataset_id ,
167+ DatasetSymlink . from_dataset_id == Input .dataset_id ,
146168 )
147- .where (
148- Input .created_at >= bindparam ("since" ),
149- Output .created_at >= bindparam ("since" ),
150- Output .created_at >= Input .created_at ,
151- Output .job_id != Input .job_id ,
152- or_ (
153- bindparam ("until" , type_ = DateTime (timezone = True )).is_ (None ),
154- and_ (
155- Input .created_at <= bindparam ("until" ),
156- Output .created_at <= bindparam ("until" ),
157- ),
158- ),
169+ .where (* where_clauses )
170+ )
171+ # IO connections Input.d_id == Symlink.to_d_id Symlink.from_d_id == Output.d_id
172+ via_symlinks_from_input = (
173+ inferred_columns .join (DatasetSymlink , Input .dataset_id == DatasetSymlink .to_dataset_id )
174+ .join (
175+ Output ,
176+ DatasetSymlink .from_dataset_id == Output .dataset_id ,
159177 )
178+ .where (* where_clauses )
160179 )
161- return query .cte ("jobs_hierarchy_core_query" ).prefix_with ("NOT MATERIALIZED" , dialect = "postgresql" )
180+
181+ query = query .union (direct_connection , via_symlinks_from_input , via_symlinks_from_output )
182+
183+ return query
0 commit comments