@@ -27,20 +27,27 @@ def from_orm(cls, feed: Gtfsrealtimefeed | None, db_session: Session) -> GtfsRTF
2727 gtfs_rt_feed .locations = [LocationImpl .from_orm (item ) for item in feed .locations ] if feed .locations else []
2828 gtfs_rt_feed .entity_types = [item .name for item in feed .entitytypes ] if feed .entitytypes else []
2929
30- # gtfs_rt_feed.feed_references = [item.stable_id for item in feed.gtfs_feeds] if feed.gtfs_feeds else []
31- gtfs_rt_location_ids = {location .id for location in feed .locations }
30+ # Base query: same provider_id, but not the same feed
3231 query = (
3332 db_session .query (GtfsFeedOrm )
34- .filter (GtfsFeedOrm .provider == feed .provider , GtfsFeedOrm .stable_id != feed .stable_id )
33+ .filter (
34+ GtfsFeedOrm .provider_id == feed .provider_id ,
35+ GtfsFeedOrm .stable_id != feed .stable_id ,
36+ )
3537 .options (joinedload (GtfsFeedOrm .locations ))
3638 )
3739
40+ # If the GtfsRT feed has locations, require overlap
41+ rt_location_ids = {loc .id for loc in feed .locations } if feed .locations else set ()
3842 feed_references = []
3943 for gtfs_feed in query .all ():
40- gtfs_location_ids = {location .id for location in gtfs_feed .locations }
41- # Check if there is any overlap in locations.
42- if not gtfs_location_ids .isdisjoint (gtfs_rt_location_ids ):
43- feed_references .append (gtfs_feed .stable_id )
44+ if rt_location_ids :
45+ gtfs_location_ids = {loc .id for loc in gtfs_feed .locations }
46+ if gtfs_location_ids .isdisjoint (rt_location_ids ):
47+ continue
48+
49+ feed_references .append (gtfs_feed .stable_id )
50+
4451 gtfs_rt_feed .feed_references = feed_references
4552
4653 return gtfs_rt_feed
0 commit comments