Skip to main content

Overview

DeploymentSampler handles client-side tokenization via a HuggingFace tokenizer and returns structured SampledCompletion objects with token IDs, logprobs, and completion metadata. Use it in training scripts that need token-level outputs (e.g. GRPO, DPO).
from fireworks.training.sdk import DeploymentSampler

Constructor

from transformers import AutoTokenizer
from fireworks.training.sdk import DeploymentSampler, AdaptiveConcurrencyController

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)

# Adaptive concurrency (recommended) — auto-tunes based on server load
sampler = DeploymentSampler(
    inference_url="https://api.fireworks.ai",
    model=f"accounts/{deploy_mgr.account_id}/deployments/{deployment_id}",
    api_key="<FIREWORKS_API_KEY>",
    tokenizer=tokenizer,
    concurrency_controller=AdaptiveConcurrencyController(initial_window=16),
)
ParameterTypeDescription
inference_urlstrGateway URL for inference completions
modelstrDeployment model path (accounts/<id>/deployments/<id>)
api_keystrFireworks API key
tokenizerPreTrainedTokenizerBaseHuggingFace tokenizer matching the base model
concurrency_controllerAdaptiveConcurrencyController | FixedConcurrencyController | NoneControls how many concurrent HTTP requests are in-flight. None (default) means no limit. See Concurrency Control below.
max_concurrencyint | NoneDeprecated. Use concurrency_controller instead. If set, emits a DeprecationWarning and creates a FixedConcurrencyController internally.

Concurrency Control

sample_with_tokens(n=K) fans out into K individual streaming requests. Without concurrency control, all requests fire simultaneously, which can overload the server. Two controllers are available: Auto-tunes the concurrency window using AIMD (Additive Increase / Multiplicative Decrease) based on the server’s prefill_queue_duration:
from fireworks.training.sdk import AdaptiveConcurrencyController

ctrl = AdaptiveConcurrencyController(
    initial_window=16,        # starting concurrency
    min_window=1,             # minimum window
    max_window=256,           # maximum window
    prefill_queue_target=0.5, # target prefill queue latency (seconds)
)
sampler = DeploymentSampler(..., concurrency_controller=ctrl)

# Between training steps, call step_completed() to trigger window adjustment
summary = ctrl.step_completed()
print(summary)  # {"window": 20, "avg_pq": 0.08, "cache_hit_rate": 0.95, ...}
The controller reads prefill_queue_duration from server response metrics. When the queue is below target, the window grows proportionally. When above, it halves (multiplicative decrease).

FixedConcurrencyController

Static semaphore — use when you know the right concurrency for your deployment:
from fireworks.training.sdk import FixedConcurrencyController

sampler = DeploymentSampler(
    ...,
    concurrency_controller=FixedConcurrencyController(32),
)

sample_with_tokens(...)

Sample completions and return structured results with token IDs. This method is async, so call it with await or wrap it with asyncio.run(...) from synchronous code:
import asyncio

async def main():
    completions = await sampler.sample_with_tokens(
        messages=[{"role": "user", "content": "Solve: 2+2="}],
        n=4,
        max_tokens=1024,
        temperature=0.7,
    )
    for c in completions:
        print(c.full_tokens)       # prompt + completion token IDs
        print(c.prompt_len)        # number of prompt tokens
        print(c.completion_len)    # number of completion tokens
        print(c.text)              # decoded completion text
        print(c.finish_reason)     # "stop", "length", etc.

asyncio.run(main())

Retrieving inference logprobs

For GRPO importance sampling, pass logprobs=True:
import asyncio

async def main():
    completions = await sampler.sample_with_tokens(
        messages=[{"role": "user", "content": "Solve: 2+2="}],
        n=4,
        logprobs=True,
        top_logprobs=1,
    )
    for c in completions:
        print(c.inference_logprobs)  # List[float] or None

asyncio.run(main())

Sequence length filtering

sample_with_tokens supports max_seq_len for automatic filtering:
import asyncio

completions = asyncio.run(
    sampler.sample_with_tokens(
        messages=input_messages,
        n=4,
        max_tokens=1024,
        max_seq_len=8192,  # filter out sequences exceeding this length
    )
)
Two levels of filtering are applied:
  1. Prompt pre-filter: If the tokenized prompt already meets or exceeds max_seq_len, the method returns an empty list immediately — no inference call is made.
  2. Completion post-filter: After sampling, any completion whose full token sequence (prompt + completion) exceeds max_seq_len is silently dropped.

SampledCompletion

Each completion returned by sample_with_tokens:
FieldTypeDescription
textstrDecoded completion text
full_tokensList[int]Prompt + completion token IDs
prompt_lenintNumber of prompt tokens
finish_reasonstr"stop", "length", etc.
completion_lenintNumber of completion tokens
inference_logprobsList[float] | NonePer-token logprobs (when logprobs=True is passed)
logprobs_echoedboolTrue when echo=True was used — logprobs are training-aligned (P+C-1 entries)
routing_matricesList[str] | NoneBase64-encoded per-token routing matrices for MoE Router Replay (R3)