import typing import numpy import pandas as pd from PyQt6 import QtCore from PyQt6.QtCore import Qt, QModelIndex, QSortFilterProxyModel def _get_str_dataframe(df: pd.DataFrame) -> pd.DataFrame: """Return a given dataframe with all values turned into strings. When the data given to the PandasModel class contains non-strings, an attached QTableView seems to respond rather slowly. This function turns all data in the DataFrame into strings, yielding a better experience. """ return df.astype(str) class PandasModel(QtCore.QAbstractTableModel): def __init__(self, parent=None): QtCore.QAbstractTableModel.__init__(self, parent) self._data = pd.DataFrame() self._data_str = pd.DataFrame() # # Overloaded functions # def rowCount(self, parent=None): return len(self._data_str.values) def columnCount(self, parent=None): return self._data_str.columns.size def data(self, index, role=Qt.ItemDataRole.DisplayRole): if not index.isValid(): return QtCore.QVariant() if (role != Qt.ItemDataRole.DisplayRole) and ( role != Qt.ItemDataRole.EditRole): return QtCore.QVariant() if role == Qt.ItemDataRole.DisplayRole: item = self._data_str.iloc[index.row(), index.column()] return QtCore.QVariant(item) elif role == Qt.ItemDataRole.EditRole: item = self._data.iloc[index.row(), index.column()] if type(item) is numpy.float64: return QtCore.QVariant(float(item)) else: return QtCore.QVariant(item) def headerData(self, section, orientation, role=Qt.ItemDataRole.DisplayRole): if not ((orientation == Qt.Orientation.Horizontal) and (role == Qt.ItemDataRole.DisplayRole)): return super().headerData(section, orientation, role) return self._data_str.columns[section] # # Manipulate categories # def assign_category(self, category, row_indices): if 'category' not in self._data.columns: self.create_column('category') self._data.loc[row_indices, 'category'] = category self._data_str = _get_str_dataframe(self._data) for row_index in row_indices: start_index = self.index(row_index, 0) stop_index = self.index(row_index, len(self._data.columns) - 1) self.dataChanged.emit(start_index, stop_index) def delete_category(self, category): if 'category' not in self._data.columns: self.create_column('category') row_indices = self._data.loc[self._data['category'] == category].index self.assign_category(' ', row_indices) def rename_category(self, old_name, new_name): if 'category' not in self._data.columns: self.create_column('category') row_indices = self._data.loc[self._data['category'] == old_name].index self.assign_category(new_name, row_indices) def get_categories(self) -> typing.List[str]: if 'category' not in self._data.columns: self.create_column('category') return self._data['category'].unique() # # Manipulate columns # def create_column(self, column, initial_value=' '): self._data[column] = [initial_value] * len(self._data.index) self._data_str = _get_str_dataframe(self._data) self.layoutAboutToBeChanged.emit() self.layoutChanged.emit() def delete_column_by_index(self, column_index): self._data \ = self._data.iloc[:, [j for j, c in enumerate(self._data.columns) if j != column_index]] self._data_str = _get_str_dataframe(self._data) self.layoutAboutToBeChanged.emit() self.layoutChanged.emit() def rename_column(self, old_name, new_name): if new_name in self._data.columns: raise Exception( f"A column with the name '{new_name}' already exists.") self._data = self._data.rename(columns={old_name: new_name}) self._data_str = _get_str_dataframe(self._data) column_index = self._data.columns.get_loc(new_name) self.headerDataChanged.emit(Qt.Orientation.Horizontal, column_index, column_index) def switch_columns(self, column1, column2): column_titles = list(self._data.columns) index1, index2 \ = column_titles.index(column1), column_titles.index(column2) column_titles[index1], column_titles[index2] \ = column_titles[index2], column_titles[index1] self._data = self._data.reindex(columns=column_titles) self._data_str = _get_str_dataframe(self._data) self.layoutAboutToBeChanged.emit() self.layoutChanged.emit() def get_columns(self) -> typing.List[str]: return list(self._data.columns) def assign_float_column(self, column, decimal_sep): if decimal_sep == ',': self._data[column] \ = self._data[column].str.replace(',', '.').astype(float) else: self._data[column] = self._data[column].astype(float) self._data_str = _get_str_dataframe(self._data) column_index = self._data.columns.get_loc(column) start_index = self.index(0, column_index) stop_index = self.index(len(self._data.index), column_index) self.dataChanged.emit(start_index, stop_index) def assign_date_column(self, column, date_format): self._data[column] \ = pd.to_datetime(self._data[column], format=date_format) self._data_str = _get_str_dataframe(self._data) # # Directly access dataframe # def set_dataframe(self, df): self._data = df self._data_str = _get_str_dataframe(df) self.layoutAboutToBeChanged.emit() self.layoutChanged.emit() def get_dataframe(self) -> pd.DataFrame: return self._data