diff --git a/cometx/cli/admin_gpu_app.py b/cometx/cli/admin_gpu_app.py index 3b8013b..2eb6d67 100644 --- a/cometx/cli/admin_gpu_app.py +++ b/cometx/cli/admin_gpu_app.py @@ -119,6 +119,7 @@ def process_metrics_with_aggregation( start_date: str = None, end_date: str = None, aggregate_metric: str = "avg", + selected_projects: List[str] = None, ) -> Tuple[Dict, Dict]: """ Process metric data with specified aggregation and time unit. @@ -132,6 +133,7 @@ def process_metrics_with_aggregation( start_date: Optional start date filter (YYYY-MM-DD) end_date: Optional end date filter (YYYY-MM-DD) aggregate_metric: Aggregation method for time series - "max" or "avg" (default: "avg") + selected_projects: Optional list of "workspace/project" strings to filter by Returns: tuple: (avg_data, time_series_data) @@ -165,6 +167,15 @@ def process_metrics_with_aggregation( continue exp_info = experiment_map[exp_key] + + # Apply project filter + if selected_projects is not None: + ws = exp_info.get("workspace", "Unknown") + proj = exp_info.get("project_name", "Unknown") + project_key = f"{ws}/{proj}" + if project_key not in selected_projects: + continue + server_timestamp = exp_info.get("server_timestamp") # Apply date filter @@ -271,13 +282,19 @@ def create_aggregation_avg_chart( keys = [k for k, _ in sorted_keys] values = [v for _, v in sorted_keys] + # Assign colors by alphabetical order to match the time series line chart + all_keys_sorted = sorted(data.keys()) + colors_list = get_distinct_colors(len(all_keys_sorted)) + color_map = {k: colors_list[i] for i, k in enumerate(all_keys_sorted)} + bar_colors = [color_map[k] for k in keys] + # Create bar chart fig, ax = plt.subplots(figsize=(14, 8)) ax.bar( range(len(keys)), values, - color="steelblue", + color=bar_colors, edgecolor="navy", alpha=0.7, width=0.8, @@ -449,6 +466,21 @@ def create_time_series_chart( return fig +def get_available_workspaces_and_projects(data: Dict) -> Dict[str, List[str]]: + """Extract unique workspaces and their projects from experiment_map. + + Returns: + Dict mapping workspace names to sorted list of project names. + """ + experiment_map = data.get("experiment_map", {}) + workspace_projects = defaultdict(set) + for exp_info in experiment_map.values(): + ws = exp_info.get("workspace", "Unknown") + proj = exp_info.get("project_name", "Unknown") + workspace_projects[ws].add(proj) + return {ws: sorted(list(projects)) for ws, projects in sorted(workspace_projects.items())} + + def get_available_dates(data: Dict) -> Tuple[List[str], str, str]: """Extract available date range from the data.""" experiment_map = data.get("experiment_map", {}) @@ -496,6 +528,19 @@ def get_available_dates(data: Dict) -> Tuple[List[str], str, str]: def main(): """Main Streamlit app entry point.""" + # Set theme primary color to blue (overrides default red/coral) + # This must be done before any Streamlit elements are rendered + try: + import streamlit.config as _config + _config.set_option("theme.primaryColor", "#1f77b4") + except Exception: + pass # Fallback gracefully if internal API changes + + # Force rerun on first load to ensure theme takes effect + if "theme_initialized" not in st.session_state: + st.session_state.theme_initialized = True + st.rerun() + st.set_page_config( page_title="Comet GPU Report", page_icon="🖥️", @@ -552,11 +597,107 @@ def main(): # Get available dates available_dates, min_date, max_date = get_available_dates(data) + # Get available workspaces and projects for filtering + workspace_projects_map = get_available_workspaces_and_projects(data) + all_workspaces = list(workspace_projects_map.keys()) + + # Custom CSS for professional styling (override Streamlit's default red/pink accent) + st.markdown(""" + + """, unsafe_allow_html=True) + # Sidebar configuration with st.sidebar: st.divider() st.header("Configuration") + # Workspace/Project Filter Section + st.subheader("Data Filter") + + # Get all available projects across all workspaces + all_projects = [] + for ws in all_workspaces: + for proj in workspace_projects_map.get(ws, []): + all_projects.append(f"{ws}/{proj}") + all_projects = sorted(all_projects) + + # Initialize session state for project selection + # Use individual checkbox keys as the source of truth + if "projects_initialized" not in st.session_state: + for proj_path in all_projects: + st.session_state[f"proj_{proj_path}"] = True + st.session_state.projects_initialized = True + + # Track whether the expander should stay open + if "projects_expander_open" not in st.session_state: + st.session_state.projects_expander_open = True + + # Select All / Clear All buttons + btn_col1, btn_col2 = st.columns(2) + with btn_col1: + if st.button("Select All", use_container_width=True): + for proj_path in all_projects: + st.session_state[f"proj_{proj_path}"] = True + st.session_state.projects_expander_open = True + st.rerun() + with btn_col2: + if st.button("Clear All", use_container_width=True): + for proj_path in all_projects: + st.session_state[f"proj_{proj_path}"] = False + st.session_state.projects_expander_open = True + st.rerun() + + # Expandable project selector for fine-grained control + with st.expander("Select specific projects", expanded=st.session_state.projects_expander_open): + # Group projects by workspace for cleaner display + for ws in all_workspaces: + ws_projects = [f"{ws}/{p}" for p in workspace_projects_map.get(ws, [])] + if not ws_projects: + continue + + st.markdown(f"**{ws}**") + + for proj_path in sorted(ws_projects): + proj_name = proj_path.split("/")[-1] + st.checkbox( + proj_name, + key=f"proj_{proj_path}", + ) + + # Build selected projects list from checkbox states + selected_projects_set = set() + for proj_path in all_projects: + if st.session_state.get(f"proj_{proj_path}", True): + selected_projects_set.add(proj_path) + + # Show project count + num_selected = len(selected_projects_set) + num_total = len(all_projects) + st.caption(f"**{num_selected}** of **{num_total}** projects selected") + + # Convert set to list for processing + selected_projects = sorted(selected_projects_set) + + st.divider() + # Aggregation selection st.subheader("Aggregation") aggregation = st.selectbox( @@ -640,6 +781,7 @@ def main(): start_date_str, end_date_str, aggregate_metric=aggregate_metric, + selected_projects=selected_projects, ) # Display statistics