It is an educational repo demonstrating how to build a real-time Snake game using a Diffusion model. It was inspired by several great papers:
The goal was to create a similar implementation using Snake game due its simple logic. It took near 2 months of different experiments to get a ready-to-play model.
If you don't have GPU, you can use runpod.io(paid service).
After several experiments, I selected the EDM diffusion model for its high performance with small sample steps. DDIM requires much more steps to achieve comparable quality.
Install required dependencies:
pip install -r requirements.txtFirst, obtain the training dataset using one of these methods:
- Download the prepared dataset:
bash scripts/download-dataset.sh- Or generate it manually:
python src/generate_dataset.py --model agent.pth --dataset training_data --recordThen start the training:
python src/train.py --model-type edm --output-prefix models/model --dataset training_data --gen-val-imagesThe model was trained on runpod.io for 32 epochs, taking approximately 27 hours at a cost of $10.
Download the pre-trained model:
git clone https://huggingface.co/juramoshkov/snake-diffusion modelsTo play the game, either:
- Run Play.ipynb locally to play Snake at 1 FPS(it depends on your GPU) 🤓
- Use runpod.io:
- Deploy a Pod (RTX 4090 recommended for best performance)
- Copy and run the contents of scripts/runpod.sh
- Open Play.ipynb

