Fix Dawn nightly: select SDPA attn output by shape, not numel (#20283)#20283
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20283
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@JulianCloudNTH has exported this pull request. If you are a Meta employee, you can view the originating Diff in D108625761. |
This PR needs a
|
…h#20283) Summary: The WebGPU Dawn native nightly (`webgpu_native_test`) fails deterministically on the `llama1b_prefill` SDPA config with `FAIL: ambiguous attention output: 3 tensors match numel 262144`, which fails the binary and turns the job red. `sdpa_with_kv_cache` returns three tensors `[k_cache, v_cache, attn_output]`. `test_sdpa_config` identified the attention output purely by element count (`numel == S*Hq*D`). For `llama1b_prefill` (`Hq=32, Hkv=8, D=64, S=128, Cmax=512`) the attention count `S*Hq*D = 128*32*64 = 262144` coincides exactly with each cache count `Cmax*Hkv*D = 512*8*64 = 262144`, so all three outputs match `numel` and the existing ambiguity guard correctly bails before any numeric comparison. The kernel output itself is fine -- the sibling `llama1b_decode` config (same `Hq/Hkv/D`) passes at `~1e-9`; only the test's output-selection heuristic was wrong. The colliding config and the numel selector were introduced together in D107595144. Fix: disambiguate by shape instead of flat count. The attention output is `[1, S, Hq, D]` while each cache is `[1, Cmax, Hkv, D]`; these differ in dims 1-2 even when the flat count collides. Match `dim()==4 && size(1)==S && size(2)==Hq && size(3)==D`, keeping the `attn_matches > 1` ambiguity guard as a backstop. Scope: test-only, one function (`test_sdpa_config`); no kernel, runtime, or export change. Authored with Claude Code. Reviewed By: Gasoonjia Differential Revision: D108625761
3186b53 to
63b3389
Compare
…h#20283) Summary: The WebGPU Dawn native nightly (`webgpu_native_test`) fails deterministically on the `llama1b_prefill` SDPA config with `FAIL: ambiguous attention output: 3 tensors match numel 262144`, which fails the binary and turns the job red. `sdpa_with_kv_cache` returns three tensors `[k_cache, v_cache, attn_output]`. `test_sdpa_config` identified the attention output purely by element count (`numel == S*Hq*D`). For `llama1b_prefill` (`Hq=32, Hkv=8, D=64, S=128, Cmax=512`) the attention count `S*Hq*D = 128*32*64 = 262144` coincides exactly with each cache count `Cmax*Hkv*D = 512*8*64 = 262144`, so all three outputs match `numel` and the existing ambiguity guard correctly bails before any numeric comparison. The kernel output itself is fine -- the sibling `llama1b_decode` config (same `Hq/Hkv/D`) passes at `~1e-9`; only the test's output-selection heuristic was wrong. The colliding config and the numel selector were introduced together in D107595144. Fix: disambiguate by shape instead of flat count. The attention output is `[1, S, Hq, D]` while each cache is `[1, Cmax, Hkv, D]`; these differ in dims 1-2 even when the flat count collides. Match `dim()==4 && size(1)==S && size(2)==Hq && size(3)==D`, keeping the `attn_matches > 1` ambiguity guard as a backstop. Scope: test-only, one function (`test_sdpa_config`); no kernel, runtime, or export change. Authored with Claude Code. Reviewed By: Gasoonjia Differential Revision: D108625761
63b3389 to
717a011
Compare
Summary:
The WebGPU Dawn native nightly (
webgpu_native_test) fails deterministically on thellama1b_prefillSDPA config withFAIL: ambiguous attention output: 3 tensors match numel 262144, which fails the binary and turns the job red.sdpa_with_kv_cachereturns three tensors[k_cache, v_cache, attn_output].test_sdpa_configidentified the attention output purely by element count (numel == S*Hq*D). Forllama1b_prefill(Hq=32, Hkv=8, D=64, S=128, Cmax=512) the attention countS*Hq*D = 128*32*64 = 262144coincides exactly with each cache countCmax*Hkv*D = 512*8*64 = 262144, so all three outputs matchnumeland the existing ambiguity guard correctly bails before any numeric comparison. The kernel output itself is fine -- the siblingllama1b_decodeconfig (sameHq/Hkv/D) passes at~1e-9; only the test's output-selection heuristic was wrong. The colliding config and the numel selector were introduced together in D107595144.Fix: disambiguate by shape instead of flat count. The attention output is
[1, S, Hq, D]while each cache is[1, Cmax, Hkv, D]; these differ in dims 1-2 even when the flat count collides. Matchdim()==4 && size(1)==S && size(2)==Hq && size(3)==D, keeping theattn_matches > 1ambiguity guard as a backstop.Scope: test-only, one function (
test_sdpa_config); no kernel, runtime, or export change.Authored with Claude Code.
Reviewed By: Gasoonjia
Differential Revision: D108625761