Official code repository for the paper Learning Syntax Without Planting Trees: Understanding When and Why Transformers Generalize Hierarchically
- Compatible with Python 3.11 (might work for 3.9+ but not tested)
- Pytorch 2.1.0
- The necessary packages can be install through requirements.txt
We recommend using a conda environment or virtual environment
conda create -n hiergenv python=3.11
conda activate hiergenv
conda install pipconda install pytorch==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidiaWe make some changes to the transformers library to support pruning (adapted from DSP). To install,
cd transformers/
pip install -e .cd ..
pip install -r requirements.txtOpen train_transformers.py and in Line 65, replace WANDB_ENTITY_NAME = "<Insert-Your-Entity-Name>" with your wandb-id.
Training models on the question formation dataset using different training objectives
#"Training with LM Objective"
python train_transformers.py --encoder_n_layers 6 --callback --dataset lm --max_train_steps 300000 --max_grad_norm 1 --eval_every 1000 --seed 42 --tied-embedding
#"Training with Seq2Seq Objective"
python train_transformers.py --not_lm --mode enc_dec --encoder_n_layers 6 --decoder_n_layers 6 --callback --dataset lm --max_train_steps 300000 --max_grad_norm 1 --eval_every 1000 --seed 42 --tied-embedding
#"Training with PrefixLM Objective"
python train_transformers.py --is_prefix_lm --encoder_n_layers 6 --callback --dataset lm --max_train_steps 300000 --max_grad_norm 1 --eval_every 1000 --seed 42 --tied-embedding
#"Training with Classification Objective"
python train_transformers.py --mode enc --dataset lm --encoder_n_layers 6 --callback --eval_every 1000 --callback --causal_encoder --seed 42 --max_train_steps 300000
#"Training with Cloze Completion Objective"
python train_mlm.py --dataset lm --encoder_n_layers 6 --eval_every 1000 --callback --mask_strategy aux --causal_encoder --seed 42Please check scripts/ for examples for other datasets.
First make sure you saved model checkpoints while training the LM model
python train_transformers.py --encoder_n_layers 6 --callback --dataset lm --max_train_steps 300000 --max_grad_norm 1 --eval_every 1000 --seed 42 --tied-embedding --save_every 1000 --save_dir "<PATH TO SAVE CHECKPOINTS>"To prune a specific model checkpoint
# For Train-Prune
python prune_heads_v2.py --model_path <SAVE_DIR>/checkpoint_<CHECKPOINT>.pkl --dataset qf --n_layer 6 --tied-embedding --split_for_pruning "train" --pruning_steps 10000 --pruning_lr 0.05 --l0_penalty 0.015
# For Gen-Prune
python prune_heads_v2.py --model_path <SAVE_DIR>/checkpoint_<CHECKPOINT>.pkl --dataset qf --n_layer 6 --tied-embedding --split_for_pruning "test" --pruning_steps 10000 --pruning_lr 0.05 --l0_penalty 0.015
# For Train\Gen Prune
python prune_heads_v2.py --model_path <SAVE_DIR>/checkpoint_<CHECKPOINT>.pkl --dataset qf --n_layer 6 --tied-embedding --split_for_pruning "train" --find_overfitted_heads --pruning_steps 10000 --pruning_lr 0.05 --l0_penalty 0.015To analyse the full training dynamics i.e. run the three pruning methods for all the checkpoints
python analyse_training_dynamics.py --n_layer 6 --model_path <SAVE_DIR> --last_ckpt 300000 --incr 1000 --tied-embeddings --pruning_steps 10000 --pruning_lr 0.05 --l0_penalty 0.015 --seed 42Note here G1 refers to the CFG and G2 is regular grammar
# Comparing small grammars (12 sentence types)
python pcfg.py --g1_name "agreement_hr" --g2_name "agreement_linear" --save_dir <SAVE_DIR>
# Comparing large grammars (120 sentence types)
python pcfg.py --g1_name "agreement_hr_v4" --g2_name "agreement_linear_v4" --save_dir <SAVE_DIR>To compare the posteriors after applying Bayesian Model Merging to minimize the grammars, supply the --minimize argument, for e.g.
python pcfg.py --g1_name "agreement_hr_v4" --g2_name "agreement_linear_v4" --save_dir <SAVE_DIR> --minimizeFor any clarification, comments, or suggestions feel free to contact me via email at kahuja@cs.washington.edu.
@article{ahuja2024learning,
title={Learning Syntax Without Planting Trees: Understanding When and Why Transformers Generalize Hierarchically},
author={Kabir Ahuja and Vidhisha Balachandran and Madhur Panwar and Tianxing He and Noah A. Smith and Navin Goyal and Yulia Tsvetkov},
year={2024},
eprint={2404.16367},
archivePrefix={arXiv},
primaryClass={cs.CL}
}The code is adapted from the Structural Grokking by Shikhar Murty. The question formation and tense reinflection datasets are from McCoy et al. 2020 and German question formation and passivization are from Mueller et al. 2022.