commit 715be9728971bc387c590ad6ee15efb62acbb8d3 Author: katboi01 Date: Thu Jan 8 22:07:03 2026 +0100 fresh start diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..79e98a5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +Classify/* +token.json +google.json +*__pycache__* +test*.py +session.json +database.db +token.txt +Temp* +Classifier/sample*.* +Classifier/HAV0X/* +!Classifier/HAV0X/readme.txt +Classifier/SmilingWolf/* +!Classifier/SmilingWolf/readme.txt +session.tw_session +session.json +.vscode/launch.json +debug.log +configs/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..4724935 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "getPixivpyToken"] + path = getPixivpyToken + url = https://github.com/eggplants/get-pixivpy-token.git diff --git a/Classifier/HAV0X/readme.txt b/Classifier/HAV0X/readme.txt new file mode 100644 index 0000000..0401bde --- /dev/null +++ b/Classifier/HAV0X/readme.txt @@ -0,0 +1 @@ +put all files from https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter/tree/main here \ No newline at end of file diff --git a/Classifier/SmilingWolf/readme.txt b/Classifier/SmilingWolf/readme.txt new file mode 100644 index 0000000..9832559 --- /dev/null +++ b/Classifier/SmilingWolf/readme.txt @@ -0,0 +1 @@ +put model.onnx and selected_tags.csv from https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/tree/main here \ No newline at end of file diff --git a/Classifier/classifyHelper.py b/Classifier/classifyHelper.py new file mode 100644 index 0000000..e6ec234 --- /dev/null +++ b/Classifier/classifyHelper.py @@ -0,0 +1,36 @@ +from io import BytesIO +from PIL import Image, UnidentifiedImageError +import imagehash +from Classifier.havoxClassifier import VoxClassifier +from Classifier.wdClassifier import WDClassifier +from Database.x_classes import ErrorID, HavoxLabel, PostRating +from helpers import twos_complement + +async def classify_all(photo_bytes, wd_classifier: WDClassifier, vox_classifier : VoxClassifier): + vox_label = await vox_classifier.classify_async(photo_bytes) + is_filtered = vox_label == HavoxLabel.Rejected + + tags = [] + filtered_tags = {} + phash = None + dhash = None + error_id = ErrorID.SUCCESS + if not is_filtered: + rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6) + tags = list(tags.keys()) + + try: + with BytesIO(photo_bytes) as image_file: + with Image.open(image_file) as imag: + phash = twos_complement(str(imagehash.phash(imag)), 64) + dhash = twos_complement(str(imagehash.dhash(imag)), 64) + except UnidentifiedImageError: + error_id = ErrorID.HASHING_Unidentified + except FileNotFoundError: + error_id = ErrorID.HASHING_FileNotFound + except Exception as ex: + error_id = ErrorID.HASHING_OTHER + else: + rating = PostRating.Filtered + + return vox_label, rating, tags, filtered_tags, phash, dhash, error_id \ No newline at end of file diff --git a/Classifier/havoxClassifier.py b/Classifier/havoxClassifier.py new file mode 100644 index 0000000..d77c2a0 --- /dev/null +++ b/Classifier/havoxClassifier.py @@ -0,0 +1,46 @@ +import asyncio +from PIL import Image +from io import BytesIO +from concurrent.futures import ThreadPoolExecutor +from transformers import ViTForImageClassification, ViTImageProcessor + +class VoxClassifier(): + def __init__(self): + self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/') + self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/') + + def classify(self, image_bytes): + try: + with BytesIO(image_bytes) as image_file: + with Image.open(image_file).convert("RGB") as image: + inputs = self.feature_extractor(images=image, return_tensors="pt") + outputs = self.model(**inputs) + logits = outputs.logits + # model predicts one of the 3 tags + predicted_class_idx = logits.argmax(-1).item() + return self.model.config.id2label[predicted_class_idx] + except Exception as ex: + print(ex) + return "Rejected" + + async def classify_async(self, image_bytes): + with ThreadPoolExecutor() as executor: + future = executor.submit(self.classify, image_bytes) + + while not future.done(): + await asyncio.sleep(1) + + return future.result() + +if __name__ == "__main__": + async def main(): + classifier = VoxClassifier() + images = [ + "Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png" + ] + for image in images: + with open(image, "rb") as file: + result = await classifier.classify_async(file.read()) + print(result) + + asyncio.run(main()) \ No newline at end of file diff --git a/Classifier/wdClassifier.py b/Classifier/wdClassifier.py new file mode 100644 index 0000000..4bf018d --- /dev/null +++ b/Classifier/wdClassifier.py @@ -0,0 +1,238 @@ +import asyncio +import os +from dataclasses import dataclass +from io import BytesIO +from PIL import Image + +import onnxruntime as onnx +import numpy as np +import pandas as pd + +from concurrent.futures import ThreadPoolExecutor + +@dataclass +class LabelData: + names: list[str] + rating: list[np.int64] + general: list[np.int64] + character: list[np.int64] + +class WDClassifier(): + onnxSession = None + relevant_tags = [] + ignored_tags = [] + + def __init__(self): + self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv") + self.relevant_tags = self.get_relevant_tags() + self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx') + + def classify(self, image_bytes, max_count = -1, tag_threshold = -1): + try: + with BytesIO(image_bytes) as image_file: + with Image.open(image_file) as image: + image_data = self.resize_image(image, 448) + # convert to numpy array and add batch dimension + image_data = np.array(image_data, np.float32) + image_data = np.expand_dims(image_data, axis=0) + # NHWC image RGB to BGR + image_data = image_data[..., ::-1] + + outputs = self.onnxSession.run(None, {'input': image_data}) + outputs = outputs[0].squeeze(0) + + ratings, general, filtered = self.processRating(outputs) + rating = max(ratings, key=ratings.get) + + if tag_threshold != -1: + general = self.filterThreshold(general, tag_threshold) + filtered = self.filterThreshold(filtered, tag_threshold) + + if max_count != -1: + general = self.filterCount(general, max_count) + filtered = self.filterThreshold(filtered, tag_threshold) + + return rating, general, filtered + + except Exception as ex: + print(ex) + return "", {}, {} + + async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1): + with ThreadPoolExecutor() as executor: + future = executor.submit(self.classify, image_bytes, max_count, tag_threshold) + + while not future.done(): + await asyncio.sleep(1) + + return future.result() + + def processRating(self, data): + ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating} + general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)} + filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags} + return ratings, general, filtered + + def filterThreshold(self, data, threshold = 0.6): + output = {key:value for key,value in data.items() if value > threshold} + return output + + def filterCount(self, data, count = 20): + keys = list(data.keys()) + values = list(data.values()) + results_ordered = np.flip(np.argsort(values)) + output = {} + for idx, dict_index in enumerate(results_ordered): + if idx < count: + output[keys[dict_index]] = values[dict_index] + else: + break + return output + + def filterTagList(self, data, tagList): + output = [] + for result in data: + if result["label"] in tagList: + output.append(result) + return output + + def resize_image(self, image : Image.Image, size = 448): + image = WDClassifier.pil_ensure_rgb(image) + image = WDClassifier.pil_pad_square(image) + image = WDClassifier.pil_resize(image, size) + return image + + def get_tag_index(self, tag: str): + try: + idx = self.labels.names.index(tag) + except: + print(tag, "not found") + idx = -1 + return idx + + def get_relevant_tags(self): + if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"): + with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file: + file.write("") + + with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file: + self.ignored_tags = file.read().splitlines() + + tags = {tag:index for index,tag in enumerate(self.labels.names)} + for tag in self.ignored_tags: + if tag in tags: + del(tags[tag]) + return list(tags.values()) + + def add_ignored_tag(self, tag: str): + idx = self.get_tag_index(tag) + if idx != -1 and idx in self.relevant_tags: + self.ignored_tags.append(tag) + with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file: + file.write("\n".join(self.ignored_tags)) + self.relevant_tags.remove(idx) + + @staticmethod + def load_labels(csv_path): + df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) + tag_data = LabelData( + names=df["name"].tolist(), + rating=list(np.where(df["category"] == 9)[0]), + general=list(np.where(df["category"] == 0)[0]), + character=list(np.where(df["category"] == 4)[0]), + ) + + return tag_data + + @staticmethod + def pil_ensure_rgb(image: Image.Image) -> Image.Image: + # convert to RGB/RGBA if not already (deals with palette images etc.) + if image.mode not in ["RGB", "RGBA"]: + image = ( + image.convert("RGBA") + if "transparency" in image.info + else image.convert("RGB") + ) + # convert RGBA to RGB with white background + if image.mode == "RGBA": + canvas = Image.new("RGBA", image.size, (255, 255, 255)) + canvas.alpha_composite(image) + image = canvas.convert("RGB") + return image + + @staticmethod + def pil_pad_square(image: Image.Image) -> Image.Image: + w, h = image.size + # get the largest dimension so we can pad to a square + px = max(image.size) + # pad to square with white background + canvas = Image.new("RGB", (px, px), (255, 255, 255)) + canvas.paste(image, ((px - w) // 2, (px - h) // 2)) + return canvas + + @staticmethod + def pil_resize(image: Image.Image, target_size: int) -> Image.Image: + # Resize + max_dim = max(image.size) + if max_dim != target_size: + image = image.resize( + (target_size, target_size), + Image.BICUBIC, + ) + return image + + @staticmethod + def get_tags( + probs, + labels : LabelData, + gen_threshold: float, + char_threshold: float, + ): + # Convert indices+probs to labels + probs = list(zip(labels.names, probs)) + + # First 4 labels are actually ratings + rating_labels = dict([probs[i] for i in labels.rating]) + + # General labels, pick any where prediction confidence > threshold + gen_labels = [probs[i] for i in labels.general] + gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) + gen_labels = dict( + sorted( + gen_labels.items(), + key=lambda item: item[1], + reverse=True, + ) + ) + + # Character labels, pick any where prediction confidence > threshold + char_labels = [probs[i] for i in labels.character] + char_labels = dict([x for x in char_labels if x[1] > char_threshold]) + char_labels = dict( + sorted( + char_labels.items(), + key=lambda item: item[1], + reverse=True, + ) + ) + + # Combine general and character labels, sort by confidence + combined_names = [x for x in gen_labels] + combined_names.extend([x for x in char_labels]) + + # Convert to a string suitable for use as a training caption + caption = ", ".join(combined_names) + taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)") + + return caption, taglist, rating_labels, char_labels, gen_labels + +if __name__ == "__main__": + async def main(): + classifier = WDClassifier() + with open("Classifier/sample4.png", "rb") as image: + image_data = image.read() + rating, results, discord = await classifier.classify_async(image_data, 20, 0.6) + + print(rating, results, discord) + + asyncio.run(main()) \ No newline at end of file diff --git a/Database/db_schema.py b/Database/db_schema.py new file mode 100644 index 0000000..78a2a7c --- /dev/null +++ b/Database/db_schema.py @@ -0,0 +1,39 @@ +from enum import StrEnum + +class x_accounts(StrEnum): + table = "x_accounts" + id = "id" + rating = "rating" + name = "name" + is_deleted = "is_deleted" + is_protected = "is_protected" + download_mode = "download_mode" + discord_channel_id = "discord_channel_id" + discord_thread_id = "discord_thread_id" + +class x_posts(StrEnum): + table = "x_posts" + id = "id" + account_id = "account_id" + discord_post_id = "discord_post_id" + error_id = "error_id" + action_taken = "action_taken" + rating = "rating" + tags = "tags" + text = "text" + date = "date" + +class x_posts_images(StrEnum): + table = "x_posts_images" + post_id = "post_id" + index = "index" + phash = "phash" + dhash = "dhash" + error_id = "error_id" + rating = "rating" + tags = "tags" + file = "file" + saved_file = "saved_file" + vox_label = "vox_label" + duplicate_id = "duplicate_id" + duplicate_index = "duplicate_index" \ No newline at end of file diff --git a/Database/dbcontroller.py b/Database/dbcontroller.py new file mode 100644 index 0000000..34ea450 --- /dev/null +++ b/Database/dbcontroller.py @@ -0,0 +1,256 @@ +import html +import re +import psycopg +from psycopg.rows import dict_row + +from Database.x_classes import x_posts_images, x_accounts, x_posts +from Database.db_schema import x_accounts as x_accounts_schema, x_posts as x_posts_schema, x_posts_images as x_posts_images_schema +from config import Global_Config + +class DatabaseController: + def __init__(self): + if Global_Config["postgresql_conninfo"] is None: + raise Exception("Database connection string is required") + self.conn = psycopg.connect(Global_Config["postgresql_conninfo"], row_factory=dict_row) + self.cursor = self.conn.cursor() + + def query_get_stream(self, query, params = ([]), count = -1, page_size = 1000, offset = 0, commit = True): + while(True): + result = self.query_get(query + f' LIMIT {page_size} OFFSET {offset}', params, commit=commit) + if len(result) == 0: + return + for idx, record in enumerate(result): + offset += 1 + if count != -1 and offset > count: + return + yield record + + def query_get(self, query, params = ([]), count = -1, commit = True): + self.cursor.execute(query, params) + if count == -1: + result = self.cursor.fetchall() + elif count == 1: + result = self.cursor.fetchone() + else: + result = self.cursor.fetchmany(count) + if commit: + self.conn.commit() + return result + + def query_set(self, query, params = ([]), commit = True): + try: + self.cursor.execute(query, params) + if commit: + self.conn.commit() + return True + except psycopg.IntegrityError as e: + print(e) + self.conn.rollback() + return False + + def insert(self, table_name, updates, where=None, returning=None): + # insert data into table + try: + query = f"INSERT INTO {table_name} ({', '.join([update[0] for update in updates])}) VALUES ({', '.join(['%s' for update in updates])})" + if where: + query += f" WHERE {where}" + if returning: + query += f" RETURNING {returning}" + self.cursor.execute(query, ([update[1] for update in updates])) + row = self.cursor.fetchone() + self.conn.commit() + return True, row[returning] + else: + self.cursor.execute(query, ([update[1] for update in updates])) + self.conn.commit() + return True, None + except psycopg.IntegrityError as e: + print(e, query) + self.conn.rollback() + return False, None + +#region Pixiv + def pixiv_get_tags(self, tags): + new_tags = [] + for tag in tags: + query = "SELECT id FROM pixiv_tags WHERE name=%s" + self.cursor.execute(query, ([tag.name])) + result = self.cursor.fetchone() + if result is None: + query = "INSERT INTO pixiv_tags(name, translated_name) VALUES(%s, %s) RETURNING id" + self.cursor.execute(query, ([tag.name, tag.translated_name])) + result = self.cursor.fetchone() + new_tags.append(result["id"]) + self.conn.commit() + return new_tags + + def pixiv_get_user(self, id): + query = "SELECT * FROM pixiv_users WHERE id=%s" + return self.query_get(query, ([id]), count=1) + + def pixiv_get_post(self, id): + query = "SELECT * FROM pixiv_posts WHERE id=%s" + return self.query_get(query, ([id]), count=1) + + def pixiv_insert_user(self, json, commit = True): + query = f'INSERT INTO pixiv_users (id, name, account) VALUES (%s,%s,%s)' + return self.query_set(query, ([json.id, json.name, json.account]), commit) + + def pixiv_insert_post(self, json, commit = True): + query = f'INSERT INTO pixiv_posts (id, title, type, caption, restrict, user_id, tags, create_date, page_count, sanity_level, x_restrict, illust_ai_type, is_saved) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)' + tags = self.pixiv_get_tags(json.tags) + return self.query_set(query, ([json.id, json.title, json.type, json.caption, json.restrict, json.user.id, tags, json.create_date, json.page_count, json.sanity_level, json.x_restrict, json.illust_ai_type, True]), commit) + + def pixiv_update_post(self, id, date, commit = True): + query = f'UPDATE pixiv_posts SET create_date = %s WHERE id = %s' + return self.query_set(query, ([date, id]), commit) +#endregion +#region X + def get_all_posts(self, account_id, commit = True): + query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s ORDER BY {x_posts_schema.id}" + results = self.query_get(query, ([account_id]), commit=commit) + + if results is None: + return [] + else: + results = [x_posts(**result) for result in results] + return results + + def get_post_images(self, account_id, commit = True): + query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s ORDER BY {x_posts_schema.id}" + results = self.query_get(query, ([account_id]), commit=commit) + + if results is None: + return [] + else: + results = [x_posts(**result) for result in results] + return results + + def get_all_post_ids(self, account_id, commit = True): + query = f"SELECT {x_posts_schema.id} FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s" + results = self.query_get(query, ([account_id]), commit=commit) + return [result["id"] for result in results] + + def get_max_post_id(self, account_id : str, commit = True): + query = f"SELECT MAX({x_posts_schema.id}) AS max FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s" + result = self.query_get(query, ([account_id]), count=1, commit=commit) + if result is None: + return 1 + else: + return result["max"] + + def x_get_post_from_id(self, post_id, commit = True): + query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.id} = %s" + result = self.query_get(query, ([post_id]), count=1, commit=commit) + + if result is None: + return None + else: + return x_posts(**result) + + def x_insert_post(self, post : x_posts, commit = True): + text = re.sub(r"https://t.co\S+", "", html.unescape(post.text)) + query = f'''INSERT INTO {x_posts_schema.table} + ({x_posts_schema.id}, {x_posts_schema.account_id}, {x_posts_schema.discord_post_id}, + {x_posts_schema.error_id}, {x_posts_schema.action_taken}, {x_posts_schema.rating}, + {x_posts_schema.tags}, {x_posts_schema.text}, {x_posts_schema.date}) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)''' + return self.query_set(query, + ([ + post.id, post.account_id, post.discord_post_id, + post.error_id, post.action_taken, post.rating, + post.tags, text, post.date + ]), commit) + + def x_search_duplicate(self, user_id: int, max_id: int, phash = None, dhash = None, commit = True): + if(phash != None): + query = f'''SELECT i.* FROM {x_posts_images_schema.table} i JOIN {x_posts_schema.table} p ON i.{x_posts_images_schema.post_id} = p.{x_posts_schema.id} WHERE i.post_id < {max_id} AND p.account_id = {user_id} and bit_count(i.phash # {phash}) < 1 Order by post_id ASC''' + temp = self.query_get(query, count=1, commit = commit) + return None if temp == None else x_posts_images(**temp) + elif(dhash != None): + query = f'''SELECT i.* FROM {x_posts_images_schema.table} i JOIN {x_posts_schema.table} p ON i.{x_posts_images_schema.post_id} = p.{x_posts_schema.id} WHERE i.post_id < {max_id} AND p.account_id = {user_id} and bit_count(i.dhash # {dhash}) < 1 Order by post_id ASC''' + temp = self.query_get(query, count=1, commit = commit) + return None if temp == None else x_posts_images(**temp) + else: + return None + + def x_insert_image(self, image : x_posts_images, commit = True): + query = f'''INSERT INTO {x_posts_images_schema.table} + ({x_posts_images_schema.post_id}, {x_posts_images_schema.index}, {x_posts_images_schema.phash}, + {x_posts_images_schema.dhash}, {x_posts_images_schema.error_id}, {x_posts_images_schema.rating}, + {x_posts_images_schema.tags}, {x_posts_images_schema.file}, {x_posts_images_schema.saved_file}, + {x_posts_images_schema.vox_label}) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)''' + return self.query_set(query, + ([ + image.post_id, image.index, image.phash, + image.dhash, image.error_id, image.rating, + image.tags, image.file, image.saved_file, + image.vox_label + ]), commit) + + def x_update_post(self, id, discord_post_id, error_id, action_taken, commit = True): + query = f"UPDATE {x_posts_schema.table} SET {x_posts_schema.discord_post_id} = %s, {x_posts_schema.error_id} = %s, {x_posts_schema.action_taken} = %s WHERE {x_posts_schema.id} = %s" + return self.query_set(query, ([discord_post_id, error_id, action_taken, id]), commit=commit) + + def x_get_all_accounts(self): + accounts = self.query_get(f'SELECT * from {x_accounts_schema.table} ORDER BY {x_accounts_schema.id}') + result = [x_accounts(**account) for account in accounts] + return result + + def x_get_account_by_name(self, handle : str): + """returns TwitterContainer if account exists in database or None""" + query = f'SELECT * from {x_accounts_schema.table} where {x_accounts_schema.name} = %s' + result = self.query_get(query, ([handle]), count=1) + + if result is None: + return None + + container = x_accounts(**result) + return container + + def x_get_account_by_id(self, id : int): + """returns TwitterContainer if account exists in database or None""" + query = f'SELECT * from {x_accounts_schema.table} where {x_accounts_schema.id} = %s' + result = self.query_get(query, ([id]), count=1) + + if result is None: + return None + + container = x_accounts(**result) + return container + + def x_add_account(self, container : x_accounts): + result, id = self.insert(x_accounts_schema.table, [ + (x_accounts_schema.id , container.id), + (x_accounts_schema.name , container.name), + (x_accounts_schema.rating , container.rating), + (x_accounts_schema.discord_channel_id , container.discord_channel_id), + (x_accounts_schema.discord_thread_id , container.discord_thread_id), + (x_accounts_schema.download_mode , container.download_mode), + ], returning= x_accounts_schema.id) + if result: + return id + else: + return None + + def x_update_account(self, container : x_accounts): + updates = [ + (x_accounts_schema.name , container.name), + (x_accounts_schema.rating , container.rating), + (x_accounts_schema.discord_channel_id , container.discord_channel_id), + (x_accounts_schema.discord_thread_id , container.discord_thread_id), + (x_accounts_schema.download_mode , container.download_mode) + ] + self.x_update_account_properties(container.id, updates) + + def x_update_account_properties(self, account_id, updates): + """Example: updates = [("name", container.name), ("rating", container.rating)]""" + query = f"UPDATE {x_accounts_schema.table} SET {', '.join([f'{update[0]} = %s' for update in updates])} WHERE id = {account_id}" + self.cursor.execute(query, ([update[1] for update in updates])) + self.conn.commit() + +#endregion + + def close(self): + self.conn.close() \ No newline at end of file diff --git a/Database/x_classes.py b/Database/x_classes.py new file mode 100644 index 0000000..6d7beb0 --- /dev/null +++ b/Database/x_classes.py @@ -0,0 +1,123 @@ +from datetime import datetime +from enum import IntEnum, StrEnum + +class DownloadMode(IntEnum): + NO_DOWNLOAD = 0 + DOWNLOAD = 1 + DOWNLOAD_ALL = 2 + +class ErrorID(IntEnum): + SUCCESS = 0 + NO_ART = 1 + OTHER_ERROR = 5 + DOWNLOAD_FAIL = 6 + ACCOUNT_DEAD = 7 + ACCOUNT_SKIP = 8 + TAGGING_ERROR = 9 + FILE_READ_ERROR = 10 + HASHING_Unidentified = 11 + HASHING_FileNotFound = 12 + HASHING_OTHER = 13 + FILE_DELETED_OUTSIDE_ARCHIVE = 14 + NO_CHANNEL = 15 + TWEET_DELETED = 16 + +class ActionTaken(IntEnum): + Null = 0 + Accepted = 1 + Rejected = 2 + Hidden = 3 + +class AccountRating(StrEnum): + SFW = "SFW" + NSFW = "NSFW" + NSFL = "NSFL" + +class HavoxLabel(StrEnum): + KF = "KF" + NonKF = "NonKF" + Rejected = "Rejected" + +class PostRating(StrEnum): + Unrated = "unrated" + General = "general" + Sensitive = "sensitive" + Questionable = "questionable" + Explicit = "explicit" + Filtered = "filtered" + +class x_accounts: + id : int + rating : AccountRating + name : str + is_deleted : bool + is_protected : bool + download_mode : int + discord_channel_id : int + discord_thread_id : int + + def __init__(self, id = 0, rating = AccountRating.NSFW, name : str = "", is_deleted = False, is_protected = False, download_mode = DownloadMode.NO_DOWNLOAD, discord_thread_id = 0, discord_channel_id = 0, **kwargs): + self.id = id + self.rating = rating + self.name = name + self.is_deleted = is_deleted + self.is_protected = is_protected + self.download_mode = download_mode + self.discord_thread_id = discord_thread_id + self.discord_channel_id = discord_channel_id + +class x_posts: + id : int + account_id : int + discord_post_id : int + error_id : int + action_taken : ActionTaken + rating : PostRating + "highest rating from all post images" + tags : list[str] + text : str + date : datetime + + def __init__(self, id = -1, account_id = -1, discord_post_id = 0, error_id = ErrorID.SUCCESS, action_taken = ActionTaken.Null, rating = PostRating.Unrated, tags = [], text = "", date = None, **kwargs): + self.id = id + self.account_id = account_id + self.discord_post_id= discord_post_id + self.error_id = error_id + self.action_taken = action_taken + self.rating = rating + self.tags = tags + self.text = text + self.date = date + +class x_posts_images: + post_id : int + index : int + phash : int = None + dhash : int = None + error_id : int = 0 + rating : PostRating = PostRating.Unrated + tags : list[str] = [] + file : str = None + saved_file : str = None + vox_label : HavoxLabel = None + duplicate_id : int = -1 + duplicate_index : int = -1 + + def __init__(self, post_id, index, phash = None, dhash = None, error_id = ErrorID.SUCCESS, rating = PostRating.Unrated, tags = [], file = None, saved_file = None, vox_label : HavoxLabel = None, duplicate_index = -1, duplicate_id = -1, **kwargs): + self.post_id = post_id + self.index = index + self.phash = phash + self.dhash = dhash + self.error_id = error_id + self.rating = rating + self.tags = tags + self.file = file + self.saved_file = saved_file + self.vox_label = vox_label + self.duplicate_id = duplicate_id + self.duplicate_index = duplicate_index + +if __name__ == "__main__": + print(ErrorID.ACCOUNT_DEAD) + print(PostRating.Questionable) + print(x_posts_images(0,1).rating) \ No newline at end of file diff --git a/Discord/discordHelper.py b/Discord/discordHelper.py new file mode 100644 index 0000000..d052f42 --- /dev/null +++ b/Discord/discordHelper.py @@ -0,0 +1,236 @@ +from __future__ import annotations +from datetime import datetime +import html +import io +import re +import nextcord +from nextcord import ChannelType, Message + +from Database.dbcontroller import DatabaseController +from Database.x_classes import ErrorID, x_accounts +from Twitter.tweetHelper import DownloadedMedia +from exceptions import NO_CHANNEL, OTHER_ERROR +from tweety.types.twDataTypes import Tweet + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +def chunks(s, n): + """Produce `n`-character chunks from `s`.""" + for start in range(0, len(s), n): + yield s[start:start+n] + +def build_pixiv_embed(post): + url = "https://www.pixiv.net/en/artworks/" + str(post.id) + text = re.sub(r"https://t.co\S+", "", html.unescape(post.caption)) + date = post.create_date + + embed=nextcord.Embed(description=text) + embed.set_author(name=post.user.name, url=url) + embed.set_footer(text=date) + + return embed + +def build_x_embed(handle : str, post : Tweet): + url = "https://x.com/" + handle + "/status/" + str(post.id) + text = re.sub(r"https://t.co\S+", "", html.unescape(post.text)) + date = datetime.strftime(post.created_on, '%Y-%m-%d %H:%M:%S') + + embed=nextcord.Embed(description=text) + embed.set_author(name=handle, url=url, icon_url=post.author.profile_image_url_https) + embed.set_footer(text=date) + + return embed + +def build_secondary_embed(main_post : Message, handle : str, post : Tweet): + text = re.sub(r"https://t.co\S+", "", html.unescape(post.text)) + date = datetime.strftime(post.created_on, '%Y-%m-%d %H:%M:%S') + attachment_urls = [attachment.url.split("?")[0] for attachment in main_post.attachments] + + embeds = [] + for url in attachment_urls: + embed = nextcord.Embed(url="https://katworks.sytes.net") + embed.set_image(url) + embeds.append(embed) + + if len(embeds) == 0: + return None + + embed : nextcord.Embed = embeds[0] + embed.description = text + embed.set_footer(text=date) + embed.set_author(name=handle, url=main_post.jump_url, icon_url=post.author.profile_image_url_https) + return embeds + +async def send_error(ex : Exception, botData : RuntimeBotData): + print(ex) + errors_channel = nextcord.utils.get(botData.client.guilds[0].channels, name="bot-status") + await errors_channel.send(content=str(ex)) + +def get_secondary_channel(is_animated, is_filtered, rating, tags : list, artist : x_accounts, guild : nextcord.Guild): + if is_animated: + return nextcord.utils.get(guild.channels, name="animation-feed") + if artist.rating == 'NSFL': + return nextcord.utils.get(guild.channels, name="hidden-feed") + if "futanari" in tags: + return nextcord.utils.get(guild.channels, name="hidden-feed") + if is_filtered or not rating: + return nextcord.utils.get(guild.channels, name="filtered-feed") + if rating == "general": + return nextcord.utils.get(guild.channels, name="safe-feed") + if rating == "sensitive" or rating == "questionable": + return nextcord.utils.get(guild.channels, name="unsafe-feed") + if rating == 'explicit': + return nextcord.utils.get(guild.channels, name="explicit-feed") + + return nextcord.utils.get(guild.channels, name="filtered-feed") + +async def edit_existing_embed_color(message : Message, color : nextcord.Colour): + embeds = message.embeds + embeds[0].colour = color + await message.edit(embeds=embeds) + +async def send_x_post(post : Tweet, artist : x_accounts, guild : nextcord.Guild, new_accounts : list, files_to_send : list[DownloadedMedia], is_filtered: bool, rating: str, tags : list, auto_approve : bool = False, vox_labels : list = None, duplicate_posts : list = None, xView: nextcord.ui.View = None, yView: nextcord.ui.View = None): + if vox_labels is None: + vox_labels = [] + if duplicate_posts is None: + duplicate_posts = [] + + if artist.discord_channel_id != 0: + channel = guild.get_channel(artist.discord_channel_id) + elif artist.discord_thread_id != 0: + channel = guild.get_thread(artist.discord_thread_id) + else: + raise NO_CHANNEL("Ensure channel for the account exists") + + embed = build_x_embed(artist.name, post) + if rating != "": + embed.add_field(name = "rating", value=rating, inline=False) + embed.add_field(name = "tags", value=f'{", ".join(tags)}'.replace("_","\\_")[:1024], inline=False) + if len(duplicate_posts) > 0: + links = [f"[{id}]()" for id in duplicate_posts] + embed.add_field(name = "duplicates", value=f'{" ".join(links)}'[:1024], inline=False) + + discord_files = [nextcord.File(fp = io.BytesIO(x.file_bytes), filename = x.file_name, force_close=True) for x in files_to_send] + try: + main_post : Message = await channel.send(embed=embed, files=discord_files) + except Exception as e: + raise OTHER_ERROR(e) + finally: + for file in discord_files: file.close() + + #skip posting in the public feed for accounts with too many posts + if artist.name in new_accounts and not is_filtered: + return main_post + + is_animated = files_to_send[0].file_name.endswith(".mp4") + secondary_channel : nextcord.TextChannel = get_secondary_channel(is_animated, is_filtered, rating, tags, artist, guild) + + if is_animated: + link = "https://vxtwitter.com/" + artist.name + "/status/" + post.id + " " + main_post.jump_url + secondary_post : Message = await secondary_channel.send(content=link, view = None if auto_approve else yView) + else: + embeds = build_secondary_embed(main_post, artist.name, post) + if rating != "": + embeds[0].add_field(name = "rating", value=rating, inline=False) + #embeds[0].add_field(name = "tags", value=f'{", ".join(tags)}'.replace("_","\\_")[:1024], inline=False) + if len(vox_labels) > 0: + embeds[0].add_field(name = "prediction", value=", ".join(vox_labels), inline=False) + if len(duplicate_posts) > 0: + links = [f"[{id}]()" for id in duplicate_posts] + embeds[0].add_field(name = "duplicates", value=f'{" ".join(links)}'[:1024], inline=False) + secondary_post : Message = await secondary_channel.send(embeds=embeds, view = None if auto_approve else xView) + + return main_post + +async def send_pixiv_post(post, files_to_send, channel : nextcord.TextChannel): + embed = build_pixiv_embed(post) + + discord_files = [nextcord.File(fp = io.BytesIO(file["file"]), filename = file["name"], force_close=True) for file in files_to_send] + try: + main_post : Message = await channel.send(embed=embed, files=discord_files) + except Exception as e: + print(e) + return None, ErrorID.OTHER_ERROR + finally: + for file in discord_files: file.close() + + return main_post, ErrorID.SUCCESS + +async def post_result(results, guild : nextcord.Guild, new_accounts : list): + print("Results: ", len(results)) + channel = nextcord.utils.get(guild.channels, name="bot-status") + string = f'Last check: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}' + for result in results: + if result in new_accounts: + string += f'\n{result} (new): {results[result]}' + new_accounts.remove(result) + else: + string += f'\n{result}: {results[result]}' + + for chunk in chunks(string, 6000): + embed = nextcord.Embed(description=chunk) + await channel.send(embed=embed) + +async def get_channel_from_handle(guild: nextcord.Guild, handle : str): + channel = nextcord.utils.get(guild.channels, name=handle.lower()) + if channel is None: + catId = 0 + category = None + while category is None or len(category.channels) == 50: + category = nextcord.utils.get(guild.categories, name=f'TWITTER-{catId}') + if category is None: + category = await guild.create_category(f'TWITTER-{catId}') + catId+=1 + channel = await guild.create_text_channel(name=handle.lower(), category=category) + return channel + +async def get_thread_from_handle(guild: nextcord.Guild, handle : str): + channel = nextcord.utils.get(guild.channels, name="threads") + thread = nextcord.utils.get(channel.threads, name=handle.lower()) + if thread is None: + thread = await channel.create_thread(name=handle.lower(), auto_archive_duration=10080, type=ChannelType.public_thread) + return thread + +async def get_category_by_name(client: nextcord.Client, name): + guild = client.guilds[0] + category = nextcord.utils.get(guild.categories, name=name) + if category is None: + category = await guild.create_category(name) + return category + +def check_permission(interaction : nextcord.Interaction): + role = nextcord.utils.find(lambda r: r.name == 'Archivist', interaction.guild.roles) + return role in interaction.user.roles + +async def message_from_jump_url(server : nextcord.Guild, jump_url:str): + link = jump_url.split('/') + server_id = int(link[4]) + channel_id = int(link[5]) + msg_id = int(link[6]) + + channel = server.get_channel(channel_id) + if channel is None: + channel = server.get_thread(channel_id) + message = await channel.fetch_message(msg_id) + return message + +async def get_main_post_and_data(guild, jump_url, botData : RuntimeBotData): + main_post = await message_from_jump_url(guild, jump_url) + tw_embed = main_post.embeds[0] + x_post_id = int(tw_embed.author.url.split('/')[-1]) + return main_post, x_post_id + +async def ensure_has_channel_or_thread(artist: x_accounts, guild: nextcord.Guild, database: DatabaseController): + if artist.discord_channel_id == 0 and artist.discord_thread_id == 0: + try: + channel = await get_channel_from_handle(guild, artist.name) + artist.discord_channel_id = channel.id + artist.discord_thread_id = 0 + except: + thread = await get_thread_from_handle(guild, artist.name) + artist.discord_thread_id = thread.id + artist.discord_channel_id = 0 + + database.x_update_account(artist) \ No newline at end of file diff --git a/Discord/views.py b/Discord/views.py new file mode 100644 index 0000000..de3c2e3 --- /dev/null +++ b/Discord/views.py @@ -0,0 +1,105 @@ +from __future__ import annotations +import nextcord +from nextcord.ui import View +from Database.x_classes import ActionTaken +from Discord import discordHelper + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +class XView(View): + def __init__(self, botData : RuntimeBotData): + self.botData = botData + super().__init__(timeout=None, prevent_update=False) + + @nextcord.ui.button(label="KF Art", custom_id="button-keep", style=nextcord.ButtonStyle.gray) + async def button_keep(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData) + self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Accepted) + await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.green()) + await self_message.edit(view=None) + + @nextcord.ui.button(label="Non-KF Art", custom_id="button-hide", style=nextcord.ButtonStyle.gray) + async def button_hide(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData) + self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Hidden) + await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.yellow()) + await self_message.delete() + + @nextcord.ui.button(label="Delete", custom_id="button-delete", style=nextcord.ButtonStyle.gray) + async def button_delete(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + try: + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData) + print("Deleting", x_post_id, main_post.jump_url) + + await main_post.delete() + await self_message.delete() + + self.botData.db.x_update_post(x_post_id, 0, 0, ActionTaken.Rejected) + except Exception as e: + await interaction.response.send_message("Error occured " + str(e), ephemeral=False) + +class YView(View): + + def __init__(self, botData : RuntimeBotData): + self.botData = botData + super().__init__(timeout=None, prevent_update=False) + + @nextcord.ui.button(label="KF Art", custom_id="y-button-keep", style=nextcord.ButtonStyle.gray) + async def button_keep(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData) + self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Accepted) + await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.green()) + await self_message.edit(view=None) + + @nextcord.ui.button(label="Non-KF Art", custom_id="y-button-hide", style=nextcord.ButtonStyle.gray) + async def button_hide(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData) + self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Hidden) + await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.yellow()) + await self_message.delete() + + @nextcord.ui.button(label="Delete", custom_id="y-button-delete", style=nextcord.ButtonStyle.gray) + async def button_delete(self, button, interaction : nextcord.Interaction): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + try: + self_message = interaction.message + main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData) + print("Deleting", x_post_id, main_post.jump_url) + + await main_post.delete() + await self_message.delete() + + self.botData.db.x_update_post(x_post_id, 0, 0, ActionTaken.Rejected) + except Exception as e: + await interaction.response.send_message("Error occured " + str(e), ephemeral=False) + \ No newline at end of file diff --git a/Pixiv/downloader.py b/Pixiv/downloader.py new file mode 100644 index 0000000..d4b0889 --- /dev/null +++ b/Pixiv/downloader.py @@ -0,0 +1,37 @@ +from __future__ import annotations +import traceback +from typing import TYPE_CHECKING + +import nextcord + +from Discord import discordHelper +from config import Global_Config +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +async def download_loop(botData: RuntimeBotData): + try: + db = botData.db + guild = botData.client.guilds[0] + + new_posts = await botData.pixivApi.get_new_posts(botData) + new_posts.sort(key= lambda x: x.id) + + for post in new_posts: + media = (await botData.pixivApi.download_illust(post, Global_Config["pixiv_download_path"]))[:4] + if post.x_restrict == 0: + channel = nextcord.utils.get(guild.channels, name="pixiv-sfw-feed") + elif post.x_restrict == 1: + channel = nextcord.utils.get(guild.channels, name="pixiv-r18-feed") + else: + channel = nextcord.utils.get(guild.channels, name="pixiv-r18g-feed") + await discordHelper.send_pixiv_post(post, media, channel) + + if db.pixiv_get_user(post.user.id) == None: + db.pixiv_insert_user(post.user) + db.pixiv_insert_post(post) + + except Exception as ex: + print(ex) + await discordHelper.send_error(traceback.format_exc()[0:256], botData) + print("Pixiv done") diff --git a/Pixiv/pixivapi.py b/Pixiv/pixivapi.py new file mode 100644 index 0000000..695189c --- /dev/null +++ b/Pixiv/pixivapi.py @@ -0,0 +1,230 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from config import Global_Config +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +import os +import asyncio +from io import BytesIO +from zipfile import ZipFile +from pixivpy3 import AppPixivAPI +from Twitter.tweetyapi import TweetyApi +from getPixivpyToken.gppt import GetPixivToken +from PIL import Image + +class PixivApi: + def init(self): + self.api = AppPixivAPI(timeout=60) + self.api.auth(refresh_token=self.get_refresh_token()) + return self + + def get_refresh_token(self) -> str: + if Global_Config["pixiv_token"] is not None: + return Global_Config["pixiv_token"] + else: + if Global_Config["pixiv_username"] is None or Global_Config["pixiv_password"] is None: + raise Exception("Pixiv username and password are required") + g = GetPixivToken(headless=True, username=Global_Config["pixiv_username"], password=Global_Config["pixiv_password"]) + refresh_token = g.login()["refresh_token"] + Global_Config["pixiv_token"] = refresh_token + return refresh_token + + async def search_illust(self, retries = 3, **next_qs): + last_error = "" + sleep_default = 0.125 + for i in range(retries): + try: + json_result = self.api.search_illust(**next_qs) + + if json_result.error != None: + raise Exception(json_result.error) + + return json_result + except Exception as ex: + print(ex) + last_error = ex + await TweetyApi.sleep_wait(sleep_default, i) + self.api = PixivApi().init().api + continue + + raise Exception(last_error) + + async def illust_detail(self, id, retries = 3): + last_error = "" + sleep_default = 0.125 + for i in range(retries): + try: + json_result = self.api.illust_detail(id) + + if json_result.error != None: + if json_result.error.user_message == 'Page not found': + return None + if json_result.error.user_message == 'Artist has made their work private.': + return None + raise Exception(json_result.error) + + return json_result + except Exception as ex: + print(ex) + last_error = ex + await TweetyApi.sleep_wait(sleep_default, i) + self.api = PixivApi().init().api + continue + + raise Exception(last_error) + + async def ugoira_metadata(self, id, retries = 3): + last_error = "" + sleep_default = 0.125 + for i in range(retries): + try: + json_result = self.api.ugoira_metadata(id) + + if json_result.error != None: + if json_result.error.user_message == 'Page not found': + return None + if json_result.error.user_message == 'Artist has made their work private.': + return None + raise Exception(json_result.error) + + return json_result + except Exception as ex: + print(ex) + last_error = ex + await TweetyApi.sleep_wait(sleep_default, i) + self.api = PixivApi().init().api + continue + + raise Exception(last_error) + + async def download(self, url, retries = 3): + for i in range(retries): + try: + with BytesIO() as io_bytes: + def foo(): + self.api.download(url, fname=io_bytes) + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, foo) + return io_bytes.getbuffer().tobytes() + except Exception as ex: + print(ex) + if i == 3: + raise ex + await asyncio.sleep(i * 60) + + async def get_new_posts(self, botData: RuntimeBotData): + posts = [] + # get all page: + next_qs = {'word': 'けものフレンズ', 'search_target': 'partial_match_for_tags', 'sort': 'date_desc', 'filter': 'for_ios', 'end_date': ''} + + page = 0 + while next_qs: + print("Getting pixiv page") + has_new_posts = False + + json_result = await self.search_illust(retries=3, **next_qs) + + for illust in json_result.illusts: + if botData.db.pixiv_get_post(illust.id) == None: + has_new_posts = True + print(illust.id, illust.title, illust.create_date) + posts.append(illust) + + if has_new_posts == False: + break + + next_qs = self.api.parse_qs(json_result.next_url) + + if len(json_result.illusts) > 0: + oldest_date = json_result.illusts[-1].create_date[0:10] + if "end_date" in next_qs and oldest_date != next_qs["end_date"]: + next_qs["end_date"] = oldest_date + next_qs["offset"] = sum(1 for x in json_result.illusts if x.create_date[0:10] == oldest_date) + + page += 1 + await asyncio.sleep(1) + + return posts + + async def download_illust(self, illust, download_path = "Temp/"): + print(illust.id) + downloaded = [] + filepath = os.path.abspath(os.path.join(download_path, str(illust.user.id))) + os.makedirs(filepath, exist_ok=True) + + if illust.type == 'ugoira': + filename = str(illust.id) + "_p0.gif" + filepath = os.path.join(filepath, filename) + bytes = await self.download_ugoira_to_bytes(illust.id) + if bytes is None: + return [] + with open(filepath, "wb") as f: + f.write(bytes) + downloaded.append({"file":bytes, "name":filename}) + else: + if len(illust.meta_pages) == 0: + _, filename = os.path.split(illust.meta_single_page.original_image_url) + filepath = os.path.join(filepath, filename) + bytes = await self.download(illust.meta_single_page.original_image_url) + with open(filepath, "wb") as f: + f.write(bytes) + if len(downloaded) < 4: + downloaded.append({"file":bytes, "name":filename}) + else: + for page in illust.meta_pages: + filepath = os.path.abspath(os.path.join(download_path, str(illust.user.id))) + _, filename = os.path.split(page.image_urls.original) + filepath = os.path.join(filepath, filename) + bytes = await self.download(page.image_urls.original) + with open(filepath, "wb") as f: + f.write(bytes) + if len(downloaded) < 4: + downloaded.append({"file":bytes, "name":filename}) + + return downloaded + + async def download_ugoira_to_bytes(self, id): + closables = [] + + metadata = await self.ugoira_metadata(id) + if metadata is None: + return None + + if len(metadata.ugoira_metadata.zip_urls) > 1: + raise Exception + + bytes = await self.download(metadata.ugoira_metadata.zip_urls["medium"]) + with BytesIO(bytes) as bytes0, ZipFile(bytes0) as input_zip: + frames = metadata.ugoira_metadata.frames + zip_frames = {} + for name in input_zip.namelist(): + im = input_zip.read(name) + im_bytes = BytesIO(im) + im_im = Image.open(im_bytes) + closables += [im_bytes, im_im] + zip_frames[name] = im_im + + with BytesIO() as buffer: + zip_frames[frames[0].file].save( + buffer, + format="GIF", + save_all=True, + append_images=[zip_frames[frame.file] for frame in frames[1:]], + duration=[frame.delay for frame in frames], + loop=0, + optimize=False + ) + + [closable.close() for closable in closables] + return buffer.getbuffer().tobytes() + +if __name__ == "__main__": + async def get_new_posts(): + api = PixivApi().init() + while True: + await api.get_new_posts() + await asyncio.sleep(30 * 60) + + asyncio.run(get_new_posts()) \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cc3a2a2 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +Discord bot with Pixiv and X post scrapers and a postgresql database + +Will download all of Pixiv posts tagged KF when run + +Uses AI models for tagging and classification: https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/ + +Recommended python 3.13.x + +Postgresql database schema included in file: japariarchive_schema.sql \ No newline at end of file diff --git a/Twitter/downloader.py b/Twitter/downloader.py new file mode 100644 index 0000000..f103c89 --- /dev/null +++ b/Twitter/downloader.py @@ -0,0 +1,147 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +import asyncio +from datetime import datetime +import gc +import traceback +import tracemalloc + +from Classifier.classifyHelper import classify_all +from Database.x_classes import DownloadMode +from Discord import discordHelper +from Twitter import tweetHelper +from tweety.types import Tweet +from exceptions import ACCOUNT_DEAD, ACCOUNT_SKIP, DOWNLOAD_FAIL, NO_CHANNEL, OTHER_ERROR +from Database.x_classes import ActionTaken, DownloadMode, ErrorID, HavoxLabel, PostRating, x_posts, x_posts_images, x_accounts +from Database.db_schema import x_accounts as schema_x_accounts +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +async def download_loop(botData: RuntimeBotData): + guild = botData.client.guilds[0] + try: + results = {} + botData.new_accounts = [] + + for artist in botData.db.x_get_all_accounts(): + if artist.is_deleted: continue + + #sleep to avoid rate limits + await asyncio.sleep(5) + + print("Artist:", artist.name) + + #wait for ALL new posts to be found + try: + match artist.download_mode: + case DownloadMode.NO_DOWNLOAD: + continue + case DownloadMode.DOWNLOAD: + await discordHelper.ensure_has_channel_or_thread(artist, guild, botData.db) + new_posts = await tweetHelper.UpdateMediaPosts(artist, botData) + case DownloadMode.DOWNLOAD_ALL: + await discordHelper.ensure_has_channel_or_thread(artist, guild, botData.db) + new_posts = await tweetHelper.DownloadAllMediaPosts(artist, botData) + case _: + continue + except ACCOUNT_DEAD: + botData.db.x_update_account_properties(artist.id, [(schema_x_accounts.is_deleted, True)]) + continue + except ACCOUNT_SKIP: + continue + + if len(new_posts) == 0: continue + + new_posts_count = len([post for post in new_posts if len(post.media) > 0]) + if new_posts_count > 20: + #skips posting to discord if there are too many posts + botData.new_accounts.append(artist.name) + + new_posts.sort(key= lambda x: x.date) + + for tweet in new_posts: #posts should arrive here in chronological order + await download_post(artist, tweet, botData) + gc.collect() + print(tracemalloc.get_traced_memory()) + + results[artist.name] = new_posts_count + if artist.download_mode == DownloadMode.DOWNLOAD_ALL: + botData.db.x_update_account_properties(artist.id, [(schema_x_accounts.download_mode, DownloadMode.DOWNLOAD)]) + await discordHelper.post_result(results, guild, botData.new_accounts) + except Exception as ex: + print(ex) + await discordHelper.send_error(traceback.format_exc()[0:256], botData) + +async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotData): + x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text) + + if len(tweet.media) == 0: + x_post.error_id = ErrorID.NO_ART + botData.db.x_insert_post(x_post, commit=True) + return + + print("New media post:", str(tweet.url)) + media = await tweetHelper.GetTweetMediaUrls(tweet) + image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)] + + try: + downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, botData.session) + except DOWNLOAD_FAIL as e: + x_post.error_id = e.code + + botData.db.x_insert_post(x_post, commit=False) + for image in image_containers: + image.error_id = ErrorID.DOWNLOAD_FAIL + botData.db.x_insert_image(image, commit=False) + botData.db.conn.commit() + return + + def get_rating_value(rating): + return 4 if rating == PostRating.Explicit else 3 if rating == PostRating.Questionable else 2 if rating == PostRating.Sensitive else 1 if rating == PostRating.General else 0 + + vox_labels = [] + final_filtered_tags = {} + duplicates = [] + for idx, attachment in enumerate(downloaded_media): + container = image_containers[idx] + container.saved_file = attachment.file_name + container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox) + + if container.vox_label not in vox_labels: + vox_labels.append(container.vox_label) + + if container.phash != None: + duplicate = botData.db.x_search_duplicate(user_id=x_post.account_id, max_id = x_post.id, phash=container.phash) + if duplicate != None: + container.duplicate_id = duplicate.post_id + container.duplicate_index = duplicate.index + if duplicate.post_id not in duplicates: + duplicates.append(duplicate.post_id) + + x_post.tags = list(set(x_post.tags + container.tags)) + x_post.rating = container.rating if get_rating_value(container.rating) > get_rating_value(x_post.rating) else x_post.rating + final_filtered_tags = final_filtered_tags | filtered_tags + + is_filtered = len(vox_labels) == 1 and vox_labels[0] == HavoxLabel.Rejected + + try: + discord_post = await discordHelper.send_x_post(tweet, artist, botData.client.guilds[0], botData.new_accounts, downloaded_media, is_filtered, rating=x_post.rating, tags=final_filtered_tags, vox_labels = vox_labels, duplicate_posts=duplicates, xView=botData.xView, yView=botData.yView) + except (NO_CHANNEL, OTHER_ERROR) as e: + x_post.error_id = e.code + x_post.discord_post_id = 0 + else: + x_post.discord_post_id = discord_post.id + + x_post.action_taken = ActionTaken.Rejected if is_filtered else ActionTaken.Null + + try: + if not botData.db.x_insert_post(x_post, commit = False): + raise Exception("Transaction error") + + for image in image_containers: + botData.db.x_insert_image(image, False) + except Exception as ex: + botData.db.conn.rollback() + raise ex + else: + botData.db.conn.commit() diff --git a/Twitter/tweetHelper.py b/Twitter/tweetHelper.py new file mode 100644 index 0000000..45d4051 --- /dev/null +++ b/Twitter/tweetHelper.py @@ -0,0 +1,131 @@ +from __future__ import annotations +import os +from Database.x_classes import x_accounts +from config import Global_Config +import downloadHelper +from tweety.types.twDataTypes import Tweet + +from exceptions import ACCOUNT_DEAD, ACCOUNT_SKIP +from tweety.exceptions_ import UserNotFound, UserProtected + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +class TweetMedia: + url : str + file_name : str + + def __init__(self, url, file_name): + self.url = url + self.file_name = file_name + +class DownloadedMedia: + file_bytes : str + file_name : str + + def __init__(self, bytes, file_name): + self.file_bytes = bytes + self.file_name = file_name + +async def GetTweetMedia(tweet : Tweet) -> list[TweetMedia]: + mediaList : list[TweetMedia] = [] + for idx, media in enumerate(tweet.media): + if media.file_format == 'mp4': + best_stream = await media.best_stream() + fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{media.file_format}" + mediaList.append(TweetMedia(best_stream.direct_url, fileName)) + else: + best_stream = await media.best_stream() + extension = best_stream.file_format + fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{extension}" + mediaList.append(TweetMedia(best_stream.direct_url, fileName)) + + return mediaList + +async def GetTweetMediaUrls(tweet : Tweet): + mediaList = await GetTweetMedia(tweet) + return [media.url for media in mediaList] + +async def DownloadMedia(post_id, account_id, account_name, url_list : list, session) -> list[DownloadedMedia]: + result : list[DownloadedMedia] = [] + path = f"{Global_Config("x_download_path")}{account_id}" + os.makedirs(path, exist_ok=True) + + for idx, file_url in enumerate(url_list): + file_name = get_file_name(account_name, post_id, idx, file_url) + full_path = f"{path}/{file_name}" + + photo_bytes = await downloadHelper.save_to_file(file_url, full_path, session) + + result.append(DownloadedMedia(photo_bytes, file_name)) + + return result + +def get_file_name(account_name: str, post_id: int, image_index: int, image_url: str, account_id : int = None, base_path : str = None): + ''' + `account_id` and `base_path` are optional\n + In `base_path`, do not include trailing slash\n + Example if none are defined:\n `file_name.ext` + Example if `base_path` is defined:\n `c:/base_path/file_name.ext` + Example if `account_id` is defined:\n `account_id/file_name.ext` + Example if both are defined:\n `c:/base_path/account_id/file_name.ext` + ''' + + ext = image_url.split("?")[0].split(".")[-1] + file_name = f"{account_name}_{post_id}_{image_index}.{ext}" + if account_id != None and base_path != None: + return f"{base_path}/{account_id}/{file_name}" + elif base_path != None: + return f"{base_path}/{file_name}" + elif account_id != None: + return f"{account_id}/{file_name}" + return file_name + +async def UpdateMediaPosts(account : x_accounts, botData : RuntimeBotData) -> list[Tweet]: + all_posts = botData.db.get_all_post_ids(account.id) + newest_post = 1 if len(all_posts) == 0 else max(all_posts) + posts = [] + + try: + posts = [tweet async for tweet in botData.twApi.get_tweets(user_name = account.name, bottom_id = newest_post, all_posts = all_posts)] + + except (UserProtected, UserNotFound) as ex: + print("User dead: ", account.name, ex) + raise ACCOUNT_DEAD(ex) + except Exception as ex: + print("Error in ", account.name, ex) + raise ACCOUNT_SKIP(ex) + + return posts + +async def DownloadAllMediaPosts(account : x_accounts, botData : RuntimeBotData) -> list[Tweet]: + all_posts = botData.db.get_all_post_ids(account.id) + posts = [] + + try: + async for tweet in botData.twApi.get_tweets(user_name = account.name, bottom_id = 1, all_posts = []): + if int(tweet.id) not in all_posts: + posts.append(tweet) + + except (UserProtected, UserNotFound) as ex: + print("User dead: ", account.name, ex) + raise ACCOUNT_DEAD(ex) + except Exception as ex: + print("Error in ", account.name, ex) + raise ACCOUNT_SKIP(ex) + + return posts + +def parse_x_url(url : str): + "return account (handle, post id) from full X post url" + url = url.replace("https://", "").replace("http://", "") + split = url.split("?") + if len(split) > 0: + url = split[0] + + split = url.split('/') + if split[2] != "status": + raise Exception("Invalid Format") + + return split[1], int(split[3]) \ No newline at end of file diff --git a/Twitter/tweetyapi.py b/Twitter/tweetyapi.py new file mode 100644 index 0000000..c914df9 --- /dev/null +++ b/Twitter/tweetyapi.py @@ -0,0 +1,102 @@ +import asyncio +from typing import AsyncIterator +from tweety import TwitterAsync +from tweety.types.twDataTypes import Tweet, SelfThread, ConversationThread +from tweety.types.usertweet import UserMedia +from tweety.exceptions_ import RateLimitReached +from config import Global_Config + +class TweetyApi: + async def init(self, skip_login = False, session_name = "session"): + if skip_login: + self.app = TwitterAsync(session_name) + else: + if Global_Config["x_cookies"] == None: + raise Exception("X cookies are required") + cookies = Global_Config["x_cookies"] + self.app = TwitterAsync(session_name) + await self.app.load_cookies(cookies) + print(self.app.user) + return self + + async def get_tweet(self, url): + try: + tweet = await self.app.tweet_detail(url) + return tweet + except: + return None + + async def get_tweets(self, user_name, bottom_id, all_posts : list) -> AsyncIterator[Tweet]: + def validate_tweet(tweet : Tweet): + tweet_id_num = int(tweet.id) + + past_bounds = False + tweet_valid = True + + if tweet_id_num <= bottom_id: + past_bounds = True + if tweet_id_num in all_posts: + tweet_valid = False + + return past_bounds, tweet_valid + + sleep_default = 0.125 + sleep_exponent = 1 + user = None + + while user == None: + try: + user = await self.app.get_user_info(username=user_name) + except RateLimitReached as ex: + sleep_exponent = await self.sleep_wait(sleep_default, sleep_exponent) + except Exception as ex: + print("User error: " + str(ex)) + raise ex + + tweety_api = UserMedia(user.rest_id, self.app, 1, 2, None) + sleep_exponent = 1 + + while True: + await asyncio.sleep(5) + old_cursor = tweety_api.cursor + + try: + tweets = await tweety_api.get_next_page() + sleep_exponent = 1 + except RateLimitReached as ex: + sleep_exponent = await self.sleep_wait(sleep_default, sleep_exponent) + tweety_api.cursor = old_cursor + continue + except Exception as ex: + raise ex + + has_valid_tweets = False + for tweet in tweets: + if isinstance(tweet, ConversationThread) | isinstance(tweet, SelfThread): + tweet:ConversationThread | SelfThread + for tweet1 in tweet.tweets: + _, tweet_valid = validate_tweet(tweet1) + if tweet_valid: + has_valid_tweets = True + yield tweet1 + else: + past_bounds, tweet_valid = validate_tweet(tweet) + if past_bounds: continue + if tweet_valid: + has_valid_tweets = True + yield tweet + + if len(tweets) == 0 or not has_valid_tweets: + break + await asyncio.sleep(1) + + @staticmethod + async def sleep_wait(sleep_default, sleep_exponent): + sleep_amount = min(sleep_default * pow(2,sleep_exponent), 2) + print(f"Sleeping for {round(sleep_amount,2)} hours.") + await asyncio.sleep(sleep_amount * 60 * 60) + print("Sleep done") + sleep_exponent += 1 + return sleep_exponent + +#asyncio.run(TweetyApi().get_tweets("redhood_depth", 0)) \ No newline at end of file diff --git a/Twitter/twitterContainer.py b/Twitter/twitterContainer.py new file mode 100644 index 0000000..48428d4 --- /dev/null +++ b/Twitter/twitterContainer.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from Database.x_classes import AccountRating, DownloadMode, ErrorID +from tweety.exceptions_ import UserNotFound, UserProtected + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +class TwitterContainer(): + id : int + name : str + rating : str + discord_channel_id : int = 0 + discord_thread_id : int = 0 + download_mode : int = 0 + + def __init__(self, name : str, rating = AccountRating.NSFW, discord_channel_id = 0, discord_thread_id = 0, download_mode = DownloadMode.NO_DOWNLOAD, id = 0, **kwargs): + self.id = id + self.name = name + self.rating = rating.upper() + self.discord_channel_id = discord_channel_id + self.discord_thread_id = discord_thread_id + self.download_mode = download_mode + + async def UpdateMediaPosts(self, botData : RuntimeBotData): + all_posts = botData.db.get_all_post_ids(self.id) + newest_post = 1 if len(all_posts) == 0 else max(all_posts) + posts = [] + + try: + async for tweet in botData.twApi.get_tweets(user_name = self.name, bottom_id = newest_post, all_posts = all_posts): + posts.append(tweet) + + except (UserProtected, UserNotFound) as ex: + print("User dead: ", self.name, ex) + return ErrorID.ACCOUNT_DEAD, [] + except Exception as ex: + print("Error in ", self.name, ex) + return ErrorID.ACCOUNT_SKIP, [] + + return ErrorID.SUCCESS, posts \ No newline at end of file diff --git a/bot.py b/bot.py new file mode 100644 index 0000000..bb6ddac --- /dev/null +++ b/bot.py @@ -0,0 +1,54 @@ +from datetime import datetime +import os + +from config import Global_Config + +os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' +from Twitter.downloader import download_loop as x_download_loop +from Pixiv.downloader import download_loop as pixiv_download_loop + +#import asyncio +import nextcord +from runtimeBotData import RuntimeBotData +from nextcord.ext import tasks, commands + +# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +botData = RuntimeBotData() +#loop.run_until_complete(botData.initialize_data()) + +intents = nextcord.Intents.default() +intents.message_content = True +intents.guild_reactions = True + +bot = commands.Bot(command_prefix=".", intents=intents) +botData.client = bot +bot.load_extension("commands", extras={"botData": botData}) + +@bot.event +async def on_ready(): + if botData.initialized: return + await botData.initialize_data() + bot.add_view(botData.xView) + bot.add_view(botData.yView) + botData.initialized = True + print(f'{bot.user} has connected to nextcord!') + x_loop.start() + pixiv_loop.start() + +@tasks.loop(minutes = 3 * 60) +async def x_loop(): + print(datetime.now().strftime("%d/%m/%Y %H:%M:%S")) + await x_download_loop(botData) + print("Finished: " + datetime.now().strftime("%d/%m/%Y %H:%M:%S")) + +@tasks.loop(minutes = 30) +async def pixiv_loop(): + await pixiv_download_loop(botData) + +if Global_Config["discord_token"] is None: + raise Exception("Discord bot token is required") +bot.run(Global_Config["discord_token"]) + diff --git a/commands.py b/commands.py new file mode 100644 index 0000000..b0fc3f5 --- /dev/null +++ b/commands.py @@ -0,0 +1,250 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import requests + +from Classifier.classifyHelper import classify_all +from Database.x_classes import AccountRating, ActionTaken, DownloadMode, PostRating, x_posts, x_posts_images +from Database.db_schema import x_posts as schema_x_posts, x_accounts as schema_x_accounts +from Discord import discordHelper +from Twitter import tweetHelper +from exceptions import NO_CHANNEL, OTHER_ERROR +from helpers import get_rating_value +from helpers_adv import create_x_account +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +import importlib +import nextcord +from nextcord.ext import commands +from nextcord import application_command + +class Commands(commands.Cog): + def __init__(self, botData : RuntimeBotData): + self.botData = botData + self.db = botData.db + self._last_member = None + + @application_command.slash_command() + @commands.has_permissions(administrator=True) + async def order(self, interaction: nextcord.Interaction): + guild = interaction.guild + categories = list(filter(lambda cat: cat.name.startswith("TWITTER"), guild.categories)) + channels = list(filter(lambda chan: not isinstance(chan, nextcord.CategoryChannel) and chan.category.name.startswith("TWITTER"), guild.channels)) + channelsOrdered = sorted(channels, key=lambda chan: chan.name) + + for idx,category in enumerate(categories): + await category.edit(name=f"temp-{idx}") + + last_id = -1 + target_category = None + for idx, channel in enumerate(channelsOrdered): + id = idx//50 + if id > last_id: + last_id = id + target_category = await guild.create_category(f"TWITTER-{id}") + print("moving", channel.name, "to", target_category.name) + await channel.edit(category=target_category) + + for category in categories: + print("deleting", category.name) + await category.delete() + + @application_command.slash_command(description="Downloads single post from (domain)/(account_name)/status/(post_id) format") + async def download_post(self, interaction: nextcord.Interaction, full_link : str, new_account_rating : str = nextcord.SlashOption(choices= [AccountRating.NSFW, AccountRating.SFW, AccountRating.NSFL], default = AccountRating.NSFW), download_mode : int = nextcord.SlashOption(choices={"No":DownloadMode.NO_DOWNLOAD, "Yes":DownloadMode.DOWNLOAD, "More Yes":DownloadMode.DOWNLOAD_ALL}, default=DownloadMode.NO_DOWNLOAD)): + botData = self.botData + + try: + handle, post_id = tweetHelper.parse_x_url(full_link) + except: + await interaction.response.send_message("Invalid url. Format should be (domain)/(account_name)/status/(post_id)") + return + + await interaction.response.defer() + + query_result = botData.db.x_get_post_from_id(post_id) + if query_result is not None: + if query_result.discord_post_id == 0: + await interaction.followup.send(content=f'Post was previously deleted. This is not implemented yet.' ) + return + else: + await interaction.followup.send(content=f'Post was previously downloaded')#: https://discord.com/channels/{client.guilds[0].id}/{query_result.discord_channel_id}/{query_result["discord_post_id"]}' ) + return + else: + try: + tweet = await botData.twApi.get_tweet(f'x.com/{handle}/status/{post_id}') + x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text) + + artist = botData.db.x_get_account_by_name(handle) + + if artist is None: + artist = await create_x_account(handle, botData, new_account_rating, download_mode = download_mode) + + await discordHelper.ensure_has_channel_or_thread(artist, interaction.guild, botData.db) + + image_containers : list[x_posts_images] = [] + media = await tweetHelper.GetTweetMediaUrls(tweet) + image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)] + downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, botData.session) + + vox_labels = [] + final_filtered_tags = {} + duplicates = [] + for idx, attachment in enumerate(downloaded_media): + container = image_containers[idx] + container.saved_file = attachment.file_name + container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox) + + if container.vox_label not in vox_labels: + vox_labels.append(container.vox_label) + + if container.phash != None: + duplicate = botData.db.x_search_duplicate(user_id=x_post.account_id, max_id = x_post.id, phash=container.phash) + if duplicate != None: + container.duplicate_id = duplicate.post_id + container.duplicate_index = duplicate.index + if duplicate.post_id not in duplicates: + duplicates.append(duplicate.post_id) + + x_post.tags = list(set(x_post.tags + container.tags)) + x_post.rating = container.rating if get_rating_value(container.rating) > get_rating_value(x_post.rating) else x_post.rating + final_filtered_tags = final_filtered_tags | filtered_tags + + try: + discord_post = await discordHelper.send_x_post(tweet, artist, interaction.guild, [artist.name], downloaded_media, False, rating=x_post.rating, tags=final_filtered_tags, auto_approve=True, vox_labels = vox_labels, duplicate_posts=duplicates, xView=botData.xView, yView=botData.yView) + except (NO_CHANNEL, OTHER_ERROR) as e: + x_post.error_id = e.code + x_post.discord_post_id = 0 + else: + x_post.discord_post_id = discord_post.id + + x_post.action_taken = ActionTaken.Accepted + + try: + if not botData.db.x_insert_post(x_post, commit = False): + raise Exception("Transaction error") + + for image in image_containers: + botData.db.x_insert_image(image, False) + except Exception as ex: + botData.db.conn.rollback() + raise ex + else: + botData.db.conn.commit() + + if discord_post is not None: + embeds = discordHelper.build_secondary_embed(discord_post, artist.name, tweet) + await discordHelper.edit_existing_embed_color(discord_post, nextcord.Colour.green()) + await interaction.followup.send(embeds=embeds) + return + else: + await interaction.followup.send(content=f'Posting failed: {x_post.error_id}') + return + except Exception as e: + await interaction.followup.send(content=e) + return + + @application_command.slash_command() + @commands.has_permissions(administrator=True) + async def query(self, interaction: nextcord.Interaction, query: str): + self.db.cursor.execute(query) + self.db.conn.commit() + await interaction.response.send_message("ok") + + @application_command.slash_command() + @commands.has_permissions(administrator=True) + async def ignore_tag(self, interaction: nextcord.Interaction, tag: str): + self.botData.classifier.add_ignored_tag(tag) + await interaction.response.send_message("ok", ephemeral=True) + + @application_command.slash_command() + async def add_x_account(self, interaction: nextcord.Interaction, handle: str, rating : str = nextcord.SlashOption(choices= [AccountRating.NSFW, AccountRating.SFW, AccountRating.NSFL], default = AccountRating.NSFW), download_mode : int = nextcord.SlashOption(choices={"No":DownloadMode.NO_DOWNLOAD, "Yes":DownloadMode.DOWNLOAD, "Yes (check all posts)":DownloadMode.DOWNLOAD_ALL}, default=DownloadMode.DOWNLOAD_ALL)): + await interaction.response.defer() + if "x.com" in handle: + handle = handle.split('/') + handle = handle[handle.index('x.com')+1] + try: + print(handle) + result = requests.get(f"https://api.vxtwitter.com/{handle}") + if result.status_code != 200: + raise Exception("Failed to get user id") + id = result.json()["id"] + existing_account = self.db.x_get_account_by_id(id) + if existing_account != None: + if existing_account.download_mode == download_mode: + raise Exception("Account is already on download list") + self.db.x_update_account_properties(id, updates=[(schema_x_accounts.download_mode, download_mode)]) + if download_mode == DownloadMode.DOWNLOAD_ALL: + await interaction.followup.send("Added " + handle + " to download list (with full check)") + elif download_mode == DownloadMode.NO_DOWNLOAD: + await interaction.followup.send("Removed " + handle + " from download list") + elif download_mode == DownloadMode.DOWNLOAD: + await interaction.followup.send("Added " + handle + " to download list") + else: + await interaction.followup.send("huh??") + else: + account = await create_x_account(handle, self.botData, rating, download_mode=download_mode) + await interaction.followup.send("Added " + handle) + except Exception as ex: + print(ex) + await interaction.followup.send(str(ex)) + + @application_command.message_command(guild_ids=[1043267878851457096]) + async def archive_set_KF(self, interaction: nextcord.Interaction, message: nextcord.Message): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + query =f'SELECT {schema_x_posts.id}, {schema_x_posts.account_id}, {schema_x_posts.action_taken} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s ' + result = self.db.query_get(query, ([message.id]), count = 1) + if result is None: + await interaction.response.send_message("post not found in database", ephemeral=True) + return + if result[schema_x_posts.action_taken] == ActionTaken.Accepted: + await discordHelper.edit_existing_embed_color(message, nextcord.Colour.green()) + await interaction.response.send_message("post was already KF", ephemeral=True) + return + + await discordHelper.edit_existing_embed_color(message, nextcord.Colour.green()) + result = self.db.x_update_post(result[schema_x_posts.id], message.id, 0, ActionTaken.Accepted) + await interaction.response.send_message("post approved", ephemeral=True) + + @application_command.message_command(guild_ids=[1043267878851457096]) + async def archive_set_nonKF(self, interaction: nextcord.Interaction, message: nextcord.Message): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + query =f'SELECT {schema_x_posts.id}, {schema_x_posts.account_id}, {schema_x_posts.action_taken} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s ' + result = self.db.query_get(query, ([message.id]), count = 1) + if result is None: + await interaction.response.send_message("post not found in database", ephemeral=True) + return + if result[schema_x_posts.action_taken] == ActionTaken.Hidden: + await discordHelper.edit_existing_embed_color(message, nextcord.Colour.yellow()) + await interaction.response.send_message("post was already NonKF", ephemeral=True) + return + + await discordHelper.edit_existing_embed_color(message, nextcord.Colour.yellow()) + result = self.db.x_update_post(result[schema_x_posts.id], message.id, 0, ActionTaken.Hidden) + await interaction.response.send_message("post hidden", ephemeral=True) + + @application_command.message_command(guild_ids=[1043267878851457096]) + async def archive_delete(self, interaction: nextcord.Interaction, message: nextcord.Message): + if not discordHelper.check_permission(interaction): + await interaction.response.send_message("No permission", ephemeral=True) + return + + query =f'SELECT {schema_x_posts.id} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s' + result = self.db.query_get(query, ([message.id]), count = 1) + if result is None: + await interaction.response.send_message("post not found in database", ephemeral=True) + return + + await message.delete() + result = self.db.x_update_post(result["id"], 0, 0, ActionTaken.Rejected) + await interaction.response.send_message("post deleted", ephemeral=True) + +def setup(bot: commands.Bot, botData: RuntimeBotData): + bot.add_cog(Commands(botData)) + print("Loaded Commands") \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..cc8ec80 --- /dev/null +++ b/config.py @@ -0,0 +1,45 @@ +import json +import os + +class Config: + name = 'config' + config = { + "discord_token": None, + #format: "dbname=xxx user=xxx password=xxx" + "postgresql_conninfo": None, + #format: "auth_token=xxx; ct0=xxx;" + "x_cookies": None, + #format: "C:/Storage/X/" - remember about trailing / + "x_download_path": None, + "pixiv_token": None, + "pixiv_username": None, + "pixiv_password": None, + #format: "C:/Storage/Pixiv/" - remember about trailing / + "pixiv_download_path": None + } + + def __init__(self, name : str = "config"): + '"config" name is used as global config' + self.name = name + if os.path.exists(f"configs/{name}.json"): + with open(f"configs/{name}.json", "rt") as f: + self.config = json.load(f) + else: + self.save() + + def __getitem__(self, key): + if key not in self.config: + self.config[key] = None + self.save() + return self.config[key] + + def __setitem__(self, key, value): + self.config[key] = value + self.save() + + def save(self): + os.makedirs("configs", exist_ok=True) + with open(f"configs/{self.name}.json", "wt") as f: + json.dump(self.config, f, indent=1) + +Global_Config = Config() \ No newline at end of file diff --git a/downloadHelper.py b/downloadHelper.py new file mode 100644 index 0000000..d48fd30 --- /dev/null +++ b/downloadHelper.py @@ -0,0 +1,29 @@ +import os +import shutil +import aiohttp + +from exceptions import DOWNLOAD_FAIL + +async def downloadFileToMemory(url, session : aiohttp.ClientSession): + print("Downloading", url) + try: + async with session.get(url) as resp: + result = await resp.read() + if len(result) == 0: + raise DOWNLOAD_FAIL("received 0 bytes") + else: + return result + except Exception as e: + print(url, e) + raise DOWNLOAD_FAIL(e) + +async def save_to_file(url, out_file, session): + file_bytes = await downloadFileToMemory(url, session) + + if os.path.exists(out_file): + shutil.copy(out_file, out_file + "_bck") + + with open(out_file, "wb") as binary_file: + binary_file.write(file_bytes) + + return file_bytes \ No newline at end of file diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 0000000..ee22b83 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,37 @@ +from Database.x_classes import ErrorID + +class NO_ART(Exception): + code = ErrorID.NO_ART + +class OTHER_ERROR(Exception): + code = ErrorID.OTHER_ERROR + +class DOWNLOAD_FAIL(Exception): + code = ErrorID.DOWNLOAD_FAIL + +class ACCOUNT_DEAD(Exception): + code = ErrorID.ACCOUNT_DEAD + +class ACCOUNT_SKIP(Exception): + code = ErrorID.ACCOUNT_SKIP + +class TAGGING_ERROR(Exception): + code = ErrorID.TAGGING_ERROR + +class FILE_READ_ERROR(Exception): + code = ErrorID.FILE_READ_ERROR + +class HASHING_Unidentified(Exception): + code = ErrorID.HASHING_Unidentified + +class HASHING_FileNotFound(Exception): + code = ErrorID.HASHING_FileNotFound + +class HASHING_OTHER(Exception): + code = ErrorID.HASHING_OTHER + +class FILE_DELETED_OUTSIDE_ARCHIVE(Exception): + code = ErrorID.FILE_DELETED_OUTSIDE_ARCHIVE + +class NO_CHANNEL(Exception): + code = ErrorID.NO_CHANNEL \ No newline at end of file diff --git a/getPixivpyToken b/getPixivpyToken new file mode 160000 index 0000000..93544c2 --- /dev/null +++ b/getPixivpyToken @@ -0,0 +1 @@ +Subproject commit 93544c23c75c462a43555c909640ac9f9c2cf40c diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..6c8c236 --- /dev/null +++ b/helpers.py @@ -0,0 +1,24 @@ +from Database.x_classes import PostRating + + +def twos_complement(hexstr, bits): + value = int(hexstr,16) #convert hexadecimal to integer + + #convert from unsigned number to signed number with "bits" bits + if value & (1 << (bits-1)): + value -= 1 << bits + return value + +def get_rating_value(rating): + "for deciding final post rating" + match rating: + case PostRating.Explicit: + return 4 + case PostRating.Questionable: + return 3 + case PostRating.Sensitive: + return 2 + case PostRating.General: + return 1 + case _: + return 0 \ No newline at end of file diff --git a/helpers_adv.py b/helpers_adv.py new file mode 100644 index 0000000..fde4352 --- /dev/null +++ b/helpers_adv.py @@ -0,0 +1,26 @@ +#helpers that require other helpers +from __future__ import annotations +from typing import TYPE_CHECKING + +from Database.x_classes import AccountRating, DownloadMode, x_accounts +from Discord import discordHelper +if TYPE_CHECKING: + from runtimeBotData import RuntimeBotData + +async def create_x_account(handle, botData : RuntimeBotData, new_account_rating = AccountRating.NSFW, download_mode = DownloadMode.NO_DOWNLOAD): + id = await botData.twApi.app.get_user_id(handle) + try: + thread = None + channel = await discordHelper.get_channel_from_handle(botData.client.guilds[0], handle) + except: + thread = await discordHelper.get_thread_from_handle(botData.client.guilds[0], handle) + channel = None + artist = x_accounts(id= id, + rating= new_account_rating, + name= handle, + discord_channel_id= 0 if channel is None else channel.id, + discord_thread_id= 0 if thread is None else thread.id, + download_mode= download_mode) + if not botData.db.x_add_account(artist): + raise Exception("Failed to add to database") + return artist \ No newline at end of file diff --git a/japariarchive_schema.sql b/japariarchive_schema.sql new file mode 100644 index 0000000..0880664 Binary files /dev/null and b/japariarchive_schema.sql differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e60da60 Binary files /dev/null and b/requirements.txt differ diff --git a/runtimeBotData.py b/runtimeBotData.py new file mode 100644 index 0000000..3fa499c --- /dev/null +++ b/runtimeBotData.py @@ -0,0 +1,36 @@ +import aiohttp +import nextcord +from Discord.views import XView, YView +from Pixiv.pixivapi import PixivApi +from Twitter.tweetyapi import TweetyApi +from Classifier.wdClassifier import WDClassifier +from Database.dbcontroller import DatabaseController +from Classifier.havoxClassifier import VoxClassifier + +class RuntimeBotData: + initialized = False + dead_accounts = [] + new_accounts = [] + + client : nextcord.Client = None + twApi : TweetyApi = None + pixivApi : PixivApi = None + db : DatabaseController = None + vox : VoxClassifier = None + classifier : WDClassifier = None + session : aiohttp.ClientSession = None + + xView : XView = None + yView : YView = None + + async def initialize_data(self): + self.twApi = await TweetyApi().init() + self.pixivApi = PixivApi().init() + self.db = DatabaseController() + self.vox = VoxClassifier() + self.classifier = WDClassifier() + self.xView = XView(self) + self.yView = YView(self) + + #connector = aiohttp.TCPConnector(limit=60) + self.session = aiohttp.ClientSession() \ No newline at end of file