llm_runner: plumb prefill temperature#20244
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20244
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
This PR threads sampling temperature through TextPrefiller so the first sampled token (produced during prefill in session-based serving) uses the same sampling inputs as subsequent decode steps, and exposes TextTokenGenerator’s logit-processor application to keep decode paths consistent.
Changes:
- Expose
TextTokenGenerator::apply_logit_processors()(andis_eos()) so token-step callers can reuse the same logit-processing logic asgenerate(). - Extend
TextPrefiller::prefill()/prefill_chunk()to accept an optionaltemperature, applied only to the final chunk’s sampled token. - Update
TextPrefillerunit tests for the newprefill_chunksignature (but currently without asserting temperature behavior).
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| extension/llm/runner/text_token_generator.h | Adds public helpers for applying logit processors and EOS checking; generate() now calls the helper. |
| extension/llm/runner/text_prefiller.h | Adds temperature parameter to prefill APIs (currently via virtual signature change + default arg). |
| extension/llm/runner/text_prefiller.cpp | Threads temperature into logits_to_token() and ensures only the last chunk uses non-greedy sampling. |
| extension/llm/runner/test/test_text_prefiller.cpp | Updates mocks/expectations for new prefill_chunk signature; does not yet assert temperature forwarding/last-chunk behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
7fde821 to
878f15c
Compare
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 5m 57s —— View job Review:
|
| } | ||
|
|
||
| /// Whether `token` is an end-of-sequence token (used by single-step decode). | ||
| inline bool is_eos(uint64_t token) const { |
There was a problem hiding this comment.
removing for now
| EXPECT_EQ(result.get(), 42); | ||
| } | ||
|
|
||
| TEST_F(TextPrefillerTest, TwoArgumentPrefillUsesGreedyTemperature) { |
There was a problem hiding this comment.
plz add tests here to guard the temperature range. Noramlly it should be in [0, 1] and we should raise error if temperature is out of range
Session-based serving drives generation as prefill plus token steps instead of one monolithic generate call. For that path to be correct, the first sampled token produced during prefill must honor the same sampling inputs as the rest of the decode loop; otherwise requests using temperature can silently start greedily and then switch behavior on later tokens. This threads optional temperature through TextPrefiller and exposes the existing TextTokenGenerator logit-processor application so token-step callers can reuse the same sampling preparation as generate(). The goal is to remove a divergence point before session-backed serving starts depending on these primitives. Default behavior remains greedy, so existing callers that do not pass temperature keep the same semantics. The added tests focus on the new non-default path and on sharing the logit-processor logic rather than duplicating it.
878f15c to
4a38f8f
Compare
| ET_CHECK_OR_RETURN_ERROR( | ||
| temperature >= 0.0f && temperature <= 1.0f, | ||
| InvalidArgument, | ||
| "Temperature must be in [0, 1], got %f", | ||
| static_cast<double>(temperature)); |
| ET_CHECK_OR_RETURN_ERROR( | ||
| temperature >= 0.0f && temperature <= 1.0f, | ||
| InvalidArgument, | ||
| "Temperature must be in [0, 1], got %f", | ||
| static_cast<double>(temperature)); |
| ET_CHECK_OR_RETURN_ERROR( | ||
| temperature >= 0.0f && temperature <= 1.0f, | ||
| InvalidArgument, | ||
| "Temperature must be in [0, 1], got %f", | ||
| static_cast<double>(temperature)); |
| std::vector<uint64_t>& prompt_tokens, | ||
| int64_t& start_pos, | ||
| float temperature) { | ||
| ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null"); |
| // Only the final chunk samples the first generated token. | ||
| const bool is_last_chunk = | ||
| num_tokens_to_process + num_tokens_to_prefill_with >= | ||
| num_prompt_tokens; | ||
| auto chunk_result = prefill_chunk( | ||
| prompt_tokens_to_process, | ||
| start_pos, | ||
| is_last_chunk ? temperature : 0.0f); |
| | `max_new_tokens` | `int32_t` | `-1` | Maximum new tokens to generate (-1 = use available context) | | ||
| | `seq_len` | `int32_t` | `1024` | Total sequence length including prompt | | ||
| | `temperature` | `float` | `0.8f` | Sampling temperature (0.0 = deterministic, 1.0+ = creative) | | ||
| | `temperature` | `float` | `0.8f` | Sampling temperature in [0.0, 1.0] (0.0 = deterministic) | | ||
| | `echo` | `bool` | `true` | Whether to echo the input prompt | |
| GenerationConfig config; | ||
| config.temperature = 0.1f; // Very deterministic | ||
| runner->generate(factual_prompt, config, callback); | ||
|
|
||
| config.temperature = 1.2f; // Very creative | ||
| config.temperature = 1.0f; // Highest supported temperature | ||
| runner->generate(creative_prompt, config, callback); |
| TEST_F(RunnerTest, TextTokenGeneratorRejectsTemperatureOutOfRange) { | ||
| auto tokenizer = createMockTokenizer(); | ||
| auto text_decoder_runner = createMockTextDecoderRunner(); | ||
| Stats stats; | ||
| auto generator = createTextTokenGenerator( | ||
| tokenizer.get(), text_decoder_runner.get(), &stats); | ||
|
|
||
| std::vector<uint64_t> tokens = {1, 2, 3}; | ||
| EXPECT_CALL(*text_decoder_runner, step(_, _)).Times(0); | ||
|
|
||
| EXPECT_EQ( | ||
| generator->generate(tokens, 3, 3, -0.1f, [](const std::string&) {}) | ||
| .error(), | ||
| Error::InvalidArgument); | ||
| EXPECT_EQ( | ||
| generator->generate(tokens, 3, 3, 1.1f, [](const std::string&) {}) | ||
| .error(), | ||
| Error::InvalidArgument); | ||
| } |
| EXPECT_EQ( | ||
| prefiller->prefill(prompt_tokens, start_pos, -0.1f).error(), | ||
| Error::InvalidArgument); | ||
| EXPECT_EQ( | ||
| prefiller->prefill(prompt_tokens, start_pos, 1.1f).error(), | ||
| Error::InvalidArgument); |
| EXPECT_EQ( | ||
| prefiller->prefill_chunk(prompt_tokens, start_pos, -0.1f).error(), | ||
| Error::InvalidArgument); | ||
| EXPECT_EQ( | ||
| prefiller->prefill_chunk(prompt_tokens, start_pos, 1.1f).error(), | ||
| Error::InvalidArgument); |
Summary: Forward fix for D108707577 (llm_runner: plumb prefill temperature, #20244). The new test RunnerTest.TextTokenGeneratorRejectsTemperatureOutOfRange aborts: [ RUN ] RunnerTest.TextTokenGeneratorRejectsTemperatureOutOfRange ExecuTorch PAL must be initialized before call to et_pal_current_ticks() *** Signal 6 (SIGABRT) *** The new test drives the temperature-rejection path, which emits an ET_LOG (and thus calls et_pal_current_ticks()) before any model load. The RunnerTest fixture never initializes the ExecuTorch runtime, so the timer call aborts. The valid- temperature tests in the same fixture pass because their happy path does not log. Fix: initialize the runtime in RunnerTest::SetUp(), matching the established pattern used by the sibling tests in this same directory (test_text_prefiller, test_util, test_wav_loader all call executorch::runtime::runtime_init() in SetUp()). Applied to both the fbcode and xplat copies. Differential Revision: D108831322
Session-based serving drives generation as prefill plus token steps instead of one monolithic generate call. For that path to be correct, the first sampled token produced during prefill must honor the same sampling inputs as the rest of the decode loop; otherwise requests using temperature can silently start greedily and then switch behavior on later tokens.
This threads optional temperature through TextPrefiller and exposes the existing TextTokenGenerator logit-processor application so token-step callers can reuse the same sampling preparation as generate(). The goal is to remove a divergence point before session-backed serving starts depending on these primitives.
Default behavior remains greedy, so existing callers that do not pass temperature keep the same semantics. The added tests focus on the new non-default path and on sharing the logit-processor logic rather than duplicating it.
#20001