generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Openenv wordle example #4357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
burtenshaw
wants to merge
13
commits into
main
Choose a base branch
from
openenv-wordle-example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Openenv wordle example #4357
Changes from 12 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
67fcefd
drop in wordle script from experiment
burtenshaw ecf79b8
first draft of example for wordle
burtenshaw c7331b7
refactor wordle script for release
burtenshaw baeb233
copy latest snippets from wordle.py into example
burtenshaw c847685
add logs
burtenshaw e403cd4
Update docs/source/openenv.md
sergiopaniego e95e7b3
simplify rollout function in wordle script
burtenshaw 36e2fc8
respond to feedback
burtenshaw 53cf5ca
Merge branch 'openenv-wordle-example' of https://github.com/huggingfa…
burtenshaw 93eafae
Merge branch 'main' into openenv-wordle-example
sergiopaniego 2b1bcb4
Updated to pass code quality test
sergiopaniego 3f7c2fb
Merge branch 'main' into openenv-wordle-example
sergiopaniego 5595b46
Update examples/scripts/openenv/wordle.py
burtenshaw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,3 +156,218 @@ Below is the reward curve from training: | |
| <iframe src="https://trl-lib-trackio.hf.space?project=openenv&metrics=train/rewards/reward_from_env/mean&runs=qgallouedec-1761202871&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe> | ||
|
|
||
| To learn more about how to create custom environments, see the [OpenEnv documentation](https://github.com/meta-pytorch/OpenEnv/blob/main/src/envs/README.md). | ||
|
|
||
| ## Advanced Example | ||
|
|
||
| Let's level this up a bit by training a model to interact with a more complex environment. We'll use the game word guessing game [wordle](https://www.nytimes.com/games/wordle/index.html) from the `textarena` environment. | ||
|
|
||
| ### The TextArena Environment | ||
|
|
||
| [TextArena](https://huggingface.co/papers/2504.11442) is an open-source collection of competitive text-based games designed to evaluate reasoning skills in LLMs using textual games like Wordle, Snake, Tic-Tac-Toe, and more. Research has shown that such games improve model performance on reasoning tasks. | ||
|
|
||
|  | ||
|
|
||
| We will use the `textarena` environment to train a model to play Wordle. The environment is a simple text based response environment that allows the model to interact with the game by making guesses and receive feedback on them. | ||
|
|
||
| ### Wordle | ||
|
|
||
| Wordle is a useful game to train a model on because it requires the model to reason about the word and the feedback provided by the environment. Also, it is a purely language based game that requires no external tools or knowledge. Furthermore, we found that models from 1 billion parameters and up are able to improve on wordle and only require 8 tokens to generate a guess, which makes the game a good benchmark to experiment with Reinforcement Learning environments without significant compute requirements. | ||
|
|
||
| > [!NOTE] How does Wordle work? | ||
| > Wordle is a word guessing game where the player has to guess a 5-letter word. The player can make 6 guesses, and for each guess, the environment will provide feedback on the correctness of the guess. The player wins if they guess the word in 6 guesses or less. It challenges the model to generate words that are likely to be correct, and to learn from the feedback provided by the environment. | ||
| > | ||
| > For example, if the wordle environment returns the following feedback: | ||
| > | ||
| > ``` | ||
| > G U E S S | ||
| > X G Y X X | ||
| > ``` | ||
| > The model has guessed the word "GUESS" and the environment has provided feedback as the letters X, G, and Y. Referring to colors in the original game blank, green, and yellow. From this feedback, the model should learn that the word is "GUESS" is incorrect. The letter "E" is in the word, but in the wrong position. The letter "U" is correct and in the correct position. | ||
|
|
||
| In the TextArena environment, reward is only given when the model wins the game. The reward is 1.0 if the model wins, and 0.0 otherwise. This is not a very efficient reward signal for the model, so we have added a number of custom reward functions to the script to help the model learn to play the game. The extensible nature of `reward_funcs` and `rollout_func` allows you to add any custom reward function you want to the script. | ||
|
|
||
| ### Rollout Function | ||
|
|
||
| The rollout function runs one full Wordle episode, prompting the model for a guess each turn and capturing both environment rewards and auxiliary signals such as letter coverage and repetition penalties. | ||
|
|
||
| ```python | ||
| def rollout_once( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the function is large, we could maybe add some comments explaining the different parts. |
||
| env: TextArenaEnv, | ||
| tokenizer: AutoTokenizer, | ||
| args: GRPOConfig, | ||
| dataset_prompt: str, | ||
| cli_args: argparse.Namespace, | ||
| system_prompt: str, | ||
| ) -> dict[str, list]: | ||
| result = env.reset() | ||
| observation = result.observation | ||
|
|
||
| prompt_ids: list[int] = [] | ||
| completion_ids: list[int] = [] | ||
| logprobs: list[float] = [] | ||
| raw_rewards: list[float] = [] | ||
| green_scores: list[float] = [] | ||
| yellow_scores: list[float] = [] | ||
| repetition_scores: list[float] = [] | ||
| correct_scores: list[float] = [] | ||
| guess_counts: dict[str, int] = {} | ||
|
|
||
| for _turn in range(cli_args.max_turns): | ||
| # when the game is over the environment will return a done=True | ||
| if result.done: | ||
| break | ||
|
|
||
| # set up the prompt for the model | ||
| base_prompt = observation.prompt or dataset_prompt | ||
| user_prompt = make_user_prompt(base_prompt, observation.messages) | ||
| messages = [ | ||
| {"role": "system", "content": system_prompt}, | ||
| {"role": "user", "content": user_prompt}, | ||
| ] | ||
| prompt_text = tokenizer.apply_chat_template( | ||
| messages, | ||
| add_generation_prompt=True, | ||
| tokenize=False, | ||
| enable_thinking=False, | ||
| ) | ||
|
|
||
| # generate the completion from the model using vLLM | ||
| vllm_result = request_vllm_completion( | ||
| prompt_text, | ||
| args, | ||
| endpoint=cli_args.vllm_endpoint, | ||
| timeout=cli_args.request_timeout, | ||
| fallback=cli_args, | ||
| ) | ||
| prompt_ids.extend(vllm_result["prompt_ids"]) | ||
| completion_ids.extend(vllm_result["completion_ids"]) | ||
| logprobs.extend(vllm_result["logprobs"]) | ||
| completion_text = vllm_result.get("text") or tokenizer.decode( | ||
| vllm_result["completion_ids"], skip_special_tokens=True | ||
| ) | ||
| # extract the guess from the completion | ||
| guess = extract_guess(completion_text) | ||
|
|
||
| # step the environment with the guess | ||
| result = env.step(TextArenaAction(message=guess)) | ||
| raw_rewards.append(float(result.reward or 0.0)) | ||
| observation = result.observation | ||
| correct_score = float(result.reward or 0.0) | ||
| feedback = extract_wordle_feedback(observation) | ||
|
|
||
| # Update guess counts | ||
| previous_occurrences = guess_counts[guess] | ||
| repetition_score = scale_repetition_score(previous_occurrences, len(guess_counts)) | ||
| guess_counts[guess] += 1 | ||
|
|
||
| # calculate custom reward signals from the feedback | ||
| if not feedback: | ||
| green_score = 0.0 | ||
| yellow_score = 0.0 | ||
| else: | ||
| green_count, yellow_count = extract_feedback_counts(feedback) | ||
| green_score = green_count / 5.0 | ||
| yellow_score = yellow_count / 5.0 | ||
|
|
||
| repetition_scores.append(repetition_score) | ||
| green_scores.append(green_score) | ||
| yellow_scores.append(yellow_score) | ||
| correct_scores.append(correct_score) | ||
|
|
||
| correct_reward_value = correct_scores[-1] if correct_scores else (raw_rewards[-1] if raw_rewards else 0.0) | ||
|
|
||
| return { | ||
| "prompt_ids": prompt_ids, | ||
| "completion_ids": completion_ids, | ||
| "logprobs": logprobs, | ||
| "raw_rewards": raw_rewards, | ||
| "correct_reward": correct_reward_value, | ||
| "green_reward": green_scores[-1] if green_scores else 0.0, | ||
| "yellow_reward": yellow_scores[-1] if yellow_scores else 0.0, | ||
| "repetition_reward": repetition_scores[-1] if repetition_scores else 0.0, | ||
| } | ||
| ``` | ||
|
|
||
| The environment has a reward signal based on the completion of the game. We found that most models struggle to ever win the game, so we have added a number of custom reward functions to the script to help the model learn to play the game more iteratively. At first, the model will learn to cover new letters and avoid repeating guesses. As it improves, it will learn to win the game. | ||
|
|
||
| ### Reward Functions | ||
|
|
||
| We log four reward streams that encourage the model to solve the puzzle, cover new letters, and avoid repeating guesses: | ||
|
|
||
| - `reward_correct`: final win/loss signal from the environment. | ||
| - `reward_greens`: density of green letters in the last feedback. | ||
| - `reward_yellows`: density of yellow letters in the last feedback. | ||
| - `reward_repetition`: penalty for guessing the same token multiple times. | ||
|
|
||
| ```python | ||
| def reward_correct(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: | ||
| rewards = kwargs.get("correct_reward") if kwargs else None | ||
| return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) | ||
|
|
||
|
|
||
| def reward_greens(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: | ||
| rewards = kwargs.get("green_reward") if kwargs else None | ||
| return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) | ||
|
|
||
|
|
||
| def reward_yellows(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: | ||
| rewards = kwargs.get("yellow_reward") if kwargs else None | ||
| return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) | ||
|
|
||
|
|
||
| def reward_repetition(completions: List[str], **kwargs: Optional[Dict]) -> List[float]: | ||
| rewards = kwargs.get("repetition_reward") if kwargs else None | ||
| return [float(r) for r in rewards] if rewards is not None else [0.0] * len(completions) | ||
| ``` | ||
|
|
||
| ### Training the Model | ||
|
|
||
| The training script wires the custom rollout and rewards into `GRPOTrainer`. The CLI exposes the configuration used during development as defaults, so you can override endpoints or hyperparameters at launch time. | ||
|
|
||
| ```python | ||
| parser = argparse.ArgumentParser() | ||
| # ... add CLI arguments with sensible defaults ... | ||
| cli_args = parser.parse_args() | ||
|
|
||
| trainer = GRPOTrainer( | ||
| model=cli_args.model_id, | ||
| processing_class=tokenizer, | ||
| reward_funcs=[ | ||
| reward_correct, | ||
| reward_greens, | ||
| reward_yellows, | ||
| reward_repetition, | ||
| ], | ||
| train_dataset=dataset, | ||
| args=grpo_config, | ||
| rollout_func=lambda prompts, args, processing_class: rollout_func( | ||
| env=env, | ||
| tokenizer=tokenizer, | ||
| prompts=prompts, | ||
| args=args, | ||
| cli_args=cli_args, | ||
| system_prompt=system_prompt, | ||
| ), | ||
| ) | ||
| trainer.train() | ||
| ``` | ||
|
|
||
| ### Running the Example | ||
|
|
||
| The example requires two GPUs: | ||
|
|
||
| ```bash | ||
| # Terminal 1: Start vLLM inference server | ||
| CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen2.5-0.5B-Instruct --host 0.0.0.0 --port 8000 | ||
|
|
||
| # Terminal 2: Run GRPO training with OpenEnv | ||
| CUDA_VISIBLE_DEVICES=1 python examples/scripts/openenv/wordle.py | ||
| ``` | ||
|
|
||
| ### Results | ||
|
|
||
| The resulting model improves it's performance on the game, both by reducing the number of repetitions and by increasing the number of correct guesses. However, the the Qwen3-1.7B model we trained is not able to consistently win the game. The following reward curve shows the coverage of the model's guesses and the coverage of correct Y and G letters. | ||
|
|
||
| <iframe src="https://burtenshaw-wordle-grpo.hf.space/?project=group-Qwen-Qwen3-17B&metrics=train/rewards/reward_coverage/mean&runs=run-2025-10-26_09-39-49&sidebar=hidden&navbar=hidden" style="width:600px; height:500px; border:0;"></iframe> | ||
|
|
||
| We experimented larger models like `gpt-oss-20b` and found that model was able to consistently win the game. However, this requires a lot of compute to train and the model. Why not try this out yourself? | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.