diff --git a/utilities/database/database.py b/utilities/database/database.py index d40b68b..28d4794 100644 --- a/utilities/database/database.py +++ b/utilities/database/database.py @@ -2,10 +2,13 @@ class Database(): def __init__(self): pass - def query(self, query: str, parameters: list=None) -> tuple[int, list[dict]]: + def close(self): pass - def select(self, table: str, columns: list[str]=None, where: list[dict]=None) -> list[dict]: + def query(self, query: str, parameters: list=None) -> tuple[int, list[tuple]]: + pass + + def select(self, table: str, columns: list[str]=None, where: list[dict]=None) -> list[tuple]: query_string = "SELECT {columns} FROM {table}{where}".format( columns = "*" if columns is None else ",".join(columns), table = table, diff --git a/utilities/database/sqlite.py b/utilities/database/sqlite.py index e69de29..824fa70 100644 --- a/utilities/database/sqlite.py +++ b/utilities/database/sqlite.py @@ -0,0 +1,25 @@ +from database import Database +import sqlite3 + +class SQLite(Database): + def __init__(self, db_name: str): + super(self, Database).__init__() + + self.db = sqlite3.connect(db_name) + + def close(self): + self.db.close() + + def query(self, query: str, parameters: list=None) -> tuple[int, list[tuple]]: + if parameters is None: + cursor = self.db.execute(query) + else: + cursor = self.db.execute(query, parameters) + + if query.casefold().startswith("SELECT".casefold()): + rows = list(cursor) + + return (len(rows), rows) + else: + return (cursor.rowcount, None) +