Offline Engine API#

SGLang provides a direct inference engine without the need for an HTTP server, especially for use cases where additional HTTP server adds unnecessary complexity or overhead. Here are two general use cases:

  • Offline Batch Inference

  • Custom Server on Top of the Engine

This document focuses on the offline batch inference, demonstrating four different inference modes:

  • Non-streaming synchronous generation

  • Streaming synchronous generation

  • Non-streaming asynchronous generation

  • Streaming asynchronous generation

Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in custom_server.

Nest Asyncio#

Note that if you want to use Offline Engine in ipython or some other nested loop code, you need to add the following code:

import nest_asyncio

nest_asyncio.apply()

Advanced Usage#

The engine supports vlm inference as well as extracting hidden states.

Please see the examples for further use cases.

Offline Batch Inference#

SGLang offline engine supports batch inference with efficient scheduling.

[1]:
# launch the offline engine
import asyncio

import sglang as sgl
import sglang.test.doc_patch
from sglang.utils import async_stream_and_merge, stream_and_merge

llm = sgl.Engine(model_path="qwen/qwen2.5-0.5b-instruct")
[2025-12-31 08:20:25] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-12-31 08:20:25] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-12-31 08:20:25] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-12-31 08:20:28] INFO server_args.py:1602: Attention backend not specified. Use fa3 backend by default.
[2025-12-31 08:20:28] INFO server_args.py:2481: Set soft_watchdog_timeout since in CI
[2025-12-31 08:20:28] INFO engine.py:153: server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, rl_quant_profile=None, mem_fraction_static=0.835, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, enable_dynamic_chunking=False, max_prefill_tokens=16384, prefill_max_requests=None, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, pp_async_batch_depth=0, stream_interval=1, stream_output=False, random_seed=999448156, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, soft_watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, custom_sigquit_handler=None, log_level='error', log_level_http=None, log_requests=False, log_requests_level=2, log_requests_format='text', crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, fp8_gemm_runner_backend='auto', nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', disable_flashinfer_autotune=False, speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_draft_attention_backend=None, speculative_moe_runner_backend='auto', speculative_moe_a2a_backend=None, speculative_draft_model_quantization=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, enable_multi_layer_eagle=False, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm=None, init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, mamba_scheduler_strategy='no_buffer', mamba_track_interval=256, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method=None, kt_cpuinfer=None, kt_threadpool_count=None, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, dllm_algorithm=None, dllm_algorithm_config=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=8192, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096, 4352, 4608, 4864, 5120, 5376, 5632, 5888, 6144, 6400, 6656, 6912, 7168, 7424, 7680, 7936, 8192], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, enable_return_routed_experts=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, enable_fused_qk_norm_rope=False, enable_precise_embedding_interpolation=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, disaggregation_decode_enable_fake_auto=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, encoder_only=False, language_only=False, encoder_transfer_backend='zmq_to_scheduler', encoder_urls=[], custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, remote_instance_weight_loader_backend='nccl', remote_instance_weight_loader_start_seed_via_transfer_engine=False, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, enable_prefix_mm_cache=False, mm_enable_dp_encoder=False, mm_process_config={}, limit_mm_data_per_request=None, decrypted_config_file=None, decrypted_draft_config_file=None, forward_hooks=None)
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  5.57it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  5.56it/s]

Capturing batches (bs=1 avail_mem=76.73 GB): 100%|██████████| 20/20 [00:00<00:00, 20.99it/s]

Non-streaming Synchronous Generation#

[2]:
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Hello, my name is
Generated text:  Louise. I have a certain amount of money and I spend 2.5 times the amount of money as the previous month. If I spend $500 in the previous month, how much will I have left after 3 months?

Let's break down the problem step by step.

1. In the previous month, Louise spent $500.
2. In the first month, she spent 2.5 times the amount she spent in the previous month. So, she spent $500 \times 2.5$.
3. In the second month, she spent 2.5 times the amount
===============================
Prompt: The president of the United States is
Generated text:  32 years older than the president of the Senate. If the president of the Senate is currently 17 years old, what is the total of their ages now?

To determine the total age of the president of the United States and the president of the Senate, we need to follow these steps:

1. Identify the age of the president of the United States.
2. Identify the age of the president of the Senate.
3. Calculate the total age by adding the ages of both individuals.

Given:
- The president of the Senate is currently 17 years old.
- The president of the United States is 32 years
===============================
Prompt: The capital of France is
Generated text:  ______.
A. Paris
B. London
C. Berlin
D. Madrid
Answer:

A

China has an extremely rich cultural heritage, with over 7,000 known major ancient buildings, 2,000 known significant ancient buildings, 10,000 known minor ancient buildings, and a total of more than ____ ancient buildings.
A. 2,000
B. 20,000
C. 50,000
D. 5,000
Answer:

C

The most common type of hernia is ____

===============================
Prompt: The future of AI is
Generated text:  here, and it’s not just about the future of coding languages and machine learning algorithms. In fact, there are so many other areas where AI is going to shape the future, including healthcare, transportation, and education.
In this blog post, we will explore what AI is and how it is transforming the future of technology. We will also discuss some of the key areas where AI is currently developing, and what the future might hold.
So, let’s dive in!
What is AI?
AI is the application of computers and algorithms to perform tasks that would normally require human intelligence, such as speech recognition, natural language processing, and machine learning

Streaming Synchronous Generation#

[3]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {
    "temperature": 0.2,
    "top_p": 0.9,
}

