diff --git a/.gitignore b/.gitignore index 3e36220..be4c88a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ media breadbot.db .env tools/profanity_filter/bin/Words.json -tools/profanity_filter/src/Words.json \ No newline at end of file +tools/profanity_filter/src/Words.json +bin/config.json +bin/Words.json \ No newline at end of file diff --git a/bin/breadbot_common.py b/bin/breadbot_common.py new file mode 100644 index 0000000..c8ee8fa --- /dev/null +++ b/bin/breadbot_common.py @@ -0,0 +1,198 @@ +import sqlite3 +import mysql.connector +import subprocess +import string +import random +from pathlib import Path +from datetime import datetime + +class Database(): + def __init__(self): + pass + + def close(self): + pass + + def query(self, query: str, parameters: list=None) -> tuple[int, list[tuple]]: + pass + + def select(self, table: str, columns: list[str]=None, where: list[dict]=None, values: list=None) -> list[tuple]: + query_string = "SELECT {columns} FROM {table}{where}".format( + columns = "*" if columns is None else ",".join(columns), + table = table, + where = self.__generate_basic_where_clause(where) if not where is None else "" + ) + + return self.query(query_string, [element["value"] for element in where] if not where is None else None, values)[1] + + def insert(self, table: str, columns: list[str], values: list) -> int: + query_string = "INSERT INTO {table} ({columns}) VALUES ({values})".format( + table = table, + columns = ",".join(columns), + values = ("?," * len(values))[:-1] + ) + + return self.query(query_string, values)[0] + + def update(self, table: str, columns: list[str], values: list, where: list[dict]=None) -> int: + query_string = "UPDATE {table} SET {set_rules}{where}".format( + table = table, + set_rules = ",".join([element + "=?" for element in columns]), + where = self.__generate_basic_where_clause(where) if not where is None else "" + ) + + return self.query(query_string, values)[0] + + def delete(self, table: str, values: list, where: list[dict]=None) -> int: + query_string = "DELETE FROM {table}{where}".format( + table = table, + where = self.__generate_basic_where_clause(where) if not where is None else "" + ) + + return self.query(query_string, values)[0] + + # TODO This probably breaks with MySQL, because question mark bad, maybe just have MySQL class override this + def __generate_basic_where_clause(self, where: list[dict]): + return " WHERE {clauses}".format( + clauses = "".join([ + element["name"] + " " + element["compare"] + " ?" + (" " + element["boolean_op"] + " " if "boolean_op" in element else "") + for element in where + ]) + ) + +class SQLite(Database): + def __init__(self, db_name: str): + super(Database, self).__init__() + + self.db = sqlite3.connect(db_name) + self.db.autocommit = True + + def close(self): + self.db.close() + + def query(self, query: str, parameters: list=None) -> tuple[int, list[tuple]]: + if parameters is None: + cursor = self.db.execute(query) + else: + cursor = self.db.execute(query, parameters) + + if query.casefold().startswith("SELECT".casefold()): + rows = list(cursor) + + return (len(rows), rows) + elif query.casefold().startswith("INSERT".casefold()): + return (cursor.lastrowid, None) + else: + return (cursor.rowcount, None) + +class MySQL(Database): + def __init__(self, host: str, user: str, password: str, db_name: str): + super(Database, self).__init__() + + self.db = mysql.connector.connect( + host = host, + user = user, + password = password, + database = db_name + ) + self.db.autocommit = True + + def close(self): + self.db.close() + + def query(self, query: str, parameters: list=None) -> tuple[int, list[tuple]]: + with self.db.cursor() as cursor: + if parameters is None: + cursor.execute(query) + else: + cursor.execute(query, parameters) + + if query.casefold().startswith("SELECT".casefold()): + rows = cursor.fetchall() + + return (len(rows), rows) + elif query.casefold().startswith("INSERT".casefold()): + return (cursor.lastrowid, None) + else: + return (cursor.rowcount, None) + +# Data class (effective struct) because complex dictionary access is uggo. +class TranscriptableFile(): + def __init__(self, file_path: str, real_date: datetime, milliseconds_from_start: int, user_snowflake: str=None): + self.file_path = file_path + self.real_date = real_date + self.milliseconds_from_start = milliseconds_from_start + self.user_snowflake = user_snowflake + +def mix_audio_with_ffmpeg(files: list[TranscriptableFile], media_folder_path: str, call_id: int, is_final_pass: bool) -> TranscriptableFile: + filter_list = [ + "[{input_id}]adelay={delay}|{delay}[a{input_id}]".format( + input_id = index, + delay = files[index].milliseconds_from_start + ) + for index in range(len(files)) + ] + + command_list = ["ffmpeg"] + + for file in files: + command_list.append("-i") + command_list.append(file.file_path) + + command_list.append("-filter_complex") + + filter_string = "\"" + ";".join(filter_list) + ";" + + filter_string = filter_string + "".join([ + "[a{input_id}]".format( + input_id = index + ) + for index in range(len(files)) + ]) + + filter_string = filter_string + "amix=inputs={input_count}:normalize=0[a]".format( + input_count = len(files) + ) + + if is_final_pass: + filter_string = filter_string + ";[a]volume=3[boosted]\"" + else: + filter_string = filter_string + "\"" + + command_list.append(filter_string) + command_list.append("-map") + + if is_final_pass: + command_list.append("\"[boosted]\"") + else: + command_list.append("\"[a]\"") + + output_file_name = Path( + media_folder_path, + call_id, + "output.mp3" if is_final_pass else "intermediate-" + "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + ".mp3" + ) + + command_list.append(output_file_name) + + # TODO shell = True isn't great, I don't remember the reason why it has to be this way + # I *think* it had something to do with me not using ffmpeg's absolute path + ffmpeg_process = subprocess.Popen( + ' '.join(command_list), + stdout = subprocess.PIPE, + stderr = subprocess.PIPE, + shell = True + ) + + stdout, stderr = ffmpeg_process.communicate() + + if ffmpeg_process.returncode != 0: + print("An FFMPEG process failed") + print(stdout) + print(stderr) + raise Exception("An FFMPEG process broke spectacularly") + + return TranscriptableFile(output_file_name, files[0].real_date, files[0].milliseconds_from_start) + + + diff --git a/bin/breadmixer.py b/bin/breadmixer.py new file mode 100644 index 0000000..d2c3526 --- /dev/null +++ b/bin/breadmixer.py @@ -0,0 +1,113 @@ +import json +import os +import copy +from pathlib import Path +from datetime import datetime +from breadbot_common import SQLite, MySQL, TranscriptableFile, mix_audio_with_ffmpeg +from txtai.pipeline import Transcription + +MAX_FILES_PER_CYCLE=50 + +script_path = Path(__file__).resolve() +config_path = Path(script_path.parent, "config.json") + +with open(config_path, 'r') as config_file: + config_json = json.loads(config_file.read()) + +if config_json["db"]["type"].casefold() == "SQLITE".casefold(): + db = SQLite(Path(script_path.parent.parent, config_json["db"]["db_path"])) +else: + db = MySQL( + config_json["db"]["host"], + config_json["db"]["user"], + config_json["db"]["password"], + config_json["db"]["db_name"] + ) + +calls_needing_work = db.query( + "SELECT * FROM db_call WHERE NOT call_end_time IS NULL AND call_consolidated = 0 AND call_transcribed = 0" +) + +if calls_needing_work[0] == 0: + print("No work to do, exiting") + +transcriber = Transcription("openai/whisper-base") + +for call in calls_needing_work[1]: + all_files = os.listdir(Path( + config_json["media_voice_folder"], + call[0] + )) + + transcriptable_files = [] + + for file in all_files: + file_name_no_extension = file.split('.')[0] + timestamp = int(file_name_no_extension.split('-')[0]) + user_snowflake = file_name_no_extension.split('-')[1] + file_stamp_as_datetime = datetime.fromtimestamp(timestamp / 1000) + time_diff = file_stamp_as_datetime - call[1] + + transcriptable_files.append(TranscriptableFile( + file_path = file, + real_date = file_stamp_as_datetime, + milliseconds_from_start = int((time_diff.seconds * 1000) + (time_diff.microseconds / 1000)), + user_snowflake = user_snowflake + )) + + transcriptable_files.sort(key=lambda a: a.milliseconds_from_start) + + # TODO Possibly RAM abusive solution to wanting to keep the original list around + ffmpeg_files = copy.deepcopy(transcriptable_files) + + # TODO Error handling for all ffmpeg operations + while len(ffmpeg_files) > MAX_FILES_PER_CYCLE: + ffmpeg_files = [ + mix_audio_with_ffmpeg( + ffmpeg_files[index:min(index + MAX_FILES_PER_CYCLE, len(ffmpeg_files))], + config_json["media_voice_folder"], + call[0], + False + ) + for index in range(0, len(ffmpeg_files), MAX_FILES_PER_CYCLE) + ] + + final_pass_file = mix_audio_with_ffmpeg( + ffmpeg_files, + config_json["media_voice_folder"], + call[0], + True + ) + + db.update("db_call", ["call_consolidated"], [1, call[0]], [{ + "name": "call_id", + "compare": "=" + }]) + + for file in os.listdir(Path(config_json["media_voice_folder"], call[0])): + if file.startswith("intermediate"): + os.remove(Path(config_json["media_voice_folder"], call[0], file)) + + for file in transcriptable_files: + text = transcriber(file.file_path) + + db.insert( + "db_call_transcriptions", + ["speaking_start_time", "text", "callCallId", "userUserSnowflake"], + [file.real_date, text, call[0], file.user_snowflake] + ) + + db.update("db_call", ["call_transcribed"], [1, call[0]], [{ + "name": "call_id", + "compare": "=" + }]) + + + + + + + + + + \ No newline at end of file diff --git a/bin/profanity_regex_inserter.py b/bin/profanity_regex_inserter.py new file mode 100644 index 0000000..079fed1 --- /dev/null +++ b/bin/profanity_regex_inserter.py @@ -0,0 +1,44 @@ +# The hidden file filed with profanity and first version of this program brought to you +# by Noah Lacorazza, rewritten from Java to Python by Brad + +import json +from pathlib import Path +from breadbot_common import SQLite, MySQL + +script_path = Path(__file__).resolve() +config_path = Path(script_path.parent, "config.json") +words_path = Path(script_path.parent, "Words.json") + +with open(config_path, 'r') as config_file: + config_json = json.loads(config_file.read()) + +with open(words_path, 'r') as words_file: + words_list = json.loads(words_file.read()) + +if config_json["db"]["type"].casefold() == "SQLITE".casefold(): + db = SQLite(Path(script_path.parent.parent, config_json["db"]["db_path"])) +else: + db = MySQL( + config_json["db"]["host"], + config_json["db"]["user"], + config_json["db"]["password"], + config_json["db"]["db_name"] + ) + +print(db.select("db_server", ["server_snowflake"])) + +for element in db.select("db_server", ["server_snowflake"]): + for word in words_list: + regex_string = "(^|\\\\W|\\\\b)" + + for i in range(len(word)): + if word[i] in config_json["profanity"]["replacers"].keys(): + regex_string = regex_string + config_json["profanity"]["replacers"][word[i]] + "{1,}" + else: + regex_string = regex_string + word[i] + "{1,}" + + regex_string = regex_string + "($|\\\\W|\\\\b)" + + db.insert("db_message_regex", ["regex", "word", "serverServerSnowflake"], [regex_string, word, element[0]]) + +db.close() \ No newline at end of file diff --git a/src/utilties/discord/messages.ts b/src/utilties/discord/messages.ts index 6bc44a3..e6afd82 100644 --- a/src/utilties/discord/messages.ts +++ b/src/utilties/discord/messages.ts @@ -3,6 +3,7 @@ import { Repository } from "typeorm"; import { DBMessage } from "../storage/entities/DBMessage"; import { DBMessageAttachments } from "../storage/entities/DBMessageAttachment"; import { DBMessageContentChanges } from "../storage/entities/DBMessageContentChanges"; +import { DBMessageRegex } from "../storage/entities/DBMessageRegex"; // TODO Do partial messages affect other functionality elsewhere? @@ -99,4 +100,8 @@ export async function markMessageDeleted(db: Repository, message: Omi console.log(err) return null } +} + +export async function checkYourProfanity(messageDB: Repository, regexDB: Repository, message: OmitPartialGroupDMChannel> | PartialMessage) : Promise { + return null } \ No newline at end of file diff --git a/src/utilties/discord/regex_matching.ts b/src/utilties/discord/regex_matching.ts index 08094c7..fee339b 100644 --- a/src/utilties/discord/regex_matching.ts +++ b/src/utilties/discord/regex_matching.ts @@ -1,7 +1,9 @@ -import { Guild } from "discord.js"; +import { Guild, Message, OmitPartialGroupDMChannel, PartialMessage } from "discord.js"; import { DBMessageRegex } from "../storage/entities/DBMessageRegex"; import { Repository } from "typeorm"; import { DBServer } from "../storage/entities/DBServer"; +import { DBMessageRegexMatches } from "../storage/entities/DBMessageRegexMatches"; +import { DBMessage } from "../storage/entities/DBMessage"; export async function getRegexesForGuild(db: Repository, guild: Guild): Promise { return (await db.findOne({ @@ -15,4 +17,37 @@ export async function getRegexesForGuild(db: Repository, guild: Guild) server_snowflake: guild.id } }))?.regexes +} + +export async function checkMatchingRegexes(regexes: DBMessageRegex[], testString: string) : Promise { + let matchedRegexes: DBMessageRegex[] = [] + + regexes.forEach((regex) => { + const regexObj = new RegExp(regex.regex, 'gmi') + + if(regexObj.test(testString)) { + matchedRegexes.push(regex) + } + }) + + if(matchedRegexes.length != 0) { + return matchedRegexes + } else { + return null + } +} + +export async function insertAnyRegexMatches(regexes: DBMessageRegex[], db: Repository, message: OmitPartialGroupDMChannel> | PartialMessage) { + let matches : DBMessageRegexMatches[] = [] + + regexes.forEach(async (regex) => { + matches.push(db.create({ + message: { + message_snowflake: message.id + }, + regex: regex + })) + }) + + await db.save(matches) } \ No newline at end of file diff --git a/src/utilties/storage/entities/DBMessageRegex.ts b/src/utilties/storage/entities/DBMessageRegex.ts index 4971c7b..7766209 100644 --- a/src/utilties/storage/entities/DBMessageRegex.ts +++ b/src/utilties/storage/entities/DBMessageRegex.ts @@ -10,6 +10,9 @@ export class DBMessageRegex { @ManyToOne(() => DBServer, (server: DBServer) => server.regexes) server: DBServer + @Column() + word: string + @Column() regex: string