diff --git a/banking_breakdown/ui/custom_ui_items.py b/banking_breakdown/ui/custom_ui_items.py index a1dceaf..08dfc11 100644 --- a/banking_breakdown/ui/custom_ui_items.py +++ b/banking_breakdown/ui/custom_ui_items.py @@ -35,16 +35,16 @@ class HeaderContextMenu(QMenu): """Context menu appearing when right-clicking the header of the QTableView. """ - def __init__(self, column, pandas_model: PandasModel, callback=None, + def __init__(self, column_index, pandas_model: PandasModel, callback=None, parent=None): super(HeaderContextMenu, self).__init__() - self._column = column self._pandas_model = pandas_model self._callback = callback + self._column_index = column_index self._column_text \ - = self._pandas_model.headerData(self._column, + = self._pandas_model.headerData(self._column_index, Qt.Orientation.Horizontal) # Define assign action @@ -85,9 +85,11 @@ class HeaderContextMenu(QMenu): return if (new_name != self._column_text) and (new_name != ''): - df = self._pandas_model.get_dataframe() - df = df.rename(columns={self._column_text: new_name}) - self._pandas_model.set_dataframe(df) + try: + self._pandas_model.rename_column(self._column_text, new_name) + except: + QMessageBox.warning(self, "No action performed", + "An error occurred.") if self._callback: self._callback() @@ -98,10 +100,7 @@ class HeaderContextMenu(QMenu): f" column '{self._column_text}'?") if button == QMessageBox.StandardButton.Yes: - df = self._pandas_model.get_dataframe() - df = df.iloc[:, [j for j, c - in enumerate(df.columns) if j != self._column]] - self._pandas_model.set_dataframe(df) + self._pandas_model.delete_column_by_index(self._column_index) if self._callback: self._callback() @@ -112,21 +111,14 @@ class HeaderContextMenu(QMenu): if column != self._column_text] other_name, flag = QInputDialog.getItem(self, "Switch column position", - f"Switch position of colum" - f" '{self._column_text}' with:", + f"Switch position of colum " + f"'{self._column_text}' with:", columns, editable=False) if not flag: return - column_titles = list(df.columns) - index1, index2 = column_titles.index( - self._column_text), column_titles.index(other_name) - column_titles[index1], column_titles[index2] \ - = column_titles[index2], column_titles[index1] - - df = df.reindex(columns=column_titles) - self._pandas_model.set_dataframe(df) + self._pandas_model.switch_columns(self._column_text, other_name) if self._callback: self._callback() @@ -139,14 +131,12 @@ class HeaderContextMenu(QMenu): if not flag: return - df = self._pandas_model.get_dataframe() try: - df[self._column_text] \ - = pd.to_datetime(df[self._column_text], format=date_format) + self._pandas_model.assign_date_column(self._column_text, + date_format) except: QMessageBox.warning(self, "No action performed", "An error occurred.") - self._pandas_model.set_dataframe(df) if self._callback: self._callback() @@ -160,19 +150,12 @@ class HeaderContextMenu(QMenu): if not flag: return - df = self._pandas_model.get_dataframe() - try: - if decimal_sep == ',': - df[self._column_text] \ - = df[self._column_text].str.replace(',', '.').astype(float) - else: - df[self._column_text] = df[self._column_text].astype(float) + self._pandas_model.assign_float_column(self._column_text, + decimal_sep) except: QMessageBox.warning(self, "No action performed", "An error occurred.") - self._pandas_model.set_dataframe(df) - if self._callback: self._callback() diff --git a/banking_breakdown/ui/main_window.py b/banking_breakdown/ui/main_window.py index 7c2a032..85a293d 100644 --- a/banking_breakdown/ui/main_window.py +++ b/banking_breakdown/ui/main_window.py @@ -45,6 +45,7 @@ class MainWindow(QMainWindow): self._proxy_model.setSourceModel(self._pandas_model) self._table_view.setModel(self._proxy_model) self._proxy_model.setSortRole(Qt.ItemDataRole.EditRole) + self._proxy_model.setDynamicSortFilter(False) # Set event handlers @@ -98,6 +99,13 @@ class MainWindow(QMainWindow): len(col)) self._table_view.setColumnWidth(i, max_char * 10) + def _assign_category_to_selected_transactions(self, category: str): + indexes = self._table_view.selectionModel().selectedRows() + row_indices = [self._table_view.model().mapToSource(index).row() + for index in indexes] + + self._pandas_model.assign_category(category, row_indices) + # # List data updates # @@ -107,16 +115,11 @@ class MainWindow(QMainWindow): self._list_widget.addItem(category) def _update_categories_from_dataframe(self): - df = self._pandas_model.get_dataframe() - - if 'category' not in df.columns: - df['category'] = [' '] * len(df.index) - - df_categories = df['category'].unique() + df_categories = self._pandas_model.get_categories() current_categories = [self._list_widget.item(x).text() for x in range(self._list_widget.count())] - missing = list(set(df_categories) - set(current_categories)) + self._add_categories([category for category in missing if category != ' ']) @@ -135,19 +138,19 @@ class MainWindow(QMainWindow): warning_item.hide() self._warning_layout.removeItem(warning_item) - df = self._pandas_model.get_dataframe() + columns = self._pandas_model.get_columns() - if 't' not in df.columns: + if 't' not in columns: self._add_warning_item( "The column 't' does not exist. Please rename the column" " containing the dates of the transactions to 't'.") - if 'value' not in df.columns: + if 'value' not in columns: self._add_warning_item( "The column 'value' does not exist. Please rename the column" " containing the values of the transactions to 'value'.") - if 'balance' not in df.columns: + if 'balance' not in columns: self._add_warning_item( "The column 'balance' does not exist. Please rename the column" " containing the balance after each transaction to 'balance'") @@ -159,7 +162,7 @@ class MainWindow(QMainWindow): def _handle_header_right_click(self, pos): column = self._table_view.horizontalHeader().logicalIndexAt(pos) - context = HeaderContextMenu(parent=self, column=column, + context = HeaderContextMenu(parent=self, column_index=column, pandas_model=self._pandas_model, callback=self._dataframe_update_callback) context.exec(self.sender().mapToGlobal(pos)) @@ -187,35 +190,19 @@ class MainWindow(QMainWindow): f"Are you sure you want to delete" f" category '{selected_item.text()}'?") - df = self._pandas_model.get_dataframe() - - if 'category' not in df.columns: - df['category'] = [' '] * len(df.index) - if button == QMessageBox.StandardButton.Yes: - df.loc[df['category'] == selected_item.text(), 'category'] = ' ' + self._pandas_model.delete_category(selected_item.text()) self._list_widget.takeItem(self._list_widget.row(selected_item)) - self._pandas_model.set_dataframe(df) def _handle_clear_click(self): - self._assign_category(' ') - - def _assign_category(self, category: str): - indexes = self._table_view.selectionModel().selectedRows() - - row_indices = [self._table_view.model().mapToSource(index).row() - for index in indexes] - - df = self._pandas_model.get_dataframe() - df.loc[row_indices, 'category'] = category - self._pandas_model.set_dataframe(df) + self._assign_category_to_selected_transactions(' ') def _handle_apply_click(self): category = self._list_widget.selectedItems()[0].text() - self._assign_category(category) + self._assign_category_to_selected_transactions(category) def _handle_item_double_click(self, item): - self._assign_category(item.text()) + self._assign_category_to_selected_transactions(item.text()) def _handle_save(self): filename, _ = QFileDialog.getSaveFileName(self, 'Save File') diff --git a/banking_breakdown/ui/pandas_model.py b/banking_breakdown/ui/pandas_model.py index f8371bf..61d6825 100644 --- a/banking_breakdown/ui/pandas_model.py +++ b/banking_breakdown/ui/pandas_model.py @@ -1,3 +1,5 @@ +import typing + import numpy import pandas as pd from PyQt6 import QtCore @@ -20,9 +22,10 @@ class PandasModel(QtCore.QAbstractTableModel): self._data = pd.DataFrame() self._data_str = pd.DataFrame() - self._horizontalHeaders = None + # # Overloaded functions + # def rowCount(self, parent=None): return len(self._data_str.values) @@ -55,14 +58,112 @@ class PandasModel(QtCore.QAbstractTableModel): and (role == Qt.ItemDataRole.DisplayRole)): return super().headerData(section, orientation, role) - return self._horizontalHeaders[section] + return self._data_str.columns[section] - # Other functions + # + # Manipulate categories + # + + def assign_category(self, category, row_indices): + if 'category' not in self._data.columns: + self.create_column('category') + + self._data.loc[row_indices, 'category'] = category + self._data_str = _get_str_dataframe(self._data) + + for row_index in row_indices: + start_index = self.index(row_index, 0) + stop_index = self.index(row_index, len(self._data.columns) - 1) + self.dataChanged.emit(start_index, stop_index) + + def delete_category(self, category): + if 'category' not in self._data.columns: + self.create_column('category') + + row_indices = self._data.loc[self._data['category'] == category].index + self.assign_category(' ', row_indices) + + def get_categories(self) -> typing.List[str]: + if 'category' not in self._data.columns: + self.create_column('category') + + return self._data['category'].unique() + + # + # Manipulate columns + # + + def create_column(self, column, initial_value=' '): + self._data[column] = [initial_value] * len(self._data.index) + self._data_str = _get_str_dataframe(self._data) + self.layoutAboutToBeChanged.emit() + self.layoutChanged.emit() + + def delete_column_by_index(self, column_index): + self._data \ + = self._data.iloc[:, [j for j, c in enumerate(self._data.columns) + if j != column_index]] + self._data_str = _get_str_dataframe(self._data) + + self.layoutAboutToBeChanged.emit() + self.layoutChanged.emit() + + def rename_column(self, old_name, new_name): + if new_name in self._data.columns: + raise Exception( + f"A column with the name '{new_name}' already exists.") + + self._data = self._data.rename(columns={old_name: new_name}) + self._data_str = _get_str_dataframe(self._data) + + column_index = self._data.columns.get_loc(new_name) + self.headerDataChanged.emit(Qt.Orientation.Horizontal, + column_index, column_index) + + def switch_columns(self, column1, column2): + column_titles = list(self._data.columns) + + index1, index2 \ + = column_titles.index(column1), column_titles.index(column2) + + column_titles[index1], column_titles[index2] \ + = column_titles[index2], column_titles[index1] + + self._data = self._data.reindex(columns=column_titles) + self._data_str = _get_str_dataframe(self._data) + + self.layoutAboutToBeChanged.emit() + self.layoutChanged.emit() + + def get_columns(self) -> typing.List[str]: + return list(self._data.columns) + + def assign_float_column(self, column, decimal_sep): + if decimal_sep == ',': + self._data[column] \ + = self._data[column].str.replace(',', '.').astype(float) + else: + self._data[column] = self._data[column].astype(float) + + self._data_str = _get_str_dataframe(self._data) + + column_index = self._data.columns.get_loc(column) + start_index = self.index(0, column_index) + stop_index = self.index(len(self._data.index), column_index) + self.dataChanged.emit(start_index, stop_index) + + def assign_date_column(self, column, date_format): + self._data[column] \ + = pd.to_datetime(self._data[column], format=date_format) + self._data_str = _get_str_dataframe(self._data) + + # + # Directly access dataframe + # def set_dataframe(self, df): self._data = df self._data_str = _get_str_dataframe(df) - self._horizontalHeaders = list(df.columns) self.layoutAboutToBeChanged.emit() self.layoutChanged.emit()