A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations
- New version of the framework to be out soon!!
- [2025.05] 🎉 The "CoT-ICL Lab" paper has been accepted to ACL Main 2025!
- Create a virtual environment and install the package.
$ python3.12 -m venv .venv
$ source .venv/bin/activate
(.venv) $ pip install -e .- Run unit tests as a sanity check.
(.venv) $ pytest- (Development) Run ruff + isort fixes to sanitize the code changes.
(.venv) $ ./beautify.shOur framework serves as a test bed to generate synthetic tokenized datasets for training and evaluating transformer models. We do so by using DAG and TokenProcessor classes. These can be configured directly by the Args dataclass. For example:
from tokenized_cot_icl.core.args import Args
from tokenized_cot_icl.core.data import TokenizedDataset
args = Args(
vocab_size=1024,
n_inputs=4,
n_parents=2,
chain_length=3,
n_examples=1,
enable_cot=True,
prompt_strategy="cot",
activation="leaky_relu",
n_tasks=10,
)
dataset = TokenizedDataset(args=args)
print(dataset[0])The above item in the dataset is as follows:
{
'adj_list': tensor([[0, 2], [4, 3], [5, 3]]),
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867, 240, 466, 217]),
'labels': tensor([-100, -100, -100, -100, 240, 466, 217]),
'cot_eval':
{
'attention_mask': tensor([1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867]),
'last_example_cot': tensor([240, 466, 217])
}
}Let's break down the result above to understand the DAG structure. Consider
The 'adj_list': tensor([[0, 2], [4, 3], [5, 3]]) (based on zero-indexing) indicates that the parent tokens for the chain tokens are as follows:
| Chain Token | Parent Tokens |
|---|---|
Note
The TokenCoverage metric introduced in the paper relies on the uniqueness of chain tokens in the entire dataset and depends heavily on the "vocab_size" and "activation". Thus controlling the difficulty of the tasks.
We leverage the HuggingFace transformers library to create custom Llama models and expose a MODEL_REGISTRY to register new model families.
# src/tokenized_cot_icl/core/models.py
MODEL_REGISTRY = {"llama": create_llama_model}Tip
Users can register the creation function for models of their choice from the transformers library to explore new architectures and validate ideas.
To make it suitable for bulk launching the experiments, we rely on a TASK_CARD to collate all the args. For instance, to train a model with the args as per the above example, we do:
# src/tokenized_cot_icl/core/task_card.py
def custom_task_card() -> Dict[int, Args]:
"""A custom task card."""
args = Args(...) # set as needed
return {0: args}
# set the dictionary
TASK_CARD = custom_task_card()The TASK_CARD allows us to index into the experimental config of our choice and launch the torch distributed data parallel (DDP) training runs. For example:
(.venv) $ cd src
(.venv) $ export NUM_NODES=1 # change as needed
(.venv) $ export LOCAL_WORLD_SIZE=4 # change as needed
(.venv) $ torchrun --nnodes=$NUM_NODES --nproc-per-node=$LOCAL_WORLD_SIZE -m tokenized_cot_icl.core.train --task_card_key 0- By default, we use
metric_logger="stdout"inArgsand log the metrics/params toSTDOUT. - We also support logging to an MLFlow tracking server by setting the
MLFLOW_SERVICE_URLenvironment variable and usingArgs(metric_logger="mlflow").
Users can also apply the Liger-Kernel optimizations to patch the llama models by setting Args(use_liger_kernels=True) and speed up the training runs.
(.venv) $ pip install liger-kernel # install suitable versionIn addition to using the transformers.GenerationConfig for small scale inference during the training runs, we also support vLLM and SGLang based evaluation of the trained model (or model checkpoints) to analyze the predictions.
(.venv) $ pip install vllm # install suitable version
(.venv) $ pip install sglang # install suitable versionWe provide an easy to extend example for calculating the answer token prediction accuracy as follows:
# for vllm
(.venv) $ cd src && python tokenized_cot_icl/inference/vllm/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.
# for sglang
(.venv) $ cd src && python tokenized_cot_icl/inference/sglang/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.@inproceedings{Kothapalli2025CoTICLLAB,
title={CoT-ICL Lab: A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations},
author={Vignesh Kothapalli and Hamed Firooz and Maziar Sanjabi},
booktitle={Annual Meeting of the Association for Computational Linguistics},
year={2025},
}