diff --git a/textractor/entities/table.py b/textractor/entities/table.py index e8fccdf..bf7a5d2 100644 --- a/textractor/entities/table.py +++ b/textractor/entities/table.py @@ -497,28 +497,18 @@ def __getitem__(self, key): return new_table - def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLinearizationConfig()): + def _process_table(self, use_columns=False, config: TextLinearizationConfig = TextLinearizationConfig()): """ - Converts the table to a pandas DataFrame - - :param use_columns: If the first row of the table is made of column headers, use them for the pandas dataframe. Only supports single row header. - :param config: Text linearization configuration object for the table content - :return: + Processes the table into a list of rows for consumption by to_pandas and to_list. + Returns (table: List[List[str]], columns: Optional[List[str]]) """ - try: - from pandas import DataFrame - except ImportError: - raise MissingDependencyException( - "pandas library is required for exporting tables to DataFrame objects or markdown" - ) - rows = sorted([(key, list(group)) for key, group in itertools.groupby( self.table_cells, key=lambda cell: cell.row_index )], key=lambda r: r[0]) row_offset = 0 - columns = None processed_cells = set() + table = [] if use_columns: # Try to automatically get the columns if they are in the first row columns = [[] for _ in range(self.column_count)] @@ -529,9 +519,9 @@ def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLin break for i, cell in enumerate(row): if ( - cell not in processed_cells or - config.table_duplicate_text_in_merged_cells or - config.table_flatten_headers + cell not in processed_cells or + config.table_duplicate_text_in_merged_cells or + config.table_flatten_headers ): if cell.siblings: # This handles the edge case where we are flattening the headers @@ -557,15 +547,16 @@ def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLin else: columns[i].append("") row_offset += 1 - # If we have the correct number of column and at least half the row is tagged as a header - if len(columns) == self.column_count and is_header_count / len(columns) >= config.table_column_header_threshold: + if len(columns) == self.column_count and is_header_count / len( + columns) >= config.table_column_header_threshold: use_columns = True else: use_columns = False logger.info( f"The number of column header cell do not match the column count, ignoring them, {len(columns)} vs {self.column_count}" ) - + columns = None + row_offset = 0 if columns and any([c for c in columns]) and config.table_flatten_headers: columns = ["".join(c) for c in columns] table = [columns] @@ -574,9 +565,6 @@ def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLin columns = [c[0] for c in columns] table = [columns] row_offset = 1 - else: - table = [] - for _, row in rows[row_offset:]: table.append([]) for cell in row: @@ -584,7 +572,8 @@ def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLin if cell.siblings: children = [] first_row, first_col, last_row, last_col = cell._get_merged_cell_range() - if (cell.col_index == first_col and cell.row_index == first_row) or config.table_duplicate_text_in_merged_cells: + if ( + cell.col_index == first_col and cell.row_index == first_row) or config.table_duplicate_text_in_merged_cells: for sib in cell.siblings: children.extend(sib.children) processed_cells.add(sib) @@ -599,12 +588,38 @@ def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLin text = config.table_cell_empty_cell_placeholder if config.table_cell_empty_cell_placeholder else "" else: text = cell.get_text(config) - table[-1][cell.col_index - 1] = text if text or not config.table_cell_empty_cell_placeholder else config.table_cell_empty_cell_placeholder + table[-1][ + cell.col_index - 1] = text if text or not config.table_cell_empty_cell_placeholder else config.table_cell_empty_cell_placeholder + return table, columns - return DataFrame( - table[1:] if use_columns else table, - columns=columns if use_columns else None, - ) + def to_list(self, config: TextLinearizationConfig = TextLinearizationConfig()): + """ + Converts the table to a list of lists. + :param config: Text linearization configuration object for the table content + :return: List of rows representing the table. + :rtype: List[List[str]] + """ + table, columns = self._process_table(use_columns=False, config=config) + return table + + def to_pandas(self, use_columns=False, config: TextLinearizationConfig = TextLinearizationConfig()): + """ + Converts the table to a pandas DataFrame + :param use_columns: If the first row of the table is made of column headers, use them for the pandas dataframe. Only supports single row header. + :param config: Text linearization configuration object for the table content + :return: + """ + try: + from pandas import DataFrame + except ImportError: + raise MissingDependencyException( + "pandas library is required for exporting tables to DataFrame objects or markdown" + ) + table, columns = self._process_table(use_columns=use_columns, config=config) + if columns is not None and use_columns: + return DataFrame(table[1:], columns=columns) + else: + return DataFrame(table) def to_csv(self, use_columns = False, config: TextLinearizationConfig = TextLinearizationConfig()) -> str: """Returns the table in the Comma-Separated-Value (CSV) format