Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
.history*
.idea
__pycache__
cache/
cache/
.env
*.code-workspace
.vscode/
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,25 @@ export OPENAI_API_KEY=

It assumes two files with the same number of lines. It prints the score for each line pair:

### Source-only evaluation (no reference)

```
python main.py --source=source.txt --hypothesis=hypothesis.txt --source_lang=English --target_lang=Czech --method="GEMBA-MQM" --model="gpt-4"
```

The main recommended methods: `GEMBA-MQM` and `GEMBA-DA` with the model `gpt-4`.
### Reference-based evaluation

For reference-based evaluation, use the `_ref` methods and provide a reference file:

```
python main.py --source=source.txt --hypothesis=hypothesis.txt --reference=reference.txt --source_lang=English --target_lang=Czech --method="GEMBA-DA_ref" --model="gpt-4"
```

**Available methods:**
- Source-only: `GEMBA-MQM`, `GEMBA-DA`, `GEMBA-SQM`, `GEMBA-stars`, `GEMBA-classes`
- Reference-based: `GEMBA-DA_ref`, `GEMBA-SQM_ref`, `GEMBA-stars_ref`, `GEMBA-classes_ref`

The main recommended methods: `GEMBA-MQM` and `GEMBA-DA` (source-only) or `GEMBA-DA_ref` (reference-based) with the model `gpt-4`.

## Collecting and evaluating experiments for GEMBA-DA

Expand Down
20 changes: 19 additions & 1 deletion gemba/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,26 @@
from gemba.prompt import prompts, validate_number


def get_gemba_scores(source, hypothesis, source_lang, target_lang, method, model):
def get_gemba_scores(source, hypothesis, source_lang, target_lang, method, model, reference=None):
# Validate reference usage
method_uses_ref = method.endswith('_ref')
if method_uses_ref and reference is None:
raise ValueError(f"Method '{method}' requires a reference, but none was provided. "
f"Please provide a reference file using the --reference flag.")
if not method_uses_ref and reference is not None:
print(f"Warning: Reference provided but method '{method}' does not use references. "
f"Consider using '{method}_ref' to utilize the reference in evaluation.")

# Build DataFrame with source and hypothesis
df = pd.DataFrame({'source_seg': source, 'target_seg': hypothesis})

# Add reference if provided
if reference is not None:
if len(reference) != len(source):
raise ValueError(f"Reference has {len(reference)} lines but source has {len(source)} lines. "
f"All files must have the same number of lines.")
df['reference_seg'] = reference

df['source_lang'] = source_lang
df['target_lang'] = target_lang

Expand Down
19 changes: 18 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
flags.DEFINE_string('model', "gpt-4", 'OpenAI model')
flags.DEFINE_string('source', None, 'Filepath to the source file.')
flags.DEFINE_string('hypothesis', None, 'Filepath to the translation file.')
flags.DEFINE_string('reference', None, 'Filepath to the reference file (optional, required for *_ref methods).')
flags.DEFINE_string('source_lang', None, 'Source language name.')
flags.DEFINE_string('target_lang', None, 'Target language name.')

Expand All @@ -27,6 +28,9 @@ def main(argv):
if not os.path.isfile(FLAGS.hypothesis):
print(f"Hypothesis file {FLAGS.hypothesis} does not exist.")
sys.exit(1)
if FLAGS.reference is not None and not os.path.isfile(FLAGS.reference):
print(f"Reference file {FLAGS.reference} does not exist.")
sys.exit(1)

assert FLAGS.source_lang is not None, "Source language name must be provided."
assert FLAGS.target_lang is not None, "Target language name must be provided."
Expand All @@ -39,9 +43,22 @@ def main(argv):
hypothesis = f.readlines()
hypothesis = [x.strip() for x in hypothesis]

# load reference file if provided
reference = None
if FLAGS.reference is not None:
with open(FLAGS.reference, 'r') as f:
reference = f.readlines()
reference = [x.strip() for x in reference]

# validate that reference has the same number of lines
if len(reference) != len(source):
print(f"Error: Reference file has {len(reference)} lines but source has {len(source)} lines.")
print("All files must have the same number of lines.")
sys.exit(1)

assert len(source) == len(hypothesis), "Source and hypothesis files must have the same number of lines."

answers = get_gemba_scores(source, hypothesis, FLAGS.source_lang, FLAGS.target_lang, FLAGS.method, FLAGS.model)
answers = get_gemba_scores(source, hypothesis, FLAGS.source_lang, FLAGS.target_lang, FLAGS.method, FLAGS.model, reference=reference)

for answer in answers:
print(answer)
Expand Down