Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .example.env
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Database Configuration
DB_URL=postgresql://postgres:postgres@localhost:5432/neso_solar

# country code for fetching data. Other options are "nl", "de"
COUNTRY="gb"
# country code for fetching data (ISO 3166-1 alpha-3). Options: "gbr_gb", "nld", "deu", "bel", "ind_rj"
COUNTRY="gbr_gb"

# ways to store the data. Other options are "csv", "site-db"
SAVE_METHOD="db"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ The package provides three main functionalities:
### Environment Variables: (Can be found in the .example.env / .env file)

- `DB_URL=postgresql://postgres:postgres@localhost:5432/neso_solar` : Database Configuration
- `COUNTRY="gb"` : Country code for fetching data. Currently, other options are ["be", "ind_rajasthan", "nl"]
- `COUNTRY="gbr_gb"` : Country code for fetching data (ISO 3166-1 alpha-3). Options are ["gbr_gb", "bel", "ind_rj", "nld", "deu"]
- `SAVE_METHOD`: Ways to store the data. Options are ["db", "csv", "site-db"].
`site-db` is supported for NL, DE, and India (RUVNL).
- `CSV_DIR=None` : Directory to save CSV files if `SAVE_METHOD="csv"`.
Expand Down
15 changes: 8 additions & 7 deletions solar_consumer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def app(
db_url: str,
save_method: str,
csv_dir: str = None,
country: str = "gb",
country: str = "gbr_gb",
historic_or_forecast: str = "generation",
):
"""
Expand All @@ -44,22 +44,23 @@ async def app(
db_url (str): Database connection URL from an environment variable.
save_method (str): Method to save the forecast data. Options are "db" or "csv".
csv_dir (str, optional): Directory to save CSV files if save_method is "csv".
country (str): Country code for fetching data. Default is "gb".
country (str): Country code for fetching data. Default is "gbr_gb".
historic_or_forecast: (str): Type of data to fetch. Default is "generation".
"""
logger.info(f"Starting the NESO Solar Forecast pipeline (version: {__version__}).")

# Use the `Neso` class for hardcoded configuration]
if country == "gb":
if country == "gbr_gb":
model_tag = "neso-solar-forecast"
elif country == "nl":
elif country == "nld":
model_tag = "ned-nl-national"
elif country == "de":
elif country == "deu":
model_tag = "entsoe-de"
elif country == "be":
elif country == "bel":
model_tag = "elia-be-forecast"



