diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index ac46f0021fb..15b4d005f9d 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -17,13 +17,10 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) -set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp) +set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp + test_text_prefiller.cpp +) et_cxx_test( - test_runner - SOURCES - ${_test_srcs} - EXTRA_LIBS - executorch - extension_llm_runner + test_runner SOURCES ${_test_srcs} EXTRA_LIBS executorch extension_llm_runner ) diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index a5c8be7b6de..8bc3d4cc100 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -27,3 +27,12 @@ def define_common_targets(): "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], ) + + runtime.cxx_test( + name = "test_text_prefiller", + srcs = ["test_text_prefiller.cpp"], + deps = [ + "//executorch/extension/llm/runner:runner_lib", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp new file mode 100644 index 00000000000..b786fc71978 --- /dev/null +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -0,0 +1,306 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using executorch::extension::llm::TextDecoderRunner; +using executorch::extension::llm::TextPrefiller; +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::testing::TensorFactory; + +// Mock class for TextDecoderRunner +class MockTextDecoderRunner : public TextDecoderRunner { + public: + MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {} + MOCK_METHOD( + Result, + step, + (executorch::extension::TensorPtr&, executorch::extension::TensorPtr&), + ()); + MOCK_METHOD(bool, is_method_loaded, (), ()); + MOCK_METHOD(Result, prefill, (std::vector&, int64_t), ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); +}; + +// Test fixture for TextPrefiller tests +class TextPrefillerTest : public Test { + protected: + void SetUp() override { + executorch::runtime::runtime_init(); + // Set up default behavior for the text decoder runner + ON_CALL(text_decoder_runner_, is_method_loaded()) + .WillByDefault(Return(true)); + ON_CALL(text_decoder_runner_, step) + .WillByDefault([&](executorch::extension::TensorPtr&, + executorch::extension::TensorPtr&) { + return Result(tensor); + }); + } + + // Helper method to create a TextPrefiller with specific parameters + std::unique_ptr createTextPrefiller( + int64_t max_seq_len, + bool use_kv_cache = true, + bool enable_parallel_prefill = false) { + return std::make_unique( + &text_decoder_runner_, + use_kv_cache, + enable_parallel_prefill, + max_seq_len); + } + + // Create a mock TextPrefiller that allows us to mock prefill_chunk calls + class MockTextPrefiller : public TextPrefiller { + public: + MockTextPrefiller( + TextDecoderRunner* text_decoder_runner, + bool use_kv_cache, + bool enable_parallel_prefill, + int64_t max_seq_len) + : TextPrefiller( + text_decoder_runner, + use_kv_cache, + enable_parallel_prefill, + max_seq_len) {} + + MOCK_METHOD( + ::executorch::runtime::Result, + prefill_chunk, + (std::vector&, int64_t&), + ()); + }; + + // Create a mock TextPrefiller + std::unique_ptr createMockTextPrefiller( + int64_t max_seq_len, + bool use_kv_cache = true, + bool enable_parallel_prefill = false) { + return std::make_unique( + &text_decoder_runner_, + use_kv_cache, + enable_parallel_prefill, + max_seq_len); + } + + MockTextDecoderRunner text_decoder_runner_; + std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; + TensorFactory tf; + executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_); +}; + +// Test that prefill() calls prefill_chunk() once when prompt tokens <= +// max_seq_len +TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) { + // Create a spy TextPrefiller with max_seq_len = 10 + auto prefiller = createMockTextPrefiller(10); + + // Create prompt tokens with size <= max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5}; + int64_t start_pos = 0; + + // Expect prefill_chunk to be called exactly once with the entire prompt + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .Times(1) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + // Verify the tokens passed to prefill_chunk + EXPECT_EQ(tokens.size(), prompt_tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + EXPECT_EQ(tokens[i], prompt_tokens[i]); + } + // Verify the position + EXPECT_EQ(pos, start_pos); + return Result(42); + }); + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 42); +} + +// Test that prefill() calls prefill_chunk() multiple times when prompt tokens > +// max_seq_len +TEST_F( + TextPrefillerTest, + PrefillCallsPrefillChunkMultipleTimesWhenPromptExceedsMaxLen) { + // Create a spy TextPrefiller with max_seq_len = 3 + const int64_t max_seq_len = 3; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8}; + int64_t start_pos = 0; + + // Set up expectations for prefill_chunk calls + { + InSequence seq; // Ensure calls happen in the expected order + + // First chunk: tokens [1, 2, 3] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 3); + EXPECT_EQ(tokens[0], 1); + EXPECT_EQ(tokens[1], 2); + EXPECT_EQ(tokens[2], 3); + EXPECT_EQ(pos, 0); + return Result(10); + }); + + // Second chunk: tokens [4, 5, 6] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 3); + EXPECT_EQ(tokens[0], 4); + EXPECT_EQ(tokens[1], 5); + EXPECT_EQ(tokens[2], 6); + EXPECT_EQ(pos, 3); + return Result(20); + }); + + // Third chunk: tokens [7, 8] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 2); + EXPECT_EQ(tokens[0], 7); + EXPECT_EQ(tokens[1], 8); + EXPECT_EQ(pos, 6); + return Result(30); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 30); // Should return the token from the last chunk + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, prompt_tokens.size()); +} + +// Test that prefill() handles edge cases correctly +TEST_F(TextPrefillerTest, PrefillHandlesEdgeCasesCorrectly) { + // Create a spy TextPrefiller with max_seq_len = 1 + const int64_t max_seq_len = 1; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 5; // Non-zero starting position + + // Set up expectations for prefill_chunk calls + { + InSequence seq; + + // First chunk: token [1] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 1); + EXPECT_EQ(pos, 5); + return Result(10); + }); + + // Second chunk: token [2] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 2); + EXPECT_EQ(pos, 6); + return Result(20); + }); + + // Third chunk: token [3] + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + EXPECT_EQ(tokens.size(), 1); + EXPECT_EQ(tokens[0], 3); + EXPECT_EQ(pos, 7); + return Result(30); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + EXPECT_EQ(result.get(), 30); + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, 8); // 5 (initial) + 3 (tokens) +} + +// Test that prefill() handles errors from prefill_chunk correctly +TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) { + // Create a spy TextPrefiller with max_seq_len = 3 + const int64_t max_seq_len = 3; + auto prefiller = createMockTextPrefiller(max_seq_len); + + // Create prompt tokens with size > max_seq_len + std::vector prompt_tokens = {1, 2, 3, 4, 5}; + int64_t start_pos = 0; + + // Set up expectations for prefill_chunk calls + { + InSequence seq; + + // First chunk: tokens [1, 2, 3] - succeeds + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + return Result(10); + }); + + // Second chunk: tokens [4, 5] - fails + EXPECT_CALL(*prefiller, prefill_chunk(_, _)) + .WillOnce([&](std::vector& tokens, int64_t& pos) { + return Result(Error::InvalidArgument); + }); + } + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify that the error is propagated + EXPECT_EQ(result.error(), Error::InvalidArgument); +} + +// Test that prefill_chunk() works correctly with parallel prefill enabled +TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) { + // Create a TextPrefiller with parallel prefill enabled + auto prefiller = createTextPrefiller(10, true, true); + + // Set up expectations for the text decoder runner + EXPECT_CALL(text_decoder_runner_, step(_, _)) + .Times(1) + .WillOnce(Return(Result(tensor))); + + // Create prompt tokens + std::vector prompt_tokens = {1, 2, 3}; + int64_t start_pos = 0; + + // Call prefill + auto result = prefiller->prefill(prompt_tokens, start_pos); + + // Verify the result + EXPECT_EQ(result.error(), Error::Ok); + + // Verify that start_pos has been updated correctly + EXPECT_EQ(start_pos, prompt_tokens.size()); +} diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 19c260f5be6..d02c337451d 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -24,8 +24,8 @@ TextPrefiller::TextPrefiller( : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), enable_parallel_prefill_(enable_parallel_prefill), - max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) { -} // -1 because for some reason tracing results in this upperbound + max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) { +} ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, @@ -56,21 +56,22 @@ ::executorch::runtime::Result TextPrefiller::prefill( prompt_tokens_to_process.begin()); // Process this chunk - auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos); + auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos); ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); + start_pos += num_tokens_to_prefill_with; num_tokens_to_process += num_tokens_to_prefill_with; } return cur_token; } else { // If prompt tokens don't exceed max_seq_len_, process them directly - return prefillChunk(prompt_tokens, start_pos); + return prefill_chunk(prompt_tokens, start_pos); } } -::executorch::runtime::Result TextPrefiller::prefillChunk( +::executorch::runtime::Result TextPrefiller::prefill_chunk( std::vector& prompt_tokens, int64_t& start_pos) { // enable_parallel_prefill_ maybe set even when not using kv cache diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 49b2c867167..ce12506a05c 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -45,7 +45,7 @@ class ET_EXPERIMENTAL TextPrefiller { * Module. * @return The next token of the LLM Module after prefilling this chunk. */ - ::executorch::runtime::Result prefillChunk( + virtual ::executorch::runtime::Result prefill_chunk( std::vector& prompt_tokens, int64_t& start_pos);