Skip to content

Commit c1a8141

Browse files
committed
revised logic
1 parent 4e60a1f commit c1a8141

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

api/src/shared/db_models/gtfs_rt_feed_impl.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,27 @@ def from_orm(cls, feed: Gtfsrealtimefeed | None, db_session: Session) -> GtfsRTF
2727
gtfs_rt_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations] if feed.locations else []
2828
gtfs_rt_feed.entity_types = [item.name for item in feed.entitytypes] if feed.entitytypes else []
2929

30-
# gtfs_rt_feed.feed_references = [item.stable_id for item in feed.gtfs_feeds] if feed.gtfs_feeds else []
31-
gtfs_rt_location_ids = {location.id for location in feed.locations}
30+
# Base query: same provider_id, but not the same feed
3231
query = (
3332
db_session.query(GtfsFeedOrm)
34-
.filter(GtfsFeedOrm.provider == feed.provider, GtfsFeedOrm.stable_id != feed.stable_id)
33+
.filter(
34+
GtfsFeedOrm.provider_id == feed.provider_id,
35+
GtfsFeedOrm.stable_id != feed.stable_id,
36+
)
3537
.options(joinedload(GtfsFeedOrm.locations))
3638
)
3739

40+
# If the GtfsRT feed has locations, require overlap
41+
rt_location_ids = {loc.id for loc in feed.locations} if feed.locations else set()
3842
feed_references = []
3943
for gtfs_feed in query.all():
40-
gtfs_location_ids = {location.id for location in gtfs_feed.locations}
41-
# Check if there is any overlap in locations.
42-
if not gtfs_location_ids.isdisjoint(gtfs_rt_location_ids):
43-
feed_references.append(gtfs_feed.stable_id)
44+
if rt_location_ids:
45+
gtfs_location_ids = {loc.id for loc in gtfs_feed.locations}
46+
if gtfs_location_ids.isdisjoint(rt_location_ids):
47+
continue
48+
49+
feed_references.append(gtfs_feed.stable_id)
50+
4451
gtfs_rt_feed.feed_references = feed_references
4552

4653
return gtfs_rt_feed

0 commit comments

Comments
 (0)