Adding a New Attention Backend¶
FastVideo allows integrating new attention mechanisms easily. This guide walks you through adding a new backend (e.g., MyNewAttn).
1. Implement the Backend (Python)¶
Create a new file in fastvideo/attention/backends/ (e.g., mynew_attn.py).
Your implementation should inherit from AttentionBackend defined in abstract.py.
# fastvideo/attention/backends/mynew_attn.py
import torch
from .abstract import AttentionBackend
# Import the context manager to access metadata (optional)
from fastvideo.forward_context import get_forward_context
# Import compiled kernel if applicable (see Section 2)
try:
# Import from the top-level package
from fastvideo_kernel import my_compiled_attn_func
except ImportError:
my_compiled_attn_func = None
class MyNewAttnBackend(AttentionBackend):
def process_inputs(self, q, k, v, **kwargs):
# Pre-process inputs if necessary
return q, k, v
def forward(self, q, k, v, **kwargs):
# Optional: Access extra metadata passed via ForwardContext
# Only needed if your backend requires global state (e.g. window_size)
try:
context = get_forward_context()
metadata = context.attn_metadata
# Example: window_size = metadata.window_size
except (AssertionError, AttributeError):
# Handle case where context is not set (e.g. standard inference)
pass
if my_compiled_attn_func is not None:
return my_compiled_attn_func(q, k, v)
else:
# Fallback implementation (e.g., Triton or pure PyTorch)
return self.fallback_impl(q, k, v)
2. Passing Extra Information via ForwardContext (Optional)¶
FastVideo uses a ForwardContext to pass global metadata (like current timestep, batch info, or custom attention configurations) to attention backends without changing the forward signature of every layer. This is optional and only required if your backend needs dynamic per-step information.
To use this:
1. Set Context: In your pipeline or generation loop, use the set_forward_context context manager.
2. Access Context: Inside your attention backend, use get_forward_context().
See docs/attention/sta/index.md (Sliding Tile Attention) for an example of how complex configuration (window sizes) is passed this way.
3. Adding Compiled Kernels (C++/CUDA)¶
If your backend requires custom CUDA kernels, you need to add them to the fastvideo-kernel package.
A. Add Source Files¶
Place your kernel implementation files in fastvideo-kernel/csrc/attention/.
* mynew_attn.cu (CUDA implementation)
* mynew_attn.h (Optional headers)
B. Register in Extension¶
Update fastvideo-kernel/csrc/common_extension.cpp to expose your function to Python.
// 1. Declare external function
#ifdef COMPILE_MYNEW_ATTN
extern torch::Tensor mynew_attn_forward(torch::Tensor q, torch::Tensor k, torch::Tensor v);
#endif
// 2. Register in module
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// ... other kernels ...
#ifdef COMPILE_MYNEW_ATTN
m.def("mynew_attn_fwd", torch::wrap_pybind_function(mynew_attn_forward), "My New Attention Forward");
#endif
}
C. Update CMakeLists.txt¶
Update fastvideo-kernel/CMakeLists.txt to compile your new files.
Case 1: General CUDA Kernel (Runs on all GPUs)
Add your source file directly to EXTENSION_SOURCES and define the compilation flag.
# Add to EXTENSION_SOURCES
list(APPEND EXTENSION_SOURCES csrc/attention/mynew_attn.cu)
# Add compilation definition for common_extension.cpp
list(APPEND COMPILE_DEFS COMPILE_MYNEW_ATTN)
Case 2: ThunderKittens Kernel (Hopper H100 Only)
If your kernel uses ThunderKittens (TK), it requires specific architecture flags (sm_90a). Add it inside the ENABLE_TK_KERNELS block.
if(ENABLE_TK_KERNELS)
# Add source only if TK is enabled
list(APPEND EXTENSION_SOURCES csrc/attention/mynew_attn_tk.cu)
# Add definition to guard registration
list(APPEND COMPILE_DEFS TK_COMPILE_MYNEW_ATTN)
endif()
D. Expose in Python Ops¶
Update fastvideo-kernel/python/fastvideo_kernel/ops.py to make the function importable and handle fallbacks gracefully.
# fastvideo-kernel/python/fastvideo_kernel/ops.py
# Try to load C++ extension symbols
try:
from fastvideo_kernel._C import fastvideo_kernel_ops
mynew_attn_fwd = getattr(fastvideo_kernel_ops, "mynew_attn_fwd", None)
except ImportError:
mynew_attn_fwd = None
def my_compiled_attn_func(q, k, v):
# Runtime check: use C++ kernel if available, else fallback
if mynew_attn_fwd is not None:
return mynew_attn_fwd(q, k, v)
else:
# Call Triton/Python fallback
return mynew_attn_triton(q, k, v)
E. Expose in Package Init¶
Update fastvideo-kernel/python/fastvideo_kernel/__init__.py to export the function.
from fastvideo_kernel.ops import (
my_compiled_attn_func,
# ...
)
__all__ = [
"my_compiled_attn_func",
# ...
]
4. Register the Backend¶
Update fastvideo/attention/backends/__init__.py to export your new class.
5. Platform Integration¶
If your backend requires specific platform checks (e.g., checking for H100 support), handle that in fastvideo/platforms/cuda.py or within your backend's __init__.
6. Add Documentation¶
Create a new documentation page for your backend to explain its usage, installation (if custom kernels are needed), and features.
- Create Directory:
docs/attention/mynew_attn/ - Create Index:
docs/attention/mynew_attn/index.md - Update Navigation: Add an entry to
mkdocs.ymlunder the "Attention" tab.
Checklist¶
- Created
fastvideo/attention/backends/mynew_attn.py. - (Optional) Added CUDA kernels in
fastvideo-kernel/csrc/attention/. - (Optional) Updated
common_extension.cppandCMakeLists.txt. - (Optional) Exposed kernel in
fastvideo-kernel/python/fastvideo_kernel/ops.py. - (Optional) Exported kernel in
fastvideo-kernel/python/fastvideo_kernel/__init__.py. - Implemented
forwardmethod respecting the standard signature. - Added unit tests in
tests/. - Added documentation in
docs/attention/and updatedmkdocs.yml.