model.predict_batch
K candidate rollouts sharing one starting frame, fused into one server-side forward pass. Built for visual MPC — score K candidates and pick the best.
Signature
def predict_batch( self, *, start_frame: np.ndarray | PIL.Image.Image | bytes | str | Path, actions: np.ndarray (K, T, action_dim) | list | …, num_steps: int | None = None, guidance: float | None = None, seed: int = 0,) -> BatchRolloutInput — actions
Required shape: (K, T, action_dim) — three dimensions, K candidate
sequences each of length T at action_dim per frame.
| Type | What it does |
|---|---|
np.ndarray (K, T, action_dim) | split along K, each saved as .npy |
nested list K × T × action_dim | converted via np.asarray |
list of K np.ndarray (T, action_dim) | each coerced individually |
list of K bytes | each treated as already-encoded .npy |
Wrong shape (e.g. 2-D, or K=0) raises dream.InputValidationError
client-side.
The result — BatchRollout
@dataclassclass BatchRollout: rollouts: list[Rollout] # K Rollout objects, submission order preserved request_id: str batch_size: int # K engine_wall_ms: float # fused-batch server wall — single value cost_usd: float # K × per-rollout cost customer_id: strBatchRollout is iterable and indexable:
for r in batch: print(r.cost_usd) # 0 — per-rollout cost is rolled up to the batchbatch[0].mp4_bytesbatch.cost_usd # total across K rolloutsbatch.wall_s # one engine forward, not KPer-rollout cost_usd is 0.0; the full charge is on the batch
parent. Per-rollout engine_wall_ms mirrors the batch's (the K
candidates share one fused forward).
Visual MPC pattern
import numpy as npimport dream client = dream.Client()model = client.models.get("dreamdojo-2b-gr1") # K=8 candidate action sequences from your plannercandidates = sample_action_candidates(K=8, T=48, action_dim=384)# shape: (8, 48, 384) float32 batch = model.predict_batch( start_frame=current_frame, actions=candidates, seed=0,) print(f"K={batch.batch_size}, total cost ${batch.cost_usd}, " f"wall {batch.wall_s:.2f}s") # Score each candidate against your task rewardscores = [score(r.frames) for r in batch]best = batch[scores.index(max(scores))]best.save("best.mp4")For DreamDojo on H100, K=8 takes ~3.2 s wall (only ~25% slower than K=1 because the forward fuses; transit is the same). Cost scales 8×.
Cost arithmetic
Roughly K × per-rollout cost. The per-rollout cost on GR-1 is
$0.0245 (49 frames × $0.0005, including the conditioning frame
the engine bills for). For K=8 that's about $0.196.
The server amortizes the shared start frame across the K candidates,
so the exact batch.cost_usd may come in slightly lower than the
back-of-envelope number. Trust the value the SDK returns.
Stripe meters the batch as a single event.
Errors
Same set as predict, plus:
dream.InputValidationErrorifactions.shape[2] != model.action_dimdream.InputValidationErrorfor K=0 or empty listdream.InputValidationErrorifactionsis 2-D (usepredictfor that)