Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@
from chainlit.sidebar import ElementSidebar
from chainlit.step import Step, step
from chainlit.sync import make_async, run_sync
from chainlit.types import ChatProfile, InputAudioChunk, OutputAudioChunk, Starter
from chainlit.types import (
ChatProfile,
InputAudioChunk,
OutputAudioChunk,
Starter,
StarterCategory,
)
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_session
from chainlit.utils import make_module_getattr
Expand Down Expand Up @@ -84,6 +90,7 @@
password_auth_callback,
send_window_message,
set_chat_profiles,
set_starter_categories,
set_starters,
)

Expand Down Expand Up @@ -161,6 +168,7 @@ def acall(self):
"Pyplot",
"SemanticKernelFilter",
"Starter",
"StarterCategory",
"Step",
"Task",
"TaskList",
Expand Down Expand Up @@ -203,6 +211,7 @@ def acall(self):
"run_sync",
"send_window_message",
"set_chat_profiles",
"set_starter_categories",
"set_starters",
"sleep",
"step",
Expand Down
33 changes: 32 additions & 1 deletion backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from chainlit.message import Message
from chainlit.oauth_providers import get_configured_oauth_providers
from chainlit.step import Step, step
from chainlit.types import ChatProfile, Starter, ThreadDict
from chainlit.types import ChatProfile, Starter, StarterCategory, ThreadDict
from chainlit.user import User
from chainlit.utils import wrap_user_function

Expand Down Expand Up @@ -277,6 +277,37 @@ def set_starters(func):
return func


@overload
def set_starter_categories(
func: Callable[[Optional["User"]], Awaitable[List["StarterCategory"]]],
) -> Callable[[Optional["User"]], Awaitable[List["StarterCategory"]]]: ...


@overload
def set_starter_categories(
func: Callable[
[Optional["User"], Optional["str"]], Awaitable[List["StarterCategory"]]
],
) -> Callable[
[Optional["User"], Optional["str"]], Awaitable[List["StarterCategory"]]
]: ...


def set_starter_categories(func):
"""
Programmatic declaration of starter categories with grouped starters.

Args:
func (Callable[[Optional["User"], Optional["str"]], Awaitable[List["StarterCategory"]]]): The function declaring the starter categories with optional user and language arguments.

Returns:
Callable[[Optional["User"], Optional["str"]], Awaitable[List["StarterCategory"]]]: The decorated function.
"""

config.code.set_starter_categories = wrap_user_function(func)
return func


