[ExecuTorch][WebGPU] Add et_vk.apply_rotary_emb (interleaved RoPE) + ValueList multi-output#20264
[ExecuTorch][WebGPU] Add et_vk.apply_rotary_emb (interleaved RoPE) + ValueList multi-output#20264JulianCloudNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20264
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 9b09f89 with merge base 0378fc4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 2m 34s —— View job Code Review:
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 1m 55s —— View job Code Review:
|
Stack from ghstack (oldest at bottom):
Adds the WebGPU backend handler for
et_vk.apply_rotary_emb.default(interleaved Llama rotary positional embedding) plus theValueListgraph-value support its multi-output signature requires.The op rotates the query and key tensors by a shared
freqs_cos/freqs_sinpair and is composed of two dispatches of one WGSL kernel: each thread handles one (even, odd) element pair of a head row (out[2i] = x[2i]*cos - x[2i+1]*sin,out[2i+1] = x[2i]*sin + x[2i+1]*cos), one dispatch writingxq_outand one writingxk_out, mirroring the Vulkanapply_rotary_embreference (buffer-only, fp32, the interleaved.defaultvariant). Each dispatch owns a distinct compute pipeline (the graph destructor releases per dispatch, so a shared handle would double-free); the workgroup size is awg_sizepipeline-override constant clamped to the device limit, both 1D dispatch counts go throughWebGPUUtils::compute_1d_workgroup_countand are validated before any GPU-object allocation, and the embedded WGSL header is generated bygen_wgsl_headers.py.The two outputs (
xq_out,xk_out) are serialized by the Vulkan exporter as a singleValueListgraph value, which the runtime did not previously model. This adds theValueType::ValueListvalue kind, avalue_lists_table populated duringbuild(), and aget_value_listaccessor the handler uses to resolve the output ids. While in that code path it also closes a latent gap: a constant tensor whoseconstant_idis set but whose constants table is missing or out of range now throws (fail-loud) rather than silently leaving the buffer uninitialized.@exported-using-ghexport
Differential Revision: D108428756
Differential Revision: D108428756