diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 338ecb39913..ef643d33482 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -662,29 +662,38 @@ static bool test_sdpa_config( } const auto& outputs = result.get(); - // The mutating op returns [k_cache, v_cache, attn_output]; select the - // attention output (numel == S*Hq*D), not a mutated cache (numel Cmax*Hkv*D). - // Count matches and fail if ambiguous: a cache could share the same numel. + // Select the attention output [1,S,Hq,D] by shape; the op returns + // [k_cache, v_cache, attn_output] and a cache [1,Cmax,Hkv,D] can share numel. int attn_idx = -1; int attn_matches = 0; for (size_t i = 0; i < outputs.size(); i++) { - if (outputs[i].isTensor() && outputs[i].toTensor().numel() == on) { + if (!outputs[i].isTensor()) { + continue; + } + const auto& t = outputs[i].toTensor(); + if (t.dim() == 4 && static_cast(t.size(1)) == cfg.s && + static_cast(t.size(2)) == cfg.hq && + static_cast(t.size(3)) == cfg.d) { attn_idx = static_cast(i); attn_matches++; } } if (attn_idx < 0) { printf( - "FAIL: no attention output (numel %d) among %zu outputs\n", - on, + "FAIL: no attention output [1,%d,%d,%d] among %zu outputs\n", + cfg.s, + cfg.hq, + cfg.d, outputs.size()); return false; } if (attn_matches > 1) { printf( - "FAIL: ambiguous attention output: %d tensors match numel %d\n", + "FAIL: ambiguous attention output: %d tensors match shape [1,%d,%d,%d]\n", attn_matches, - on); + cfg.s, + cfg.hq, + cfg.d); return false; } const auto& out_tensor = outputs[attn_idx].toTensor();