Trainer: Builds simulation environment.Trainerwill spawn multipleTrainNodeinstances using PyTorch Distributed. The local instances are connected together with_build_connection, andTrainNode.train()is executed on each rank.TrainNode: A single node (rank) running its own training loop. At each train step, instead of callingoptim.step(), it callsstrategy.step().Strategy: Abstract class for an optimization strategy, which both defines how the nodes communicate with each other and how model weights are updated. Typically, a gradient strategy will include an optimizer as well as a communication step. Sometimes (eg. DeMo), the optimizer step is comingled with the communication.
EXO Gym uses pytorch multiprocessing to spawn a subprocess per-node, which are able to communicate with each other using regular operations such as all_reduce.
The model is expected in a form that takes a batch (the same format as dataset outputs), and returns a scalar loss over the entire batch. This ensures the model is agnostic to the format of the data (eg. masked LM training doesn't have a clear x/y split).
Recall that when we call trainer.fit(),
Instantiate a single Dataset. The dataset object is passed to every subprocess, and a DistributedSampler will be used to select which datapoints are sampled per-node (to ensure each datapoint is only used once by each node). If the dataset is entirely loaded into memory, this memory will be duplicated per-node - be careful not to run out of memory! If the dataset is larger, it should be lazily loaded.
In place of the dataset object, pass a function with the following signature:
def dataset_factory(rank: int, num_nodes: int, train_dataset: bool) -> torch.utils.data.DatasetThis will be called within each rank to build the dataset. Instead of each node storing the whole dataset and subsampling datapoints, each node only loads the necessary datapoints.