Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion cometx/cli/admin_gpu_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def process_metrics_with_aggregation(
start_date: str = None,
end_date: str = None,
aggregate_metric: str = "avg",
selected_projects: List[str] = None,
) -> Tuple[Dict, Dict]:
"""
Process metric data with specified aggregation and time unit.
Expand All @@ -132,6 +133,7 @@ def process_metrics_with_aggregation(
start_date: Optional start date filter (YYYY-MM-DD)
end_date: Optional end date filter (YYYY-MM-DD)
aggregate_metric: Aggregation method for time series - "max" or "avg" (default: "avg")
selected_projects: Optional list of "workspace/project" strings to filter by

Returns:
tuple: (avg_data, time_series_data)
Expand Down Expand Up @@ -165,6 +167,15 @@ def process_metrics_with_aggregation(
continue

exp_info = experiment_map[exp_key]

# Apply project filter
if selected_projects is not None:
ws = exp_info.get("workspace", "Unknown")
proj = exp_info.get("project_name", "Unknown")
project_key = f"{ws}/{proj}"
if project_key not in selected_projects:
continue

server_timestamp = exp_info.get("server_timestamp")

# Apply date filter
Expand Down Expand Up @@ -271,13 +282,19 @@ def create_aggregation_avg_chart(
keys = [k for k, _ in sorted_keys]
values = [v for _, v in sorted_keys]

# Assign colors by alphabetical order to match the time series line chart
all_keys_sorted = sorted(data.keys())
colors_list = get_distinct_colors(len(all_keys_sorted))
color_map = {k: colors_list[i] for i, k in enumerate(all_keys_sorted)}
bar_colors = [color_map[k] for k in keys]

# Create bar chart
fig, ax = plt.subplots(figsize=(14, 8))

ax.bar(
range(len(keys)),
values,
color="steelblue",
color=bar_colors,
edgecolor="navy",
alpha=0.7,
width=0.8,
Expand Down Expand Up @@ -449,6 +466,21 @@ def create_time_series_chart(
return fig


def get_available_workspaces_and_projects(data: Dict) -> Dict[str, List[str]]:
"""Extract unique workspaces and their projects from experiment_map.

Returns:
Dict mapping workspace names to sorted list of project names.
"""
experiment_map = data.get("experiment_map", {})
workspace_projects = defaultdict(set)
for exp_info in experiment_map.values():
ws = exp_info.get("workspace", "Unknown")
proj = exp_info.get("project_name", "Unknown")
workspace_projects[ws].add(proj)
return {ws: sorted(list(projects)) for ws, projects in sorted(workspace_projects.items())}


def get_available_dates(data: Dict) -> Tuple[List[str], str, str]:
"""Extract available date range from the data."""
experiment_map = data.get("experiment_map", {})
Expand Down Expand Up @@ -496,6 +528,19 @@ def get_available_dates(data: Dict) -> Tuple[List[str], str, str]:

def main():
"""Main Streamlit app entry point."""
# Set theme primary color to blue (overrides default red/coral)
# This must be done before any Streamlit elements are rendered
try:
import streamlit.config as _config
_config.set_option("theme.primaryColor", "#1f77b4")
except Exception:
pass # Fallback gracefully if internal API changes

# Force rerun on first load to ensure theme takes effect
if "theme_initialized" not in st.session_state:
st.session_state.theme_initialized = True
st.rerun()

st.set_page_config(
page_title="Comet GPU Report",
page_icon="🖥️",
Expand Down Expand Up @@ -552,11 +597,107 @@ def main():
# Get available dates
available_dates, min_date, max_date = get_available_dates(data)

# Get available workspaces and projects for filtering
workspace_projects_map = get_available_workspaces_and_projects(data)
all_workspaces = list(workspace_projects_map.keys())

# Custom CSS for professional styling (override Streamlit's default red/pink accent)
st.markdown("""
<style>
/* Override checkbox checked color - target the role=checkbox div */
div[role="checkbox"][aria-checked="true"] {
background-color: #1f77b4 !important;
border-color: #1f77b4 !important;
}
/* Also target via data-baseweb attribute */
[data-baseweb="checkbox"] div[aria-checked="true"] {
background-color: #1f77b4 !important;
border-color: #1f77b4 !important;
}
/* Style multi-select pills */
span[data-baseweb="tag"] {
background-color: #1f77b4 !important;
}
span[data-baseweb="tag"] span {
color: white !important;
}
</style>
""", unsafe_allow_html=True)

# Sidebar configuration
with st.sidebar:
st.divider()
st.header("Configuration")

# Workspace/Project Filter Section
st.subheader("Data Filter")

# Get all available projects across all workspaces
all_projects = []
for ws in all_workspaces:
for proj in workspace_projects_map.get(ws, []):
all_projects.append(f"{ws}/{proj}")
all_projects = sorted(all_projects)

# Initialize session state for project selection
# Use individual checkbox keys as the source of truth
if "projects_initialized" not in st.session_state:
for proj_path in all_projects:
st.session_state[f"proj_{proj_path}"] = True
st.session_state.projects_initialized = True

# Track whether the expander should stay open
if "projects_expander_open" not in st.session_state:
st.session_state.projects_expander_open = True

# Select All / Clear All buttons
btn_col1, btn_col2 = st.columns(2)
with btn_col1:
if st.button("Select All", use_container_width=True):
for proj_path in all_projects:
st.session_state[f"proj_{proj_path}"] = True
st.session_state.projects_expander_open = True
st.rerun()
with btn_col2:
if st.button("Clear All", use_container_width=True):
for proj_path in all_projects:
st.session_state[f"proj_{proj_path}"] = False
st.session_state.projects_expander_open = True
st.rerun()

# Expandable project selector for fine-grained control
with st.expander("Select specific projects", expanded=st.session_state.projects_expander_open):
# Group projects by workspace for cleaner display
for ws in all_workspaces:
ws_projects = [f"{ws}/{p}" for p in workspace_projects_map.get(ws, [])]
if not ws_projects:
continue

st.markdown(f"**{ws}**")

for proj_path in sorted(ws_projects):
proj_name = proj_path.split("/")[-1]
st.checkbox(
proj_name,
key=f"proj_{proj_path}",
)

# Build selected projects list from checkbox states
selected_projects_set = set()
for proj_path in all_projects:
if st.session_state.get(f"proj_{proj_path}", True):
selected_projects_set.add(proj_path)

# Show project count
num_selected = len(selected_projects_set)
num_total = len(all_projects)
st.caption(f"**{num_selected}** of **{num_total}** projects selected")

# Convert set to list for processing
selected_projects = sorted(selected_projects_set)

st.divider()

# Aggregation selection
st.subheader("Aggregation")
aggregation = st.selectbox(
Expand Down Expand Up @@ -640,6 +781,7 @@ def main():
start_date_str,
end_date_str,
aggregate_metric=aggregate_metric,
selected_projects=selected_projects,
)

# Display statistics
Expand Down
Loading