11from typing import List
2- from typing import Union
32
43import evaluate
54import pandas as pd
65
7- from evidently .base_metric import ColumnName
86from evidently .base_metric import InputData
97from evidently .base_metric import Metric
108from evidently .base_metric import MetricResult
119from evidently .core import IncludeTags
1210from evidently .model .widget import BaseWidgetInfo
11+ from evidently .options .base import AnyOptions
1312from evidently .renderers .base_renderer import MetricRenderer
1413from evidently .renderers .base_renderer import default_renderer
1514from evidently .renderers .html_widgets import header_text
1615from evidently .renderers .html_widgets import table_data
16+ from evidently .renderers .html_widgets import text_widget
1717
1818
1919class ROUGESummaryMetricResult (MetricResult ):
2020 class Config :
2121 type_alias = "evidently:metric_result:ROUGESummaryMetricResult"
2222 field_tags = {
23+ "current" : {IncludeTags .Current },
24+ "reference" : {IncludeTags .Reference },
2325 "rouge_type" : {IncludeTags .Parameter },
24- "value" : {IncludeTags .Parameter },
26+ "per_row_scores" : {IncludeTags .Parameter },
27+ "summary_score" : {IncludeTags .Parameter },
2528 }
2629
30+ current : list
31+ reference : list
2732 rouge_type : str
28- score : dict
33+ per_row_scores : list
34+ summary_score : float
2935
3036
3137class ROUGESummaryMetric (Metric [ROUGESummaryMetricResult ]):
@@ -36,50 +42,62 @@ class Config:
3642 column_name : str
3743 rouge_n : int
3844
39- def __init__ (self , column_name : Union [ str , ColumnName ], rouge_n : int ):
45+ def __init__ (self , column_name : str , rouge_n : int , options : AnyOptions = None ):
4046 self .column_name = column_name
4147 self .rouge_n = rouge_n
42- super ().__init__ ()
48+ super ().__init__ (options = options )
4349
44- def _calculate_summary_rouge (self , current_data : pd .Series , reference_data : pd .Series ):
50+ def _calculate_summary_rouge (self , current : pd .Series , reference : pd .Series ):
4551 rouge_evaluator = evaluate .load ("rouge" )
4652
47- predictions = current_data .astype (str ).tolist ()
48- references = reference_data .astype (str ).tolist ()
53+ current = current .astype (str ).tolist ()
54+ reference = reference .astype (str ).tolist ()
4955
5056 rouge_scores = rouge_evaluator .compute (
51- rouge_types = [f"rouge{ self .rouge_n } " ], predictions = predictions , references = references , use_aggregator = False
57+ rouge_types = [f"rouge{ self .rouge_n } " ], predictions = current , references = reference , use_aggregator = False
5258 )
5359
5460 per_row_rouge_scores = rouge_scores [f"rouge{ self .rouge_n } " ]
5561
5662 summary_rouge_score = sum (per_row_rouge_scores ) / len (per_row_rouge_scores )
5763
58- return per_row_rouge_scores , summary_rouge_score
64+ return per_row_rouge_scores , summary_rouge_score , current , reference
5965
60- def calculate (self , data : InputData ) -> MetricResult :
66+ def calculate (self , data : InputData ) -> ROUGESummaryMetricResult :
67+ if data .current_data is None or data .reference_data is None :
68+ raise ValueError ("The current data or the reference data is None." )
6169 if len (data .current_data [self .column_name ]) == 0 or len (data .reference_data [self .column_name ]) == 0 :
6270 raise ValueError ("The current data or the reference data is empty." )
6371
64- per_row_rouge_scores , summary_rouge_score = self ._calculate_summary_rouge (
72+ per_row_rouge_scores , summary_rouge_score , current , reference = self ._calculate_summary_rouge (
6573 data .current_data [self .column_name ], data .reference_data [self .column_name ]
6674 )
6775
6876 result = ROUGESummaryMetricResult (
6977 rouge_type = f"ROUGE-{ self .rouge_n } " ,
70- score = {"per_row_scores" : per_row_rouge_scores , "summary_score" : summary_rouge_score },
78+ per_row_scores = per_row_rouge_scores ,
79+ summary_score = summary_rouge_score ,
80+ current = current ,
81+ reference = reference ,
7182 )
7283 return result
7384
7485
7586@default_renderer (wrap_type = ROUGESummaryMetric )
7687class ROUGESummaryMetricRenderer (MetricRenderer ):
7788 @staticmethod
78- def _get_table (metric , n : int = 2 ) -> BaseWidgetInfo :
79- column_names = ["Metric" , "Value" ]
80- rows = ([metric .rouge_type , metric .score ],)
89+ def _get_table (metric ) -> BaseWidgetInfo :
90+ column_names = ["Metric" , "current" , "reference" , "score" ]
91+ rows = []
92+ for i in range (len (metric .current )):
93+ rows .append ([metric .rouge_type , metric .current [i ], metric .reference [i ], metric .per_row_scores [i ]])
94+ # rows.append(["metric.rouge_type", 1, "metric.current[i]", "metric.reference[i]", 2.4])
8195 return table_data (title = "" , column_names = column_names , data = rows )
8296
83- def render_html (self , obj : ROUGESummaryMetricResult ) -> List [BaseWidgetInfo ]:
97+ def render_html (self , obj : ROUGESummaryMetric ) -> List [BaseWidgetInfo ]:
8498 metric = obj .get_result ()
85- return [header_text (label = "ROUGE Metric" ), self ._get_table (metric )]
99+ return [
100+ header_text (label = "ROUGE Metric" ),
101+ self ._get_table (metric ),
102+ text_widget (text = f"{ metric .summary_score } " , title = "Overall ROUGE score" ),
103+ ]
0 commit comments