def on_chat_end(func: Callable) -> Callable:
"""
Hook to react to the user websocket disconnect event.
Expand Down
8 changes: 7 additions & 1 deletion backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@
Feedback,
InputAudioChunk,
Starter,
StarterCategory,
ThreadDict,
)
from chainlit.user import User
else:
# Pydantic needs to resolve forward annotations. Because all of these are used
# within `typing.Callable`, alias to `Any` as Pydantic does not perform validation
# of callable argument/return types anyway.
Request = Response = Action = Message = ChatProfile = InputAudioChunk = Starter = ThreadDict = User = Feedback = Any # fmt: off
Request = Response = Action = Message = ChatProfile = InputAudioChunk = Starter = StarterCategory = ThreadDict = User = Feedback = Any # fmt: off

BACKEND_ROOT = os.path.dirname(__file__)
PACKAGE_ROOT = os.path.dirname(os.path.dirname(BACKEND_ROOT))
Expand Down Expand Up @@ -403,6 +404,11 @@ class CodeSettings(BaseModel):
set_starters: Optional[
Callable[[Optional["User"], Optional["str"]], Awaitable[List["Starter"]]]
] = None
set_starter_categories: Optional[
Callable[
[Optional["User"], Optional["str"]], Awaitable[List["StarterCategory"]]
]
] = None
on_shared_thread_view: Optional[
Callable[["ThreadDict", Optional["User"]], Awaitable[bool]]
] = None
Expand Down
60 changes: 60 additions & 0 deletions backend/chainlit/sample/starters_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional

import chainlit as cl


@cl.set_starter_categories
async def starter_categories(user: Optional[cl.User] = None):
return [
cl.StarterCategory(
label="Creative",
icon="https://cdn-icons-png.flaticon.com/512/3094/3094837.png",
starters=[
cl.Starter(
label="Write a poem about nature",
message="Write a poem about nature"
),
cl.Starter(
label="Create a short story",
message="Create a short story about adventure"
),
cl.Starter(
label="Generate a creative name",
message="Generate creative names for a tech startup"
),
],
),
cl.StarterCategory(
label="Learning",
icon="https://cdn-icons-png.flaticon.com/512/3976/3976625.png",
starters=[
cl.Starter(
label="Explain a complex topic",
message="Explain quantum computing in simple terms"
),
cl.Starter(
label="Help me learn a language",
message="Teach me basic French phrases"
),
],
),
cl.StarterCategory(
label="Productivity",
icon="https://cdn-icons-png.flaticon.com/512/1055/1055646.png",
starters=[
cl.Starter(
label="Summarize a topic",
message="Summarize the key points of machine learning"
),
cl.Starter(
label="Create a plan",
message="Help me create a weekly study plan"
),
],
),
]


@cl.on_message
async def on_message(msg: cl.Message):
await cl.Message(f"You said: {msg.content}").send()
7 changes: 7 additions & 0 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,12 @@ async def project_settings(
if s:
starters = [it.to_dict() for it in s]

starter_categories = []
if config.code.set_starter_categories:
sc = await config.code.set_starter_categories(current_user, effective_language)
if sc:
starter_categories = [it.to_dict() for it in sc]

data_layer = get_data_layer()
debug_url = (
await data_layer.build_debug_url() if data_layer and config.run.debug else None
Expand Down Expand Up @@ -865,6 +871,7 @@ async def project_settings(
"markdown": markdown,
"chatProfiles": profiles,
"starters": starters,
"starterCategories": starter_categories,
"debugUrl": debug_url,
}
)
Expand Down
11 changes: 11 additions & 0 deletions backend/chainlit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from chainlit.element import ElementDict
from chainlit.step import StepDict

from dataclasses import field

from dataclasses_json import DataClassJsonMixin
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -298,6 +300,15 @@ class Starter(DataClassJsonMixin):
icon: Optional[str] = None


@dataclass
class StarterCategory(DataClassJsonMixin):
"""A category/group of starters with an optional icon."""

label: str
icon: Optional[str] = None
starters: List[Starter] = field(default_factory=list)


@dataclass
class ChatProfile(DataClassJsonMixin):
"""Specification for a chat profile that can be chosen by the user at the thread start."""
Expand Down
52 changes: 52 additions & 0 deletions backend/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,58 @@ async def get_starters(user, language):
assert result[0].message == "Message de test"


async def test_set_starter_categories(
mock_chainlit_context, test_config: config.ChainlitConfig
):
from chainlit.callbacks import set_starter_categories
from chainlit.types import Starter, StarterCategory

async with mock_chainlit_context:

@set_starter_categories
async def get_starter_categories(user, language):
return [
StarterCategory(
label="Creative",
icon="https://example.com/creative.png",
starters=[
Starter(label="Write a poem", message="Write a poem"),
Starter(label="Write a story", message="Write a story"),
],
),
StarterCategory(
label="Educational",
starters=[
Starter(label="Explain concept", message="Explain it"),
],
),
]

assert test_config.code.set_starter_categories is not None

result = await test_config.code.set_starter_categories(None, None)

assert result is not None
assert isinstance(result, list)
assert len(result) == 2

assert result[0].label == "Creative"
assert result[0].icon == "https://example.com/creative.png"
assert len(result[0].starters) == 2
assert result[0].starters[0].label == "Write a poem"

assert result[1].label == "Educational"
assert result[1].icon is None
assert len(result[1].starters) == 1

category_dict = result[0].to_dict()
assert category_dict["label"] == "Creative"
assert category_dict["icon"] == "https://example.com/creative.png"
starters_list = category_dict["starters"]
assert isinstance(starters_list, list)
assert len(starters_list) == 2


async def test_on_shared_thread_view_allow(
mock_chainlit_context, test_config: config.ChainlitConfig
):
Expand Down
27 changes: 27 additions & 0 deletions cypress/e2e/starters_categories/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional

import chainlit as cl


@cl.set_starter_categories
async def starter_categories(user: Optional[cl.User] = None):
return [
cl.StarterCategory(
label="Creative",
starters=[
cl.Starter(label="poem", message="Write a poem"),
cl.Starter(label="story", message="Write a story"),
],
),
cl.StarterCategory(
label="Educational",
starters=[
cl.Starter(label="explain", message="Explain something"),
],
),
]


@cl.on_message
async def on_message(msg: cl.Message):
await cl.Message(msg.content).send()
33 changes: 33 additions & 0 deletions cypress/e2e/starters_categories/spec.cy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
describe('Starters with Categories', () => {
it('should display category buttons', () => {
cy.wait(1000);
cy.get('#starters').should('exist');

cy.contains('button', 'Creative').should('exist');
cy.contains('button', 'Educational').should('exist');
});

it('should show starters when category is clicked', () => {
cy.wait(1000);
cy.contains('button', 'Creative').click();
cy.get('#starter-poem').should('exist');
cy.get('#starter-story').should('exist');
});

it('should be able to use a starter from a category', () => {
cy.wait(1000);
cy.contains('button', 'Creative').click();
cy.get('#starter-poem').should('exist').click();
cy.get('.step').should('have.length', 2);
cy.get('.step').eq(0).contains('Write a poem');
});

it('should toggle category selection', () => {
cy.wait(1000);
cy.contains('button', 'Creative').click();
cy.get('#starter-poem').should('exist');

cy.contains('button', 'Creative').click();
cy.get('#starter-poem').should('not.exist');
});
});
24 changes: 24 additions & 0 deletions frontend/src/components/chat/StarterCategory.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { IStarterCategory } from '@chainlit/react-client';

import { Button } from '@/components/ui/button';

interface Props {
category: IStarterCategory;
isSelected: boolean;
onClick: () => void;
}

export default function StarterCategory({ category, isSelected, onClick }: Props) {
return (
<Button
variant={isSelected ? 'default' : 'outline'}
className="rounded-full gap-2"
onClick={onClick}
>
{category.icon && (
<img className="h-4 w-4" src={category.icon} alt="" />
)}
{category.label}
</Button>
);
}
Loading