Capability-aware scheduler for zenmetrics-api. Sits one layer above the
opaque Metric constructor and owns four things that previously every
caller hand-rolled:
- Backend selection. Picks GPU-full / GPU-strip / cvvdp StripPair / CPU per task based on a persisted, machine-specific benchmark cache.
- OOM recovery. When the chosen backend OOMs (predicted or bubble-up at runtime), the orchestrator walks a fallback ladder instead of failing the task.
- Cached-reference dispatch. Auto-detects "many distorted, one
reference" workloads via xxhash3 and promotes them to the
set_reference+compute_with_cached_referenceAPI for the 1.5–3× speedup. Callers who want zero overhead can pre-upload references explicitly. - Concurrency. A small worker pool (one GPU worker + N CPU workers)
handles streaming
submit/polland batchrun_allAPIs without the caller spinning up threads.
Status: the orchestrator is the recommended entry point for any caller that scores more than one pair at a time. Single-shot scoring can still use
zenmetrics-apidirectly; everything else — sweeps, picker training, RD curves, anything that batches — should go through the orchestrator.
use zenmetrics_api::MetricKind;
use zenmetrics_orchestrator::{
Orchestrator, OrchestratorConfig, Task, TaskData,
};
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let mut orch = Orchestrator::new(OrchestratorConfig::default())?;
// Optional but recommended: run the quick-bench so the chooser has
// real perf+VRAM numbers for this machine. `warm()` is idempotent —
// it only runs when the cache is missing or stale.
orch.warm()?;
let task = Task {
task_id: 0,
ref_data: TaskData::Srgb8(reference_bytes),
dist_data: TaskData::Srgb8(distorted_bytes),
width: 1024,
height: 1024,
metric: MetricKind::Ssim2,
params: None,
ref_hash: 0,
};
let result = orch.run_single(task);
match result.outcome {
Ok(score) => println!("score = {}", score.value),
Err(e) => eprintln!("failed: {e}"),
}
# Ok(()) }
# fn main() {}
# let reference_bytes: Vec<u8> = vec![0; 1024 * 1024 * 3];
# let distorted_bytes: Vec<u8> = reference_bytes.clone();| Caller shape | Use |
|---|---|
One (ref, dist) per process, no fallback needed |
zenmetrics-api |
| Batch of tasks (sweep, RD curve, picker training) | orchestrator |
| Streaming workload (submit-as-you-go) | orchestrator |
| OOM-tolerant scoring (e.g. unknown image sizes) | orchestrator |
| Need to share warm references across many distorts | orchestrator |
The orchestrator's overhead on a single task is small (one chooser call
- one cache touch + a pool spawn if it's the first submit) — typically
< 5 ms above the underlying metric. For workloads with many tasks the
amortised cost is much lower because the pool keeps warm
Metricinstances across same-signature tasks.
# use zenmetrics_api::MetricKind;
# use zenmetrics_orchestrator::{Orchestrator, OrchestratorConfig, Task, TaskData};
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let mut orch = Orchestrator::new(OrchestratorConfig::default())?;
orch.warm()?;
let tasks: Vec<Task> = (0..100)
.map(|i| Task {
task_id: i as u64,
ref_data: TaskData::Srgb8(reference_bytes.clone()),
dist_data: TaskData::Srgb8(distorted_variants[i].clone()),
width: 1024,
height: 1024,
metric: MetricKind::Cvvdp,
params: None,
ref_hash: 0,
})
.collect();
// Results arrive in COMPLETION order (not submission order). Correlate
// via `task_id`.
for result in orch.run_all(tasks) {
match result.outcome {
Ok(score) => println!("task {}: {}", result.task_id, score.value),
Err(e) => eprintln!("task {} failed: {e}", result.task_id),
}
}
# Ok(()) }
# let reference_bytes: Vec<u8> = vec![0; 1024 * 1024 * 3];
# let distorted_variants: Vec<Vec<u8>> = (0..100).map(|_| reference_bytes.clone()).collect();When the caller is producing tasks as it goes (e.g. fetching distorted images from R2), the streaming API lets the orchestrator overlap I/O with compute.
# use zenmetrics_api::MetricKind;
# use zenmetrics_orchestrator::{Orchestrator, OrchestratorConfig, Task, TaskData};
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let mut orch = Orchestrator::new(OrchestratorConfig::default())?;
orch.warm()?;
let mut outstanding = 0;
for next_task in pending_tasks {
orch.submit(next_task)?;
outstanding += 1;
// Drain any completed results without blocking.
while let Some(result) = orch.poll_any() {
process_result(result);
outstanding -= 1;
}
}
// Drain the tail — block until every submitted task is done.
while outstanding > 0 {
if let Some(result) = orch.poll_any_blocking() {
process_result(result);
outstanding -= 1;
} else {
break;
}
}
# Ok(()) }
# struct StubTask;
# let pending_tasks: Vec<Task> = Vec::new();
# fn process_result(_: zenmetrics_orchestrator::TaskResult) {}Three OOM signals trigger the ladder:
- Predictive avoidance. The chooser consults
capability.metrics.<metric>.vram_mib_atand rejects backends whose predicted VRAM exceedslive_vram_free × 0.85. - Constructor OOM.
Metric::new_with_memory_modereturningError::TooBigForFull { needed, cap }. The orchestrator marks the(backend, size_pixels)cell as failed, persists the cache, and re-asks the chooser. - Runtime OOM.
Metric::compute_srgb_u8bubbling a cubecl runtime OOM. Same recovery as constructor OOM — drop the metric, record the cell, retry next backend.
Ladder order (per metric):
GpuFull → GpuStrip → (Cvvdp only: GpuStripPair) → Cpu → FullyExhausted
Each downgrade updates the persistent capability cache so the same machine never tries the failing combination twice.
When a caller absolutely needs a specific backend (e.g. a parity test
that wants to fail loudly if GPU isn't available), construct the
orchestrator with oom_retry_strategy: OomRetry::NoFallback. The first
OOM bubbles up as OrchestratorError::FullyExhausted with one entry in
backends_attempted — no ladder, no surprise CPU fallback.
Note: as of Phase 6, the
OomRetryknob lives on the design surface but the chooser exposes the same selectivity via per-metricKnownOomCellentries in the cache. A future minor release will plumbOomRetryend-to-end.
For workloads with many distorted variants of one reference (the
canonical sweep shape), the orchestrator promotes the dispatch to the
metric's set_reference + compute_with_cached_reference API. This
saves the per-task reference upload (typically 4–8 ms at 4096²) and
re-uses any preprocessed GPU-resident state inside the metric (cvvdp's
XYB-transformed reference cube, ssim2's blurred reference pyramid,
etc.).
The orchestrator hashes ref bytes via xxhash3_64 (~5–15 GB/s on a
modern CPU, ~4–8 ms at 4096²) and consults a sliding window of the last
32 (metric, w, h, hash) tuples. On a hit, the next task dispatches
through the cached-ref API.
# use zenmetrics_api::MetricKind;
# use zenmetrics_orchestrator::{Orchestrator, OrchestratorConfig, Task, TaskData};
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let mut orch = Orchestrator::new(OrchestratorConfig::default())?;
orch.warm()?;
// 100 distorted variants of the SAME reference → orchestrator
// auto-detects after the first task and switches to cached-ref
// dispatch for the rest.
let tasks: Vec<Task> = (0..100)
.map(|i| Task {
task_id: i as u64,
ref_data: TaskData::Srgb8(reference.clone()),
dist_data: TaskData::Srgb8(variants[i].clone()),
width: 1024,
height: 1024,
metric: MetricKind::Ssim2,
params: None,
ref_hash: 0,
})
.collect();
let results: Vec<_> = orch.run_all(tasks).collect();
// Audit the cached-ref auto-detect: 99 / 100 tasks should have hit.
let stats = orch.cached_ref_stats();
println!("cached-ref hits: {} / misses: {}", stats.hit_count, stats.miss_count);
# Ok(()) }
# let reference: Vec<u8> = vec![0; 1024 * 1024 * 3];
# let variants: Vec<Vec<u8>> = (0..100).map(|_| reference.clone()).collect();When the caller wants zero hashing overhead (e.g. processing 4096²
references where the 8 ms hash adds up across thousands of tasks),
upload once and pass a TaskRefHandle:
# use zenmetrics_api::MetricKind;
# use zenmetrics_orchestrator::{Orchestrator, OrchestratorConfig, Task, TaskData};
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let mut orch = Orchestrator::new(OrchestratorConfig::default())?;
orch.warm()?;
let ref_handle = orch.upload_reference(
&reference_bytes,
4096,
4096,
MetricKind::Cvvdp,
)?;
for (i, variant) in variants.iter().enumerate() {
let task = Task {
task_id: i as u64,
ref_data: TaskData::PreUploaded(ref_handle.clone()),
dist_data: TaskData::Srgb8(variant.clone()),
width: 4096,
height: 4096,
metric: MetricKind::Cvvdp,
params: None,
ref_hash: 0,
};
orch.submit(task)?;
}
// drain results …
orch.drop_reference(ref_handle);
# Ok(()) }
# let reference_bytes: Vec<u8> = vec![0; 4096 * 4096 * 3];
# let variants: Vec<Vec<u8>> = vec![reference_bytes.clone(); 4];Each per-metric CPU reference adapter is opt-in via a Cargo feature so downstream consumers don't pay for crates they don't use:
| Feature flag | Pulls | Notes |
|---|---|---|
cpu-cvvdp |
cvvdp |
In-tree port; full cached-ref via warm_reference. |
cpu-ssim2 |
ssimulacra2 |
Crates.io reference; cached-ref re-uses cached ref bytes (no separate warm-ref API upstream). |
cpu-dssim |
dssim-core |
True cached-ref via DssimImage. |
cpu-butter |
butteraugli |
Cached-ref recomputes from bytes (upstream has no separate warm path). |
cpu-zensim |
zensim |
Cached-ref recomputes. |
cpu-all |
All five | Convenience bundle for production workers. |
Iwssim has no clean upstream CPU reference — selecting it as the CPU
fallback surfaces OrchestratorError::CpuMetricUnavailable and the
ladder advances. See docs/CPU_BACKENDS.md for
the per-metric mapping, RAM characteristics, and cached-ref semantics.
The orchestrator writes a per-machine TOML profile to
$XDG_CACHE_HOME/zenmetrics/capability_<short_hash>.toml (default
~/.cache/zenmetrics/). The <short_hash> is the first 16 chars of
sha256(gpu_model || driver_version || cpu_brand || cpu_features), so
the same machine always lands on the same file but different machines
co-exist without stomping each other.
- First run on a machine with no cache file.
- Time-based: cache >
OrchestratorConfig::cache_validityold (default 7 days). Wall-clock time only — process restarts don't invalidate. - Driver / hardware change: a fresh
detect_gpu()returns a different model or driver version than the cached snapshot.
warm() runs the bench only when one of these triggers fires; it's
safe to call on every process startup.
To force a re-bench (e.g. after upgrading a metric crate), call
Orchestrator::bench() directly — it unconditionally overwrites the
metric profile table.
A single capability profile can be shared across a fleet of identical
boxes by uploading the TOML to R2 / S3 and pointing
OrchestratorConfig::cache_dir at a writable local path that's
pre-populated from the shared key at boot. The orchestrator's
hardware-change check still fires on startup, so a mismatched machine
won't trust a foreign profile — it'll re-bench locally instead.
scripts/sweep/onstart_orchestrator.sh in the zenmetrics repo
implements this exact pattern for vast.ai workers.
pub struct OrchestratorConfig {
pub cache_dir: PathBuf, // default ~/.cache/zenmetrics/
pub cache_validity: Duration, // default 7 days
}Pool-level knobs (max_parallel_cpu, vram_safety_floor_mib,
vram_sample_interval_ms, vram_stall_ms) live on the separate
PoolConfig struct — pass via Orchestrator::set_pool_config(cfg)
before the first submit / run_all / upload_reference call.
Defaults are sensible for desktop GPUs with a display compositor; bare-
metal datacenter GPUs can usually push vram_safety_floor_mib lower
than 200.
default — capability detection + TOML cache only (no scheduling).
bench — quick-bench harness + chooser. Pulls zenmetrics-api +
cvvdp-gpu + xxhash + rayon.
cuda — single-task executor + worker pool (implies bench).
cpu-cvvdp,
cpu-ssim2,
cpu-dssim,
cpu-butter,
cpu-zensim — per-metric CPU reference adapters (each implies bench).
cpu-all — every CPU adapter at once.
Production sweep workers typically build with --features cuda,cpu-all.
Light callers that just want the capability detection (e.g. a CI sanity
check that the machine has the expected GPU model) build with default
features.
This crate (and the rest of the zenmetrics GPU stack) pins cubecl to
the imazen/zenforks-cubecl
rename of the upstream
tracel-ai/cubecl v0.10.x. The
11 patched-or-transitive crates ship to crates.io under the
zenforks-cubecl-* namespace; their [lib] names stay as the
upstream cubecl_*, so source code reads use cubecl_runtime::*;
unchanged.
The renamed crates carry these patches on top of upstream v0.10.0:
- Pinned-host-buffer fast path (
zenforks-cubecl-runtime, 0.10.0+) - PTX cache widening (
zenforks-cubecl-cuda, 0.10.1+) - Metal
Atomic<f32>capability honesty (zenforks-cubecl-wgpu, 0.10.1+)
Pinned-upload (patch 1) — why. CUDA's cuMemcpyHtoDAsync from
pageable host memory caps at ~5-6 GB/s because the driver internally
stages through a hidden pinned bounce buffer. Allocating the host
buffer with cuMemAllocHost_v2 (= "pinned" / "page-locked") lets the
driver DMA directly at 12-25 GB/s on PCIe 4.0. cvvdp-gpu's 12 MP
warm-ref bench goes from 95 ms to 22 ms — a ~4.3x speedup — purely
from this patch. See docs/CUBECL_GOTCHAS.md §G6.5 in the workspace
root for the full diagnosis.
The pinned-upload patch:
- Adds
ComputeClient::create_from_slice_pinned(&[u8]) -> Handlefor hot per-call uploads that want to skip thecaller -> pageable Vec<u8> -> pinned Bytesextra memcpy. - Adds
ComputeClient::reserve_staging(&[usize]) -> Vec<Bytes>for pre-reserving pinned slabs the caller fills in place. - Adds
ComputeClient::create_tensors_from_slices_pinnedfor batch variants. - Transparently routes the existing
create_from_slice/create_tensor_from_slice/create_tensors_from_slicespaths throughComputeServer::staging, so any caller of the default API gets the pinned-upload speedup without source changes.
Upstream PR. The pinned-upload patch is drafted as a PR against
tracel-ai/cubecl (referenced as draft PR #1334). See
../zenmetrics-api/docs/PINNED_UPLOAD_UPSTREAM_PR.md
for the full diff, bench numbers, and submission steps. The PTX cache
and Metal atomic patches are tracked in the same docs directory.
Workspace pin (from the root Cargo.toml):
[workspace.dependencies]
# 11 renamed crates from zenforks-cubecl (carry the patches above)
cubecl = { package = "zenforks-cubecl", version = "0.10.1" }
cubecl-runtime = { package = "zenforks-cubecl-runtime", version = "0.10.1" }
cubecl-core = { package = "zenforks-cubecl-core", version = "0.10.1" }
cubecl-cuda = { package = "zenforks-cubecl-cuda", version = "0.10.1" }
cubecl-cpu = { package = "zenforks-cubecl-cpu", version = "0.10.1" }
cubecl-wgpu = { package = "zenforks-cubecl-wgpu", version = "0.10.1" }
cubecl-hip = { package = "zenforks-cubecl-hip", version = "0.10.1" }
cubecl-cpp = { package = "zenforks-cubecl-cpp", version = "0.10.1" }
# 5 leaf crates that don't need patching — consumed directly from upstream:
cubecl-common = "0.10.0"
cubecl-ir = "0.10.0"Downstream opt-in. If you're consuming zenmetrics-orchestrator
from a separate workspace, copy the [workspace.dependencies] block
above into your workspace Cargo.toml. All 11 patched-or-transitive
crates must point at the same 0.10.x of the rename, and the 5
non-renamed crates stay on upstream's 0.10.0. Backends without a
pinned-memory concept (cubecl-wgpu Metal/Vulkan, cubecl-cpu) ignore
staging and behave exactly as stock cubecl, so the patch is safe to
apply unconditionally even when CUDA isn't in use.
Sunset plan. When upstream cubecl ships a release that contains
these patches (e.g., once PR #1334 merges and is released), the
specific patch we carry on top of upstream goes away and our next
zenforks-cubecl-* release drops it. The renamed crates stay on
crates.io for ABI stability — downstream pins don't churn — but
become a near-zero-diff rename of upstream until the next patch lands.
See ../zenmetrics-api/docs/ZENFORKS_CUBECL_STRATEGY.md
for the full maintenance playbook.
See docs/MIGRATION_FROM_API.md for
side-by-side code samples.
Dual-licensed under AGPL-3.0 or LicenseRef-Imazen-Commercial; see the
workspace root LICENSE-AGPL3 and COMMERCIAL.md for details.