Weights & Biases
stability.wandb.io
StableLM-3B-4E1T
Technical report for StableLM-3B-4E1TJonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz
Comment
Model ArchitectureTraining DataTraining ProcedureDownstream ResultsSystem DetailsConclusionAcknowledgmentsReferences
StableLM-3B-4E1T is a 3 billion (3B) parameter language model pre-trained under the multi-epoch regime to study the impact of repeated tokens on downstream performance. Given prior success in this area (https://arxiv.org/pdf/2205.05131.pdfTaylor et al., 2022 and Tay et al., 2023), we train on 1 trillion (1T) tokens for 4 epochs following the observations of Muennighoff et al. (2023) in "Scaling Data-Constrained Language Models" in which they find "training with up to 4 epochs of repeated data yields negligible changes to loss compared to having unique data." Further inspiration for the token count is taken from "Go smol or go home" (De Vries, 2023), which suggests a 2.96B model trained for 2.85 trillion tokens achieves a similar loss to a Chinchilla compute-optimal 9.87B language model (k_n =0.3kn=0.3).https://github.com/orgs/Stability-AI/projects/8?pane=issue&itemId=36926940
Model Architecture
Checkpoint: stabilityai/stablelm-3b-4e1t
The model is a decoder-only transformer similar to the LLaMA (Touvron et al., 2023) https://arxiv.org/abs/2307.09288architecture with the following modifications:
Parameters | Hidden Size | Layers | Heads | Sequence Length |
---|---|---|---|---|
2,795,443,200 | 2560 | 32 | 32 | 4096 |
- Position Embeddings: Rotary Position Embeddings (Su et al., 2021) applied to the first 25% of head embedding dimensions for improved throughput following Black et al. (2022).
- Normalization: LayerNorm (Ba et al., 2016) with learned bias terms as opposed to RMSNorm (Zhang & Sennrich, 2019).
- Tokenizer: GPT-NeoX (Black et al., 2022).
Training Data
The dataset is comprised of a filtered mixture of open-source large-scale datasets available on the HuggingFace Hub: Falcon RefinedWeb extract (Penedo et al., 2023), RedPajama-Data (Together Computer, 2023) and The Pile (Gao et al., 2020), both without the Books3 subset, and StarCoder (Li et al., 2023). The complete list is provided in Table 1.Table 1: Open-source datasets used for multi-epoch training. Note that the total token count does not account for the reduced size after downsampling C4, Common Crawl (2023), and GitHub to obtain 1T tokens.
Given the large amount of web data, we recommend fine-tuning the base StableLM-3B-4E1T for your downstream tasks.
Training Procedure
The model is trained for 972k steps in bfloat16 precision with a global context length of 4096 instead of the multi-stage ramp-up from 2048-to-4096 as done for StableLM-Alpha v2. The batch size is set to 1024 (4,194,304 tokens). We optimize with AdamW (Loshchilov and Hutter, 2017) and use linear warmup for the first 4.8k steps, followed by a cosine decay schedule to 4% of the peak learning rate. Early instabilities are attributed to extended periods in high learning rate regions. We do not incorporate dropout (Srivastava et al., 2014) due to the model's relatively small size. Detailed hyperparameters are provided in the model config here.During training, we evaluate natural language benchmarks and observe steady improvements over the course of training until the tail end of the learning rate decay schedule. For this reason, we decided to linearly cool down the learning rate towards 0, similar to Zhai et al. (2021), in hopes of squeezing out performance. We plan to explore alternative schedules in future work.
Furthermore, our initial stage of pre-training relies on the flash-attention API (Tri Dao, 2023) with its out-of-the-box triangular causal masking support. This forces the model to attend similarly to different documents in a packed sequence. In the cool-down stage, we instead reset position IDs and attention masks at EOD tokens for all packed sequences after empirically observing improved sample quality (read: less repetition) in a concurrent experiment. We hypothesize that this late adjustment leads to the notable degradation in byte-length normalized accuracies of Arc Easy (Clark et al., 2018) and SciQ (Welbl et al., 2017).
Figure 1: Toy demonstration of attention mask resetting.
Data composition was modified during the cool-down. Specifically, we remove Ubuntu IRC, OpenWebText, HackerNews, and FreeLaw for quality control and further NSFW filtering while upsampling C4. The distribution shift is likely responsible for the increased loss (+0.02 nats) from the initial stage.
See the plots below for validation dynamics across our hold-out set and common NLP benchmarks.
Note: The released checkpoint is taken from step 970k according to validation loss and average downstream performance.
Downstream Results
The following zero-shot evaluations are performed with EleutherAI's lm-evaluation-harness using the lm-bench branch of Stability AI's fork.Table 2: Zero-shot performance across popular language modeling and common sense reasoning benchmarks. lm-eval results JSONs can be found in the evals directory of the StableLM repo.
StableLM-3B-4E1T achieves state-of-the-art performance (September 2023) at the 3B parameter scale for open-source models and is competitive with many of the popular contemporary 7B models, even outperforming our most recent 7B StableLM-Base-Alpha-v2.
System Details
- Hardware: StableLM-3B-4E1T was trained on the Stability AI cluster across 256 NVIDIA A100 40GB GPUs (AWS P4d instances). Training began on August 23, 2023, and took approximately 30 days to complete.
- Software: We use a fork of gpt-neox (EleutherAI, 2021), train under 2D parallelism (Data and Tensor Parallel) with ZeRO-1 (Rajbhandari et al., 2019), and rely on flash-attention as well as SwiGLU and Rotary Embedding kernels from FlashAttention-2 (Dao et al., 2023).
Note: TFLOPs are estimated using GPT-NeoX's get_flops function.
Weights & Biases
stability.wandb.io