print("\n=== Testing synchronous streaming generation with overlap removal ===\n")

for prompt in prompts:
    print(f"Prompt: {prompt}")
    merged_output = stream_and_merge(llm, prompt, sampling_params)
    print("Generated text:", merged_output)
    print()

=== Testing synchronous streaming generation with overlap removal ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  [Name], and I'm a [job title] at [company name]. I'm excited to meet you and learn more about you. What can you tell me about yourself? I'm a [insert a short, neutral self-introduction about your personality, skills, or interests]. I'm always looking for new challenges and opportunities to grow and learn. What's your background? I have a [insert a short, neutral self-introduction about your background, such as education, work experience, or hobbies]. I'm always looking for new ways to improve myself and make a positive impact in the world. What's your favorite hobby or activity

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris, the city known for its iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum. It is also home to the French Parliament and the French Academy of Sciences. Paris is a bustling metropolis with a rich cultural heritage and is a major economic and political center in Europe. It is also known for its fashion industry, art scene, and its role in hosting major international events such as the Olympics and the World Cup. The city is also home to many famous landmarks and attractions, including the Palace of Versailles, the Louvre, and the Champs-Élysées. Paris is

Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  likely to be characterized by rapid advancements in areas such as machine learning, natural language processing, and computer vision. Some possible future trends in AI include:

1. Increased use of AI in healthcare: AI is already being used in healthcare to improve patient outcomes, reduce costs, and increase efficiency. As AI technology continues to improve, we can expect to see even more widespread use of AI in healthcare in the coming years.

2. AI in finance: AI is already being used in finance to improve fraud detection, risk assessment, and trading algorithms. As AI technology continues to improve, we can expect to see even more widespread use of AI in finance

Non-streaming Asynchronous Generation#

[4]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

print("\n=== Testing asynchronous batch generation ===")


async def main():
    outputs = await llm.async_generate(prompts, sampling_params)

    for prompt, output in zip(prompts, outputs):
        print(f"\nPrompt: {prompt}")
        print(f"Generated text: {output['text']}")


asyncio.run(main())

=== Testing asynchronous batch generation ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  [insert name], and I'm a [insert character's profession, such as "scientist," "teacher," "writer," "artist," etc.] Here's how I came to be: When I was a [insert age range], I was [insert profession or hobby, such as "reading," "traveling," "singing," etc.]. I loved exploring the world around me and learning about different cultures and beliefs. I became fascinated by [insert the subject or subject matter you would like to discuss] and decided to study it. It was a dream that never came true, but I still feel a strong sense of satisfaction

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris, known for its iconic Notre-Dame Cathedral, Eiffel Tower, and world-renowned food, fashion, and music scenes.

Some additional facts about Paris:
- It is one of the most visited cities in the world, with over 75 million tourists annually.
- It has been a major hub for international affairs and diplomacy for centuries, with the French Republic founded in 1792 and the first formal constitution of a country in 1791.
- The city is known for its rich history and cultural heritage, including numerous historical sites such as the Louvre Museum and Champs-Élysées.


Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  likely to be dominated by four key trends:

1. Increased integration of AI into human decision-making: As AI technology continues to improve and become more integrated into everyday life, we can expect to see more AI-powered systems being integrated into various decision-making processes. This could include decision-making in healthcare, finance, transportation, and other industries, as well as in areas like education and security.

2. Greater emphasis on ethical considerations in AI development and deployment: As AI systems become more advanced and complex, it is likely that we will see more scrutiny and consideration given to their ethical implications. This could involve policies and regulations that promote transparency, accountability,

Streaming Asynchronous Generation#

[5]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

print("\n=== Testing asynchronous streaming generation (no repeats) ===")


async def main():
    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        print("Generated text: ", end="", flush=True)

        # Replace direct calls to async_generate with our custom overlap-aware version
        async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):
            print(cleaned_chunk, end="", flush=True)

        print()  # New line after each prompt


asyncio.run(main())

=== Testing asynchronous streaming generation (no repeats) ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  [insert your name] and I am a [insert your profession or title]. I'm really excited to meet you and I really hope to learn all about you. What brings you here today? I can't wait to hear from you and we can start a great conversation. [insert a brief question or comment about the meeting]. Hi, I'm [insert your name] and I'm a [insert your profession or title]. I've been reading your profile for a while now and I really like the way you describe yourself. I'm really excited to meet you and I hope to learn all about you. [insert a brief question or

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris, the most populous city, and is renowned for its historical landmarks, art, and cuisine. It has a rich cultural history, including aspects of French, Renaissance, Baroque, and Romanticism. The city is also known for its annual festivals and celebrations, including the famous Louvre Museum, the Eiffel Tower, and the Musée d'Orsay. Paris is a UNESCO World Heritage Site and a major financial and economic center. The city is known for its iconic landmarks such as Notre-Dame Cathedral, the Palace of Versailles, and the Louvre Museum, which are UNESCO World Heritage Sites. It is also known

Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  likely to involve significant changes in both the types of tasks it can perform and the methods used to perform them. Here are some possible future trends in AI:

1. Increased integration with human intelligence: AI is likely to become even more integrated with human intelligence, with machines being able to understand and respond to human emotions and physical sensations in new ways.

2. Deep learning and neural networks: Deep learning and neural networks are already being used in many AI applications, but there is a growing trend towards more sophisticated models that can handle complex and high-dimensional data.

3. Enhanced privacy and security: As AI is increasingly used in everyday life, there is
[6]:
llm.shutdown()