Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions backends/webgpu/test/test_webgpu_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,29 +662,38 @@ static bool test_sdpa_config(
}

const auto& outputs = result.get();
// The mutating op returns [k_cache, v_cache, attn_output]; select the
// attention output (numel == S*Hq*D), not a mutated cache (numel Cmax*Hkv*D).
// Count matches and fail if ambiguous: a cache could share the same numel.
// Select the attention output [1,S,Hq,D] by shape; the op returns
// [k_cache, v_cache, attn_output] and a cache [1,Cmax,Hkv,D] can share numel.
int attn_idx = -1;
int attn_matches = 0;
for (size_t i = 0; i < outputs.size(); i++) {
if (outputs[i].isTensor() && outputs[i].toTensor().numel() == on) {
if (!outputs[i].isTensor()) {
continue;
}
const auto& t = outputs[i].toTensor();
if (t.dim() == 4 && static_cast<int>(t.size(1)) == cfg.s &&
static_cast<int>(t.size(2)) == cfg.hq &&
static_cast<int>(t.size(3)) == cfg.d) {
attn_idx = static_cast<int>(i);
attn_matches++;
}
}
if (attn_idx < 0) {
printf(
"FAIL: no attention output (numel %d) among %zu outputs\n",
on,
"FAIL: no attention output [1,%d,%d,%d] among %zu outputs\n",
cfg.s,
cfg.hq,
cfg.d,
outputs.size());
return false;
}
if (attn_matches > 1) {
printf(
"FAIL: ambiguous attention output: %d tensors match numel %d\n",
"FAIL: ambiguous attention output: %d tensors match shape [1,%d,%d,%d]\n",
attn_matches,
on);
cfg.s,
cfg.hq,
cfg.d);
return false;
}
const auto& out_tensor = outputs[attn_idx].toTensor();
Expand Down
Loading