Skip to content

Commit 65086d1

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Fix start_pos not being updated in prefill_chunk() (#11592)
Summary: This PR fixes `TextPrefiller` class `prefill_chunk` API. We should update `start_pos` at the end of the while loop. ### Enhancements to `TextPrefiller` functionality: * Renamed `prefillChunk` to `prefill_chunk` for consistency across method names in `TextPrefiller` class. Updates were made in both `text_prefiller.cpp` and `text_prefiller.h`. [[1]](diffhunk://#diff-a990c73be85dcbe7ef7e2ca02082fc0beb9951c1a8973587575a3e285f134cedL59-R74) [[2]](diffhunk://#diff-d8c18164204bbba3d0e0ccdf0c6e7f776f8865d66ec3b4129fcba641dcd7aed4L48-R48) * Adjusted the `start_pos` increment logic in the `prefill` method to ensure accurate position updates after processing each chunk. ### Comprehensive unit tests for `TextPrefiller`: * Added a new test file `test_text_prefiller.cpp` with detailed test cases for `TextPrefiller`, including edge cases, error handling, and scenarios for parallel and sequential prefill modes. Mock classes and spy implementations were introduced to facilitate testing. Differential Revision: D76483572 Pulled By: larryliu0820
1 parent db96aba commit 65086d1

File tree

4 files changed

+322
-6
lines changed

4 files changed

+322
-6
lines changed

extension/llm/runner/test/targets.bzl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ def define_common_targets():
2727
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
2828
],
2929
)
30+
31+
runtime.cxx_test(
32+
name = "test_text_prefiller",
33+
srcs = ["test_text_prefiller.cpp"],
34+
deps = [
35+
"//executorch/extension/llm/runner:runner_lib",
36+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
37+
],
38+
)
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8+
*/
9+
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
#include <executorch/extension/llm/runner/text_prefiller.h>
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/platform/runtime.h>
14+
#include <gmock/gmock.h>
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::extension::llm::TextDecoderRunner;
19+
using executorch::extension::llm::TextPrefiller;
20+
using executorch::runtime::Error;
21+
using executorch::runtime::Result;
22+
using executorch::runtime::testing::TensorFactory;
23+
24+
// Mock class for TextDecoderRunner
25+
class MockTextDecoderRunner : public TextDecoderRunner {
26+
public:
27+
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
28+
MOCK_METHOD(
29+
Result<executorch::aten::Tensor>,
30+
step,
31+
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
32+
());
33+
MOCK_METHOD(bool, is_method_loaded, (), ());
34+
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
35+
MOCK_METHOD(::executorch::runtime::Error, load, (), ());
36+
};
37+
38+
// Test fixture for TextPrefiller tests
39+
class TextPrefillerTest : public Test {
40+
protected:
41+
void SetUp() override {
42+
executorch::runtime::runtime_init();
43+
// Set up default behavior for the text decoder runner
44+
ON_CALL(text_decoder_runner_, is_method_loaded())
45+
.WillByDefault(Return(true));
46+
ON_CALL(text_decoder_runner_, step)
47+
.WillByDefault([&](executorch::extension::TensorPtr&,
48+
executorch::extension::TensorPtr&) {
49+
return Result<executorch::aten::Tensor>(tensor);
50+
});
51+
}
52+
53+
// Helper method to create a TextPrefiller with specific parameters
54+
std::unique_ptr<TextPrefiller> createTextPrefiller(
55+
int64_t max_seq_len,
56+
bool use_kv_cache = true,
57+
bool enable_parallel_prefill = false) {
58+
return std::make_unique<TextPrefiller>(
59+
&text_decoder_runner_,
60+
use_kv_cache,
61+
enable_parallel_prefill,
62+
max_seq_len);
63+
}
64+
65+
// Create a mock TextPrefiller that allows us to mock prefill_chunk calls
66+
class MockTextPrefiller : public TextPrefiller {
67+
public:
68+
MockTextPrefiller(
69+
TextDecoderRunner* text_decoder_runner,
70+
bool use_kv_cache,
71+
bool enable_parallel_prefill,
72+
int64_t max_seq_len)
73+
: TextPrefiller(
74+
text_decoder_runner,
75+
use_kv_cache,
76+
enable_parallel_prefill,
77+
max_seq_len) {}
78+
79+
MOCK_METHOD(
80+
::executorch::runtime::Result<uint64_t>,
81+
prefill_chunk,
82+
(std::vector<uint64_t>&, int64_t&),
83+
());
84+
};
85+
86+
// Create a mock TextPrefiller
87+
std::unique_ptr<MockTextPrefiller> createMockTextPrefiller(
88+
int64_t max_seq_len,
89+
bool use_kv_cache = true,
90+
bool enable_parallel_prefill = false) {
91+
return std::make_unique<MockTextPrefiller>(
92+
&text_decoder_runner_,
93+
use_kv_cache,
94+
enable_parallel_prefill,
95+
max_seq_len);
96+
}
97+
98+
MockTextDecoderRunner text_decoder_runner_;
99+
std::vector<float> return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f};
100+
TensorFactory<executorch::aten::ScalarType::Float> tf;
101+
executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_);
102+
};
103+
104+
// Test that prefill() calls prefill_chunk() once when prompt tokens <=
105+
// max_seq_len
106+
TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
107+
// Create a spy TextPrefiller with max_seq_len = 10
108+
auto prefiller = createMockTextPrefiller(10);
109+
110+
// Create prompt tokens with size <= max_seq_len
111+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5};
112+
int64_t start_pos = 0;
113+
114+
// Expect prefill_chunk to be called exactly once with the entire prompt
115+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
116+
.Times(1)
117+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
118+
// Verify the tokens passed to prefill_chunk
119+
EXPECT_EQ(tokens.size(), prompt_tokens.size());
120+
for (size_t i = 0; i < tokens.size(); i++) {
121+
EXPECT_EQ(tokens[i], prompt_tokens[i]);
122+
}
123+
// Verify the position
124+
EXPECT_EQ(pos, start_pos);
125+
return Result<uint64_t>(42);
126+
});
127+
128+
// Call prefill
129+
auto result = prefiller->prefill(prompt_tokens, start_pos);
130+
131+
// Verify the result
132+
EXPECT_EQ(result.error(), Error::Ok);
133+
EXPECT_EQ(result.get(), 42);
134+
}
135+
136+
// Test that prefill() calls prefill_chunk() multiple times when prompt tokens >
137+
// max_seq_len
138+
TEST_F(
139+
TextPrefillerTest,
140+
PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) {
141+
// Create a spy TextPrefiller with max_seq_len = 3
142+
const int64_t max_seq_len = 3;
143+
auto prefiller = createMockTextPrefiller(max_seq_len);
144+
145+
// Create prompt tokens with size > max_seq_len
146+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
147+
int64_t start_pos = 0;
148+
149+
// Set up expectations for prefill_chunk calls
150+
{
151+
InSequence seq; // Ensure calls happen in the expected order
152+
153+
// First chunk: tokens [1, 2, 3]
154+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
155+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
156+
EXPECT_EQ(tokens.size(), 3);
157+
EXPECT_EQ(tokens[0], 1);
158+
EXPECT_EQ(tokens[1], 2);
159+
EXPECT_EQ(tokens[2], 3);
160+
EXPECT_EQ(pos, 0);
161+
return Result<uint64_t>(10);
162+
});
163+
164+
// Second chunk: tokens [4, 5, 6]
165+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
166+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
167+
EXPECT_EQ(tokens.size(), 3);
168+
EXPECT_EQ(tokens[0], 4);
169+
EXPECT_EQ(tokens[1], 5);
170+
EXPECT_EQ(tokens[2], 6);
171+
EXPECT_EQ(pos, 3);
172+
return Result<uint64_t>(20);
173+
});
174+
175+
// Third chunk: tokens [7, 8]
176+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
177+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
178+
EXPECT_EQ(tokens.size(), 2);
179+
EXPECT_EQ(tokens[0], 7);
180+
EXPECT_EQ(tokens[1], 8);
181+
EXPECT_EQ(pos, 6);
182+
return Result<uint64_t>(30);
183+
});
184+
}
185+
186+
// Call prefill
187+
auto result = prefiller->prefill(prompt_tokens, start_pos);
188+
189+
// Verify the result
190+
EXPECT_EQ(result.error(), Error::Ok);
191+
EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk
192+
193+
// Verify that start_pos has been updated correctly
194+
EXPECT_EQ(start_pos, prompt_tokens.size());
195+
}
196+
197+
// Test that prefill() handles edge cases correctly
198+
TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) {
199+
// Create a spy TextPrefiller with max_seq_len = 1
200+
const int64_t max_seq_len = 1;
201+
auto prefiller = createMockTextPrefiller(max_seq_len);
202+
203+
// Create prompt tokens with size > max_seq_len
204+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
205+
int64_t start_pos = 5; // Non-zero starting position
206+
207+
// Set up expectations for prefill_chunk calls
208+
{
209+
InSequence seq;
210+
211+
// First chunk: token [1]
212+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
213+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
214+
EXPECT_EQ(tokens.size(), 1);
215+
EXPECT_EQ(tokens[0], 1);
216+
EXPECT_EQ(pos, 5);
217+
return Result<uint64_t>(10);
218+
});
219+
220+
// Second chunk: token [2]
221+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
222+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
223+
EXPECT_EQ(tokens.size(), 1);
224+
EXPECT_EQ(tokens[0], 2);
225+
EXPECT_EQ(pos, 6);
226+
return Result<uint64_t>(20);
227+
});
228+
229+
// Third chunk: token [3]
230+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
231+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
232+
EXPECT_EQ(tokens.size(), 1);
233+
EXPECT_EQ(tokens[0], 3);
234+
EXPECT_EQ(pos, 7);
235+
return Result<uint64_t>(30);
236+
});
237+
}
238+
239+
// Call prefill
240+
auto result = prefiller->prefill(prompt_tokens, start_pos);
241+
242+
// Verify the result
243+
EXPECT_EQ(result.error(), Error::Ok);
244+
EXPECT_EQ(result.get(), 30);
245+
246+
// Verify that start_pos has been updated correctly
247+
EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens)
248+
}
249+
250+
// Test that prefill() handles errors from prefill_chunk correctly
251+
TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
252+
// Create a spy TextPrefiller with max_seq_len = 3
253+
const int64_t max_seq_len = 3;
254+
auto prefiller = createMockTextPrefiller(max_seq_len);
255+
256+
// Create prompt tokens with size > max_seq_len
257+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5};
258+
int64_t start_pos = 0;
259+
260+
// Set up expectations for prefill_chunk calls
261+
{
262+
InSequence seq;
263+
264+
// First chunk: tokens [1, 2, 3] - succeeds
265+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
266+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
267+
return Result<uint64_t>(10);
268+
});
269+
270+
// Second chunk: tokens [4, 5] - fails
271+
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
272+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
273+
return Result<uint64_t>(Error::InvalidArgument);
274+
});
275+
}
276+
277+
// Call prefill
278+
auto result = prefiller->prefill(prompt_tokens, start_pos);
279+
280+
// Verify that the error is propagated
281+
EXPECT_EQ(result.error(), Error::InvalidArgument);
282+
}
283+
284+
// Test that prefill_chunk() works correctly with parallel prefill enabled
285+
TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
286+
// Create a TextPrefiller with parallel prefill enabled
287+
auto prefiller = createTextPrefiller(10, true, true);
288+
289+
// Set up expectations for the text decoder runner
290+
EXPECT_CALL(text_decoder_runner_, step(_, _))
291+
.Times(1)
292+
.WillOnce(Return(Result<executorch::aten::Tensor>(tensor)));
293+
294+
// Create prompt tokens
295+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
296+
int64_t start_pos = 0;
297+
298+
// Call prefill
299+
auto result = prefiller->prefill(prompt_tokens, start_pos);
300+
301+
// Verify the result
302+
EXPECT_EQ(result.error(), Error::Ok);
303+
304+
// Verify that start_pos has been updated correctly
305+
EXPECT_EQ(start_pos, prompt_tokens.size());
306+
}

