Skip to content

Commit fcabc6c

Browse files
committed
New: feature importance visualizer
1 parent 1ed6b68 commit fcabc6c

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

wqf/interface/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858
VID_DEP = "deptho"
5959
"""The sea floor depth (m) variable identifier."""
6060

61+
VID_MDT = "mdt"
62+
"""The mean dynamic topography (m) variable identifier."""
63+
64+
"""The sea floor depth (m) variable identifier."""
6165
VID_NO3 = "no3"
6266
"""
6367
The concentration of nitrates (mmol m-3) variable identifier.

wqf/interface/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Plot(metaclass=ABCMeta):
1818
@abstractmethod
1919
def plot(
2020
self,
21-
data: DataArray | tuple[DataArray, DataArray],
21+
data: DataArray | tuple[DataArray, DataArray] | None,
2222
xlabel: str | None = None,
2323
ylabel: str | None = None,
2424
xlim: tuple[Any, Any] | None = None,

wqf/val/importanceviz.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Brockmann Consult GmbH, 2024
2+
# License: MIT
3+
4+
"""This module provides a feature importance visualizer."""
5+
6+
from typing import Any
7+
from typing import Literal
8+
from typing import Sequence
9+
10+
import xgboost
11+
from matplotlib import pyplot as plt
12+
from matplotlib.figure import Figure
13+
from xgboost import Booster
14+
15+
from wqf.interface.constants import VID_DEP
16+
from wqf.interface.constants import VID_MDT
17+
from wqf.interface.plot import Plot
18+
from wqf.xgb import registry
19+
20+
21+
class ImportanceVisualizer(Plot):
22+
"""A feature importance visualizer."""
23+
24+
_model: Booster
25+
"""The forecast model."""
26+
27+
def __init__(self, name: str = "default"):
28+
"""
29+
Creates a new analyzer.
30+
31+
:param name: The name of the forecast model to analyze.
32+
"""
33+
self._name = name
34+
self._model = registry().model(name)
35+
self._model.feature_names = self._feature_names()
36+
37+
def visualize(self, feature_count: int | None = None):
38+
"""
39+
Visualizes feature importance.
40+
41+
:param feature_count: The number of features to visualize.
42+
"""
43+
self._plot("cover", (0.0, 1.0e07), feature_count)
44+
self._plot("gain", (0.0, 1.0e06), feature_count)
45+
self._plot("total_gain", (0.0, 1.0e07), feature_count)
46+
self._plot("total_cover", (0.0, 1.0e09), feature_count)
47+
self._plot("weight", (0.0, 1.0e03), feature_count)
48+
49+
def plot(
50+
self,
51+
data: None = None,
52+
xlabel: str | None = None,
53+
ylabel: str | None = "feature",
54+
xlim: tuple[Any, Any] | None = None,
55+
ylim: tuple[Any, Any] | None = None,
56+
title: str | None = "Feature importance",
57+
fn: str | None = None,
58+
show: bool = False,
59+
*,
60+
bar_height: Any = 0.4,
61+
feature_count: int = 12,
62+
importance_type: Literal[
63+
"cover", "gain", "total_cover", "total_gain", "weight"
64+
] = "total_gain",
65+
show_grid: bool = False,
66+
show_values: bool = True,
67+
values_format: str = "{v:,.1f}",
68+
) -> Figure:
69+
fig, ax = plt.subplots()
70+
xgboost.plot_importance(
71+
self._model,
72+
ax=ax,
73+
height=bar_height,
74+
title=title,
75+
xlim=xlim,
76+
xlabel=xlabel,
77+
ylabel=ylabel,
78+
importance_type=importance_type,
79+
max_num_features=feature_count,
80+
grid=show_grid,
81+
show_values=show_values,
82+
values_format=values_format,
83+
)
84+
if fn is not None:
85+
fig.savefig(f"{fn}.pdf", bbox_inches="tight")
86+
if show:
87+
fig.show()
88+
plt.close()
89+
return fig
90+
91+
def _plot(
92+
self,
93+
importance_type: Literal[
94+
"cover", "gain", "total_cover", "total_gain", "weight"
95+
],
96+
xlim: tuple | None = None,
97+
feature_count: int | None = None,
98+
):
99+
"""Generates an importance plot."""
100+
fig = self.plot(
101+
xlabel=f"{importance_type.replace('_', ' ')} (arbitrary units)",
102+
ylabel="feature",
103+
xlim=xlim,
104+
fn=f"{self._name}_{importance_type}",
105+
feature_count=feature_count,
106+
importance_type=importance_type,
107+
)
108+
fig.clear()
109+
110+
def _feature_names(
111+
self, static: Sequence[str] = (VID_DEP, VID_MDT)
112+
) -> Sequence[str]:
113+
"""Returns feature names suitable for an importance plot."""
114+
labels: list[str] = []
115+
for feature_name in self._model.feature_names:
116+
day, v = feature_name.split("_")
117+
118+
if v in static:
119+
labels.append(f"{v}")
120+
else:
121+
labels.append(f"{v}[{ImportanceVisualizer._t(day)}]")
122+
return labels
123+
124+
@staticmethod
125+
def _t(day: str) -> str:
126+
"""Returns the time string for a given day string."""
127+
d = day[1:] if "-" in day and day[2:] != "0" else day[2:]
128+
return r"$t_{" + d + "}$"
129+
130+
131+
if __name__ == "__main__":
132+
ImportanceVisualizer().visualize(feature_count=12)

0 commit comments

Comments
 (0)