# Step 1: Fetch forecast data (returns as pd.Dataframe)
logger.info(f"Fetching {historic_or_forecast} data for {country}.")
data = fetch_data(country=country, historic_or_forecast=historic_or_forecast)
Expand Down Expand Up @@ -146,7 +147,7 @@ async def app(
if __name__ == "__main__":
# Step 1: Fetch the database URL from the environment variable
db_url = os.getenv("DB_URL") # Change from "DATABASE_URL" to "DB_URL"
country = os.getenv("COUNTRY", "gb")
country = os.getenv("COUNTRY", "gbr_gb")
save_method = os.getenv("SAVE_METHOD", "db").lower() # Default to "db"
csv_dir = os.getenv("CSV_DIR")
historic_or_forecast = os.getenv("HISTORIC_OR_FORECAST", "generation").lower()
Expand Down
16 changes: 8 additions & 8 deletions solar_consumer/fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
from solar_consumer.data.fetch_ind_rajasthan_data import fetch_ind_rajasthan_data


def fetch_data(country: str = "gb", historic_or_forecast: str = "forecast") -> pd.DataFrame:
def fetch_data(country: str = "gbr_gb", historic_or_forecast: str = "forecast") -> pd.DataFrame:
"""
Get data from different countries

:param country: "gb", "nl", "de", "ind_rajasthan", or "be"
:param country: "gbr_gb", "nld", "deu", "ind_rj", or "bel"
:param historic_or_forecast: "generation" or "forecast"
:return: Pandas dataframe with the following columns:
target_datetime_utc: Combined date and time in UTC.
solar_generation_kw: Solar generation in kW. Can be a forecast, or historic values
"""

country_data_functions = {
"gb": fetch_gb_data,
"nl": fetch_nl_data,
"de": fetch_de_data,
"ind_rajasthan": fetch_ind_rajasthan_data,
"be": fetch_be_data
"gbr_gb": fetch_gb_data,
"nld": fetch_nl_data,
"deu": fetch_de_data,
"ind_rj": fetch_ind_rajasthan_data,
"bel": fetch_be_data
}

if country in country_data_functions:
Expand All @@ -48,7 +48,7 @@ def fetch_data(country: str = "gb", historic_or_forecast: str = "forecast") -> p
raise Exception(f"An error occurred while fetching data for {country}: {e}") from e

else:
print("Only UK (gb), Netherlands (nl), Germany (de), Belgium (be), and Rajasthan India (ind_rajasthan) data can be fetched at the moment")
print("Only Great Britain (gbr_gb), Netherlands (nld), Germany (deu), Belgium (bel), and Rajasthan India (ind_rj) data can be fetched at the moment")

return pd.DataFrame() # Always return a DataFrame (never None)

Expand Down
22 changes: 11 additions & 11 deletions solar_consumer/save/save_data_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,33 @@


async def save_generation_to_data_platform(
data_df: pd.DataFrame, client: dp.DataPlatformDataServiceStub, country: str = "gb"
data_df: pd.DataFrame, client: dp.DataPlatformDataServiceStub, country: str = "gbr_gb"
) -> None:
"""
Saves model data via the data platform.

Incoming data is enriched with location information from the data platform. Anything with zero
capacity, or without a corresponding entry in the data platform, is ignored.

For GB: Data is joined via the gsp_id, which is a column in the incoming data, and has to be
For GBR_GB: Data is joined via the gsp_id, which is a column in the incoming data, and has to be
extracted from the metadata field in the data platform location data.

For NL: Data is joined via the region_id.
For NLD: Data is joined via the region_id.

Args:
data_df: DataFrame containing the generation data
client: Data platform client stub
location: Location identifier ('gb' or 'nl')
location: Location identifier ('gbr_gb' or 'nld')
"""
tasks: list[asyncio.Task] = []

# 0. Create the observers required if they don't exist already
if country == "nl":
if country == "nld":
required_observers = {"nednl"}
id_key = "region_id"
capacity_col = "capacity_kw"
capacity_multiplier = 1000
else: # gb
else: # gbr_gb
required_observers = {"pvlive_in_day", "pvlive_day_after"}
id_key = "gsp_id"
capacity_col = "capacity_mwp"
Expand All @@ -74,7 +74,7 @@ async def save_generation_to_data_platform(
raise exc

# 1. Get locations and join to the incoming data.
if country == "nl":
if country == "nld":
# Get NL locations (NATION only)
list_locations_request = dp.ListLocationsRequest(
location_type_filter=dp.LocationType.NATION,
Expand Down Expand Up @@ -139,7 +139,7 @@ async def save_generation_to_data_platform(
)
.assign(target_datetime_utc=lambda df: pd.to_datetime(df["target_datetime_utc"]))
)
else: # gb
else: # gbr_gb
# Get UK GSP locations, as well as national
tasks = [
asyncio.create_task(
Expand Down Expand Up @@ -236,7 +236,7 @@ async def save_generation_to_data_platform(
logging.info("updating %d %s location capacities", len(tasks), country.upper())
update_results = await asyncio.gather(*tasks, return_exceptions=True)
for exc in filter(lambda x: isinstance(x, Exception), update_results):
if country != "nl": # NL was previously ignoring these exceptions
if country != "nld": # NLD was previously ignoring these exceptions
raise exc

# 3. Generate the CreateObservationRequest objects from the DataFrame.
Expand All @@ -251,9 +251,9 @@ async def save_generation_to_data_platform(
)

# Determine observer name based on country
if country == "nl":
if country == "nld":
observer_name = "nednl"
else: # gb
else: # gbr_gb
regime: str = data_df["regime"].values[0]
observer_name = f"pvlive_{regime.replace('-', '_')}"

Expand Down
41 changes: 20 additions & 21 deletions solar_consumer/save/save_site_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_or_create_pvsite(
Parameters:
session (Session): CurrentSQLAlchemy session
pvsite (PVSite): Pydantic model with site metadata
country (str): Country code ('nl' or 'de')
country (str): Country code ('nld' or 'deu')
capacity_override_kw (Optional[int]): Force a specific capacity on creation

Returns:
Expand All @@ -88,12 +88,12 @@ def get_or_create_pvsite(
except Exception:
logger.info(f"Creating site {pvsite.client_site_name} in the database.")

# Choose capacity based on country; per-TSO for de; nl only has 20GW hard‑coded
# Choose capacity based on country; per-TSO for deu; nld only has 20GW hard‑coded
if capacity_override_kw is not None:
capacity = capacity_override_kw
elif country == "de":
elif country == "deu":
capacity = DE_TSO_CAPACITY[pvsite.client_site_name]
elif country == "nl":
elif country == "nld":
capacity = 20_000_000
else: #in
capacity = capacity_override_kw or 0
Expand Down Expand Up @@ -137,7 +137,7 @@ def update_capacity(


def save_generation_to_site_db(
generation_data: pd.DataFrame, session: Session, country: str = "nl"
generation_data: pd.DataFrame, session: Session, country: str = "nld"
):
"""Save generation data to the database.

Expand All @@ -147,9 +147,9 @@ def save_generation_to_site_db(
- solar_generation_kw
- target_datetime_utc
- capacity_kw (optional, used when present)
- tso_zone (only when country="de")
- tso_zone (only when country="deu")
session (Session): SQLAlchemy session for database access.
country: (str): Country code for the generation data ('nl', 'de', 'ind_rajasthan')
country: (str): Country code for the generation data ('nld', 'deu', 'ind_rj')


Return:
Expand All @@ -162,27 +162,27 @@ def save_generation_to_site_db(
return

# Determine country
if country == "nl":
if country == "nld":
country_sites = NL_NATIONAL_AND_REGIONS
elif country == "de":
elif country == "deu":
country_sites = DE_TSO_SITES
elif country == "ind_rajasthan":
elif country == "ind_rj":
country_sites = IND_RAJASTHAN_SITES
else:
raise Exception(
"Only generation data from the following countries is supported "
"when saving: 'nl', 'de', 'ind_rajasthan'"
"when saving: 'nld', 'deu', 'ind_rj'"
)

# Loop per site
for key, pvsite in country_sites.items():

# Filter by TSO for Germany, or use all data for NL
if country == "de":
# Filter by TSO for Germany, or use all data for NLD
if country == "deu":
generation_data_tso_df = generation_data[generation_data["tso_zone"] == key].copy()
elif country == "nl":
elif country == "nld":
generation_data_tso_df = generation_data[generation_data["region_id"] == int(key)].copy()
elif country == "ind_rajasthan":
elif country == "ind_rj":
generation_data_tso_df = generation_data[generation_data["energy_type"] == key].copy()
else:
generation_data_tso_df = generation_data.copy()
Expand All @@ -207,7 +207,7 @@ def save_generation_to_site_db(
)

generation_data_tso_df = generation_data_tso_df.copy()
if country == "ind_rajasthan":
if country == "ind_rj":
generation_data_tso_df["energy_type"] = key
else:
generation_data_tso_df["energy_type"] = "solar"
Expand Down Expand Up @@ -242,7 +242,7 @@ def save_forecasts_to_site_db(
session: Session,
model_tag: str,
model_version: str,
country: str = "nl",
country: str = "nld",
):
"""Save generation data to the database.

Expand All @@ -252,15 +252,14 @@ def save_forecasts_to_site_db(
session (Session): SQLAlchemy session for database access.
model_tag (str): Model tag to fetch model metadata.
model_version (str): Model version to fetch model metadata.
country: (str): Country code for the generation data. Currently only 'nl' is supported.
country: (str): Country code for the generation data. Currently only 'nld' is supported.

Return:
None
"""

if country != "nl":
raise Exception("Only NL forecast data is supported when saving (atm).")

if country != "nld":
raise Exception("Only NLD forecast data is supported when saving (atm).")
site = get_or_create_pvsite(session, nl_national, country)

timestamp_utc = pd.Timestamp.now(tz="UTC").floor("15min")
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_save_nl_to_data_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_save_nl_to_data_platform(client):
"capacity_kw": [80_000_000, 80_000_000, 60_000_000, 60_000_000],
}
)
_ = await save_generation_to_data_platform(fake_data, client=client, country="nl")
_ = await save_generation_to_data_platform(fake_data, client=client, country="nld")

# read from the data platform to check national data was saved
get_observations_request = dp.GetObservationsAsTimeseriesRequest(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_gb_historic_inday():
os.environ["UK_PVLIVE_REGIME"] = "in-day"
os.environ["UK_PVLIVE_N_GSPS"] = "10"

df = fetch_data(country = "gb", historic_or_forecast = "historic")
df = fetch_data(country = "gbr_gb", historic_or_forecast = "historic")

# 10 GSPs for 2 hours is
assert 30<=len(df) <=40
Expand All @@ -206,7 +206,7 @@ def test_gb_historic_day_after():
os.environ["UK_PVLIVE_REGIME"] = "day-after"
os.environ["UK_PVLIVE_N_GSPS"] = "10"

df = fetch_data(country = "gb", historic_or_forecast = "historic")
df = fetch_data(country = "gbr_gb", historic_or_forecast = "historic")

# 10 GSPs for 24 hours is at 30 minutes periods, including the extra two at end
assert len(df) == 10*48
8 changes: 4 additions & 4 deletions tests/unit/test_save_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ class TestCase:

with self.subTest(case.name):
if not case.should_error:
await save_generation_to_data_platform(case.input_df, client_mock, country="nl")
await save_generation_to_data_platform(case.input_df, client_mock, country="nld")
# Assert the data platform functions were called the expected number of times
self.assertEqual(
client_mock.update_location.call_count,
Expand Down Expand Up @@ -449,7 +449,7 @@ class TestCase:
self.assertEqual(call.args[0].observer_name, "nednl")
else:
with self.assertRaises(Exception):
await save_generation_to_data_platform(case.input_df, client_mock, country="nl")
await save_generation_to_data_platform(case.input_df, client_mock, country="nld")

@patch("dp_sdk.ocf.dp.DataPlatformDataServiceStub")
async def test_save_nl_generation_creates_locations_when_none_exist(self, client_mock):
Expand Down Expand Up @@ -503,7 +503,7 @@ def mock_list_observers(req: dp.ListObserversRequest) -> dp.ListObserversRespons
"target_datetime_utc": [np.datetime64('2023-01-01T00:00:00')],
})

await save_generation_to_data_platform(input_df, client_mock, country="nl")
await save_generation_to_data_platform(input_df, client_mock, country="nld")

# Verify create_location was called for each location in the CSV (13 locations)
self.assertEqual(client_mock.create_location.call_count, 13)
Expand All @@ -524,7 +524,7 @@ def test_save_generation_to_site_db_ind_rajasthan(db_site_session):
save_generation_to_site_db(
generation_data=generation_df,
session=db_site_session,
country="ind_rajasthan",
country="ind_rj",
)

saved_data = db_site_session.query(GenerationSQL).all()
Expand Down