Skip to content

[Model][Frontend] Adding timeseries modality support and Qwen2.5-ChatTS model support #16852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

chemeris
Copy link

@chemeris chemeris commented Apr 18, 2025

This pull request has two parts:

  1. Adds generic infrastructure for handling time series as a modality, including an OpenAI API server.
  2. Adds support for ChatTS model inference that relies on the above change for both offline inference and online serving using the OpenAI API server.

Please refer to the official ChatTS documentation for details about the model architecture: https://github.com/NetManAIOps/ChatTS/
This code is based on the original ChatTS code, but works with the latest vllm code, and adds support for V1 vLLM engine and OpenAI API serving.

To use the current version of ChatTS requires --trust-remote-code and --hf-overrides in order to load config and processing classes from the ChatTS HF repo, but use the vllm implementation of the model itself.

Example script to serve ChatTS via an OpenAI API server with vLLM:

vllm serve ../ChatTS-model \
    --served-model-name chatts \
    --trust-remote-code \
    --hf-overrides '{"model_type":"chatts"}' \
    --max-model-len 6000 \
    --gpu-memory-utilization 0.8 \
    --limit-mm-per-prompt timeseries=50 \
    --allowed-local-media-path $(pwd) \
    --host 0.0.0.0 \
    --port 8090 

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added frontend multi-modality Related to multi-modality (#4194) labels Apr 18, 2025
Copy link

mergify bot commented Apr 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chemeris.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this model to vLLM and expanding the multi-modality code! Some initial comments.

@chemeris chemeris force-pushed the timeseries branch 3 times, most recently from 1572d8e to d27a6d6 Compare April 19, 2025 17:40
@DarkLight1337
Copy link
Member

Please verify your model by following the guide in https://docs.vllm.ai/en/latest/contributing/model/tests.html

Also make sure to add this model to the Supported Models page in the docs!

@chemeris
Copy link
Author

@DarkLight1337 Thank you. I'll read the guide and see what's required to add the model.

Copy link

mergify bot commented Apr 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chemeris.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@chemeris
Copy link
Author

chemeris commented Jun 5, 2025

@DarkLight1337 @ywang96 A kind ping about this. Please, could we merge this PR?

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jun 5, 2025

Sorry for the delay, @Isotr0py @jeejeelee are you two able to help? I am quite busy lately.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jun 5, 2025

Regarding the prefix caching issue, maybe @heheda12345 can help as well?

@chemeris
Copy link
Author

chemeris commented Jun 5, 2025

@Isotr0py All changes made as you suggested, thank you. Especially for catching the hardcoded float16.

@chemeris chemeris force-pushed the timeseries branch 2 times, most recently from 93577f7 to 9dfe24e Compare June 5, 2025 20:34
@heheda12345
Copy link
Collaborator

(a) why is it not caching these 9 tokens

Because we only cache full blocks that won't be further modified. For example, with block_size 4, if we cache [E] of a request [ABCD,E], and there are two new requests [ABCD, EF] and [ABCD, EG] that reuse [E], they will modify block [E] with different values.

(b) why is the result different when it processes them the second time?

I'm not sure. Maybe you can print the block_ids and the kv_cache tensor of these blocks to see if there is any problem.

@chemeris
Copy link
Author

chemeris commented Jun 6, 2025

(a) why is it not caching these 9 tokens

Because we only cache full blocks that won't be further modified. For example, with block_size 4, if we cache [E] of a request [ABCD,E], and there are two new requests [ABCD, EF] and [ABCD, EG] that reuse [E], they will modify block [E] with different values.

Thank you for the explanation. I'm not sure this matches my observations, though.

From my memory, when I was sending the exact same prompt four times, I saw:

  1. Output X (full prompt is processed)
  2. Output Y (only the tail of the prompt is re-processed)
  3. Output Y (nothing is re-processed)
  4. Output Y (nothing is re-processed)

So it looked like the tail had been cached, but only after the second try.

I did try printing token IDs and vectors, but couldn't see anything obviously wrong - without the full understanding of the underlying caching machinery at least.

I'm happy to look again if you could give me a hand with a bit more detailed insight into what exactly to debug.

@chemeris
Copy link
Author

chemeris commented Jun 6, 2025

@Isotr0py @DarkLight1337 Looks like the tests are passing now, and comments by @Isotr0py have been implemented. Is it possible to merge the PR while we're looking at the caching issue, as it seems to be unrelated to this specific PR?

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry again for the delay! Overall the PR looks good to me!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) June 6, 2025 09:45
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 6, 2025
@Isotr0py
Copy link
Collaborator

Isotr0py commented Jun 6, 2025

Please take a look to the failing basic model tests and multimodal tests.

And nearly forgotten, can you update the supported_models documentation to include this model?

Alexander Chemeris added 2 commits June 7, 2025 19:30
auto-merge was automatically disabled June 7, 2025 23:33

Head branch was pushed to by a user without write access

@chemeris chemeris requested a review from aarnphm as a code owner June 7, 2025 23:33
chemeris and others added 2 commits June 8, 2025 10:44
Signed-off-by: Alexander Chemeris <[email protected]>
Signed-off-by: Alexander Chemeris <[email protected]>
valid_lengths = mask.sum(dim=1).long() # Shape: (batch_size)

patch_cnt = (valid_lengths + self.patch_size -
1) // self.patch_size # 向上取整
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1) // self.patch_size # 向上取整
1) // self.patch_size

@@ -0,0 +1,442 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

dummy_inputs=Qwen3TSDummyInputsBuilder,
)
class Qwen3TSForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to add LoRA for multimodal models, you also need to implement get_mm_mapping. Please refer to get_mm_mapping, or you can remove SupportsLoRA

@@ -0,0 +1,442 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build frontend multi-modality Related to multi-modality (#4194) qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

6 participants