extension/llm/runner/text_prefiller.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ TextPrefiller::TextPrefiller(
2424
: text_decoder_runner_(text_decoder_runner),
2525
use_kv_cache_(use_kv_cache),
2626
enable_parallel_prefill_(enable_parallel_prefill),
27-
max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {
28-
} // -1 because for some reason tracing results in this upperbound
27+
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {
28+
}
2929

3030
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3131
std::vector<uint64_t>& prompt_tokens,
@@ -56,21 +56,22 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5656
prompt_tokens_to_process.begin());
5757

5858
// Process this chunk
59-
auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos);
59+
auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos);
6060
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6161
cur_token = chunk_result.get();
6262

63+
start_pos += num_tokens_to_prefill_with;
6364
num_tokens_to_process += num_tokens_to_prefill_with;
6465
}
6566

6667
return cur_token;
6768
} else {
6869
// If prompt tokens don't exceed max_seq_len_, process them directly
69-
return prefillChunk(prompt_tokens, start_pos);
70+
return prefill_chunk(prompt_tokens, start_pos);
7071
}
7172
}
7273

73-
::executorch::runtime::Result<uint64_t> TextPrefiller::prefillChunk(
74+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
7475
std::vector<uint64_t>& prompt_tokens,
7576
int64_t& start_pos) {
7677
// enable_parallel_prefill_ maybe set even when not using kv cache

extension/llm/runner/text_prefiller.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class ET_EXPERIMENTAL TextPrefiller {
4545
* Module.
4646
* @return The next token of the LLM Module after prefilling this chunk.
4747
*/
48-
::executorch::runtime::Result<uint64_t> prefillChunk(
48+
virtual ::executorch::runtime::Result<uint64_t> prefill_chunk(
4949
std::vector<uint64_t>& prompt_tokens,
5050
int64_t& start_pos);
5151

0 commit comments

Comments
 (0)