banking-breakdown/banking_breakdown/ui/pandas_model.py

173 lines
5.7 KiB
Python

import typing
import numpy
import pandas as pd
from PyQt6 import QtCore
from PyQt6.QtCore import Qt, QModelIndex, QSortFilterProxyModel
def _get_str_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Return a given dataframe with all values turned into strings.
When the data given to the PandasModel class contains non-strings,
an attached QTableView seems to respond rather slowly. This function
turns all data in the DataFrame into strings, yielding a better experience.
"""
return df.astype(str)
class PandasModel(QtCore.QAbstractTableModel):
def __init__(self, parent=None):
QtCore.QAbstractTableModel.__init__(self, parent)
self._data = pd.DataFrame()
self._data_str = pd.DataFrame()
#
# Overloaded functions
#
def rowCount(self, parent=None):
return len(self._data_str.values)
def columnCount(self, parent=None):
return self._data_str.columns.size
def data(self, index, role=Qt.ItemDataRole.DisplayRole):
if not index.isValid():
return QtCore.QVariant()
if (role != Qt.ItemDataRole.DisplayRole) and (
role != Qt.ItemDataRole.EditRole):
return QtCore.QVariant()
if role == Qt.ItemDataRole.DisplayRole:
item = self._data_str.iloc[index.row(), index.column()]
return QtCore.QVariant(item)
elif role == Qt.ItemDataRole.EditRole:
item = self._data.iloc[index.row(), index.column()]
if type(item) is numpy.float64:
return QtCore.QVariant(float(item))
else:
return QtCore.QVariant(item)
def headerData(self, section, orientation,
role=Qt.ItemDataRole.DisplayRole):
if not ((orientation == Qt.Orientation.Horizontal)
and (role == Qt.ItemDataRole.DisplayRole)):
return super().headerData(section, orientation, role)
return self._data_str.columns[section]
#
# Manipulate categories
#
def assign_category(self, category, row_indices):
if 'category' not in self._data.columns:
self.create_column('category')
self._data.loc[row_indices, 'category'] = category
self._data_str = _get_str_dataframe(self._data)
for row_index in row_indices:
start_index = self.index(row_index, 0)
stop_index = self.index(row_index, len(self._data.columns) - 1)
self.dataChanged.emit(start_index, stop_index)
def delete_category(self, category):
if 'category' not in self._data.columns:
self.create_column('category')
row_indices = self._data.loc[self._data['category'] == category].index
self.assign_category(' ', row_indices)
def get_categories(self) -> typing.List[str]:
if 'category' not in self._data.columns:
self.create_column('category')
return self._data['category'].unique()
#
# Manipulate columns
#
def create_column(self, column, initial_value=' '):
self._data[column] = [initial_value] * len(self._data.index)
self._data_str = _get_str_dataframe(self._data)
self.layoutAboutToBeChanged.emit()
self.layoutChanged.emit()
def delete_column_by_index(self, column_index):
self._data \
= self._data.iloc[:, [j for j, c in enumerate(self._data.columns)
if j != column_index]]
self._data_str = _get_str_dataframe(self._data)
self.layoutAboutToBeChanged.emit()
self.layoutChanged.emit()
def rename_column(self, old_name, new_name):
if new_name in self._data.columns:
raise Exception(
f"A column with the name '{new_name}' already exists.")
self._data = self._data.rename(columns={old_name: new_name})
self._data_str = _get_str_dataframe(self._data)
column_index = self._data.columns.get_loc(new_name)
self.headerDataChanged.emit(Qt.Orientation.Horizontal,
column_index, column_index)
def switch_columns(self, column1, column2):
column_titles = list(self._data.columns)
index1, index2 \
= column_titles.index(column1), column_titles.index(column2)
column_titles[index1], column_titles[index2] \
= column_titles[index2], column_titles[index1]
self._data = self._data.reindex(columns=column_titles)
self._data_str = _get_str_dataframe(self._data)
self.layoutAboutToBeChanged.emit()
self.layoutChanged.emit()
def get_columns(self) -> typing.List[str]:
return list(self._data.columns)
def assign_float_column(self, column, decimal_sep):
if decimal_sep == ',':
self._data[column] \
= self._data[column].str.replace(',', '.').astype(float)
else:
self._data[column] = self._data[column].astype(float)
self._data_str = _get_str_dataframe(self._data)
column_index = self._data.columns.get_loc(column)
start_index = self.index(0, column_index)
stop_index = self.index(len(self._data.index), column_index)
self.dataChanged.emit(start_index, stop_index)
def assign_date_column(self, column, date_format):
self._data[column] \
= pd.to_datetime(self._data[column], format=date_format)
self._data_str = _get_str_dataframe(self._data)
#
# Directly access dataframe
#
def set_dataframe(self, df):
self._data = df
self._data_str = _get_str_dataframe(df)
self.layoutAboutToBeChanged.emit()
self.layoutChanged.emit()
def get_dataframe(self) -> pd.DataFrame:
return self._data