English | ็ฎไฝไธญๆ
Online Documentation: https://datawhalechina.github.io/torch-rechub/
Torch-RecHub โโ Build production-grade recommender systems in 10 lines of code. 30+ mainstream models out-of-the-box, one-click ONNX deployment, letting you focus on business instead of engineering.
- Modular Design: Easy to add new models, datasets, and evaluation metrics.
- Based on PyTorch: Leverages PyTorch's dynamic graph and GPU acceleration capabilities. Supports NVIDIA GPU and Huawei Ascend NPU.
- Rich Model Library: Covers 30+ classic and cutting-edge recommendation algorithms (Matching, Ranking, Multi-task, Generative Recommendation, etc.).
- Standardized Pipeline: Provides unified data loading, training, and evaluation workflows.
- Easy Configuration: Adjust experiment settings via config files or command-line arguments.
- Reproducibility: Designed to ensure reproducible experimental results.
- ONNX Export: Export trained models to ONNX format for seamless production deployment.
- Cross-engine Data Processing: Support for PySpark-based data processing and transformation, facilitating deployment in big data pipelines.
- Experiment Visualization & Tracking: Built-in unified integration for WandB, SwanLab, and TensorBoardX.
- ๐ฅ Torch-RecHub - A Lightweight, Efficient, and Easy-to-use PyTorch Recommender Framework
- Python 3.9+
- PyTorch 1.7+ (CUDA-enabled version recommended for GPU acceleration)
- NumPy
- Pandas
- SciPy
- Scikit-learn
Stable Version (Recommended):
# Install PyTorch matching your device
pip install torch # CPU
pip install torch --index-url https://download.pytorch.org/whl/cu121 # GPU (CUDA 12.1)
pip install torch torch-npu # NPU (Huawei Ascend, requires torch-npu >= 2.5.1)
pip install torch-rechubLatest Version:
# Install uv first (if not already installed)
pip install uv
# Clone and install
git clone https://github.com/datawhalechina/torch-rechub.git
cd torch-rechub
# Install PyTorch matching your device
uv pip install torch # CPU
uv pip install torch --index-url https://download.pytorch.org/whl/cu121 # GPU (CUDA 12.1)
uv pip install torch torch-npu # NPU (Huawei Ascend, requires torch-npu >= 2.5.1)
uv syncInstall an extra group with uv sync --extra <name> or pip install "torch-rechub[<name>]".
annoy: Adds Annoy-based approximate nearest neighbor indexing for retrieval serving.faiss: Adds FAISS-based vector indexing for high-performance retrieval experiments.milvus: Adds Milvus client support for external vector database serving workflows.bigdata: Adds PyArrow support for Parquet-based data loading and big-data preprocessing.onnx: Adds ONNX export, runtime inference, and model conversion dependencies.visualization: Adds model graph visualization support with TorchView and Graphviz.tracking: Adds WandB, SwanLab, and TensorBoardX integrations for experiment tracking.dev: Adds testing, linting, typing, and local development tooling.
Here's a simple example of training a model (e.g., DSSM) on the MovieLens dataset:
# Clone the repository (if using latest version)
git clone https://github.com/datawhalechina/torch-rechub.git
cd torch-rechub
uv sync
# Run matching example (cd into the script directory first, as scripts use relative data paths)
cd examples/matching
python run_ml_dssm.py
# Or with custom parameters:
python run_ml_dssm.py --model_name dssm --device 'cuda:0' --learning_rate 0.001 --epoch 50 --batch_size 4096 --weight_decay 0.0001 --save_dir 'saved/dssm_ml-100k'
# Run ranking example
cd ../ranking
python run_criteo.pyAfter training, model files will be saved in the saved/dssm_ml-100k directory (or your configured directory).
torch-rechub/ # Root directory
โโโ README.md # Project documentation
โโโ pyproject.toml # Project configuration and dependencies
โโโ torch_rechub/ # Core library
โ โโโ basic/ # Basic components
โ โ โโโ activation.py # Activation functions
โ โ โโโ features.py # Feature engineering
โ โ โโโ layers.py # Neural network layers
โ โ โโโ loss_func.py # Loss functions
โ โ โโโ metric.py # Evaluation metrics
โ โโโ models/ # Recommendation model implementations
โ โ โโโ matching/ # Matching models (DSSM/MIND/GRU4Rec etc.)
โ โ โโโ ranking/ # Ranking models (WideDeep/DeepFM/DIN etc.)
โ โ โโโ multi_task/ # Multi-task models (MMoE/ESMM etc.)
โ โโโ trainers/ # Training frameworks
โ โ โโโ ctr_trainer.py # CTR prediction trainer
โ โ โโโ match_trainer.py # Matching model trainer
โ โ โโโ mtl_trainer.py # Multi-task learning trainer
โ โโโ utils/ # Utility functions
โ โโโ data.py # Data processing utilities
โ โโโ match.py # Matching utilities
โ โโโ mtl.py # Multi-task utilities
โ โโโ onnx_export.py # ONNX export utilities
โโโ examples/ # Example scripts
โ โโโ matching/ # Matching task examples
โ โโโ ranking/ # Ranking task examples
โ โโโ generative/ # Generative recommendation examples (HSTU, HLLM, etc.)
โโโ docs/ # Documentation (VitePress, multi-language)
โโโ tutorials/ # Jupyter tutorials
โโโ tests/ # Unit tests
โโโ config/ # Configuration files
โโโ scripts/ # Utility scripts
The framework currently supports 30+ mainstream recommendation models:
Details
| Model | Paper | Description |
|---|---|---|
| DeepFM | IJCAI 2017 | FM + Deep joint training |
| Wide&Deep | DLRS 2016 | Memorization + Generalization |
| DCN | KDD 2017 | Explicit feature crossing |
| DCN-v2 | WWW 2021 | Enhanced cross network |
| DIN | KDD 2018 | Attention for user interest |
| DIEN | AAAI 2019 | Interest evolution modeling |
| BST | DLP-KDD 2019 | Transformer for sequences |
| AFM | IJCAI 2017 | Attentional FM |
| AutoInt | CIKM 2019 | Auto feature interaction learning |
| FiBiNET | RecSys 2019 | Feature importance + Bilinear |
| DeepFFM | RecSys 2019 | Field-aware FM |
| EDCN | KDD 2021 | Enhanced DCN |
Details
| Model | Paper | Description |
|---|---|---|
| DSSM | CIKM 2013 | Classic two-tower model |
| YoutubeDNN | RecSys 2016 | YouTube deep retrieval |
| YoutubeSBC | RecSys 2019 | Sampling bias correction |
| MIND | CIKM 2019 | Multi-interest dynamic routing |
| SINE | WSDM 2021 | Sparse interest network |
| GRU4Rec | ICLR 2016 | GRU for sequences |
| SASRec | ICDM 2018 | Self-attentive sequential |
| NARM | CIKM 2017 | Neural attentive session |
| STAMP | KDD 2018 | Short-term attention memory priority |
| ComiRec | KDD 2020 | Controllable multi-interest |
Details
| Model | Paper | Description |
|---|---|---|
| ESMM | SIGIR 2018 | Entire space multi-task |
| MMoE | KDD 2018 | Multi-gate Mixture-of-Experts |
| PLE | RecSys 2020 | Progressive Layered Extraction |
| AITM | KDD 2021 | Adaptive Information Transfer |
| SharedBottom | - | Classic shared bottom |
Details
| Model | Paper | Description |
|---|---|---|
| HSTU | Meta 2024 | Hierarchical Sequential Transduction Units, powering Meta's trillion-parameter RecSys |
| HLLM | 2024 | Hierarchical LLM for recommendation, combining LLM semantic understanding |
| TIGER | NeurIPS 2023 | T5-based generative retrieval for recommendation with semantic ID generation |
The framework provides built-in support or preprocessing scripts for the following common datasets:
- MovieLens
- Amazon
- Criteo
- Avazu
- Census-Income
- BookCrossing
- Ali-ccp
- Yidian
- ...
The expected data format is typically an interaction file containing:
- User ID
- Item ID
- Rating (optional)
- Timestamp (optional)
For specific format requirements, please refer to the example code in the tutorials directory. The examples/ directory already includes sample datasets in each scenario subdirectory, which you can use directly for quick experimentation and debugging.
You can easily integrate your own datasets by ensuring they conform to the framework's data format requirements or by writing custom data loaders.
All model usage examples can be found in /examples
from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator
dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
ctr_trainer.export_onnx("deepfm.onnx")from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer
task_types = ["classification", "classification"]
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
mtl_trainer = MTLTrainer(model)
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
mtl_trainer.export_onnx("mmoe.onnx")from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import MatchDataGenerator
dg = MatchDataGenerator(x, y)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
model = DSSM(user_features, item_features, temperature=0.02,
user_params={
"dims": [256, 128, 64],
"activation": 'prelu',
},
item_params={
"dims": [256, 128, 64],
"activation": 'prelu',
})
match_trainer = MatchTrainer(model)
match_trainer.fit(train_dl)
match_trainer.export_onnx("dssm.onnx")
# For two-tower models, you can export user and item towers separately:
# match_trainer.export_onnx("user_tower.onnx", mode="user")
# match_trainer.export_onnx("item_tower.onnx", mode="item")# Visualize model architecture (Requires: pip install torch-rechub[visualization])
graph = ctr_trainer.visualization(depth=4) # Generate computation graph
ctr_trainer.visualization(save_path="model.pdf", dpi=300) # Save as high-resolution PDFThanks to all contributors!
We welcome contributions in all forms! Please refer to CONTRIBUTING.md for detailed contribution guidelines.
We also welcome bug reports and feature suggestions through Issues.
This project is licensed under the MIT License.
If you use this framework in your research or work, please consider citing:
@misc{torch_rechub,
title = {Torch-RecHub},
author = {Datawhale},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/datawhalechina/torch-rechub}},
note = {A PyTorch-based recommender system framework providing easy-to-use and extensible solutions}
}- Project Lead: 1985312383
- GitHub Discussions
Last updated: [2026-03-20]

