From 717a0116cefa6bc87b393536ff252702949784cc Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Mon, 15 Jun 2026 09:45:22 -0700 Subject: [PATCH] Fix Dawn nightly: select SDPA attn output by shape, not numel (#20283) Summary: The WebGPU Dawn native nightly (`webgpu_native_test`) fails deterministically on the `llama1b_prefill` SDPA config with `FAIL: ambiguous attention output: 3 tensors match numel 262144`, which fails the binary and turns the job red. `sdpa_with_kv_cache` returns three tensors `[k_cache, v_cache, attn_output]`. `test_sdpa_config` identified the attention output purely by element count (`numel == S*Hq*D`). For `llama1b_prefill` (`Hq=32, Hkv=8, D=64, S=128, Cmax=512`) the attention count `S*Hq*D = 128*32*64 = 262144` coincides exactly with each cache count `Cmax*Hkv*D = 512*8*64 = 262144`, so all three outputs match `numel` and the existing ambiguity guard correctly bails before any numeric comparison. The kernel output itself is fine -- the sibling `llama1b_decode` config (same `Hq/Hkv/D`) passes at `~1e-9`; only the test's output-selection heuristic was wrong. The colliding config and the numel selector were introduced together in D107595144. Fix: disambiguate by shape instead of flat count. The attention output is `[1, S, Hq, D]` while each cache is `[1, Cmax, Hkv, D]`; these differ in dims 1-2 even when the flat count collides. Match `dim()==4 && size(1)==S && size(2)==Hq && size(3)==D`, keeping the `attn_matches > 1` ambiguity guard as a backstop. Scope: test-only, one function (`test_sdpa_config`); no kernel, runtime, or export change. Authored with Claude Code. Reviewed By: Gasoonjia Differential Revision: D108625761 --- backends/webgpu/test/test_webgpu_native.cpp | 25 ++++++++++++++------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 338ecb39913..ef643d33482 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -662,29 +662,38 @@ static bool test_sdpa_config( } const auto& outputs = result.get(); - // The mutating op returns [k_cache, v_cache, attn_output]; select the - // attention output (numel == S*Hq*D), not a mutated cache (numel Cmax*Hkv*D). - // Count matches and fail if ambiguous: a cache could share the same numel. + // Select the attention output [1,S,Hq,D] by shape; the op returns + // [k_cache, v_cache, attn_output] and a cache [1,Cmax,Hkv,D] can share numel. int attn_idx = -1; int attn_matches = 0; for (size_t i = 0; i < outputs.size(); i++) { - if (outputs[i].isTensor() && outputs[i].toTensor().numel() == on) { + if (!outputs[i].isTensor()) { + continue; + } + const auto& t = outputs[i].toTensor(); + if (t.dim() == 4 && static_cast(t.size(1)) == cfg.s && + static_cast(t.size(2)) == cfg.hq && + static_cast(t.size(3)) == cfg.d) { attn_idx = static_cast(i); attn_matches++; } } if (attn_idx < 0) { printf( - "FAIL: no attention output (numel %d) among %zu outputs\n", - on, + "FAIL: no attention output [1,%d,%d,%d] among %zu outputs\n", + cfg.s, + cfg.hq, + cfg.d, outputs.size()); return false; } if (attn_matches > 1) { printf( - "FAIL: ambiguous attention output: %d tensors match numel %d\n", + "FAIL: ambiguous attention output: %d tensors match shape [1,%d,%d,%d]\n", attn_matches, - on); + cfg.s, + cfg.hq, + cfg.d); return false; } const auto& out_tensor = outputs[attn_idx].toTensor();