You've already forked JapariArchive
fresh start
This commit is contained in:
19
.gitignore
vendored
Normal file
19
.gitignore
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
Classify/*
|
||||
token.json
|
||||
google.json
|
||||
*__pycache__*
|
||||
test*.py
|
||||
session.json
|
||||
database.db
|
||||
token.txt
|
||||
Temp*
|
||||
Classifier/sample*.*
|
||||
Classifier/HAV0X/*
|
||||
!Classifier/HAV0X/readme.txt
|
||||
Classifier/SmilingWolf/*
|
||||
!Classifier/SmilingWolf/readme.txt
|
||||
session.tw_session
|
||||
session.json
|
||||
.vscode/launch.json
|
||||
debug.log
|
||||
configs/*
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "getPixivpyToken"]
|
||||
path = getPixivpyToken
|
||||
url = https://github.com/eggplants/get-pixivpy-token.git
|
||||
1
Classifier/HAV0X/readme.txt
Normal file
1
Classifier/HAV0X/readme.txt
Normal file
@@ -0,0 +1 @@
|
||||
put all files from https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter/tree/main here
|
||||
1
Classifier/SmilingWolf/readme.txt
Normal file
1
Classifier/SmilingWolf/readme.txt
Normal file
@@ -0,0 +1 @@
|
||||
put model.onnx and selected_tags.csv from https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/tree/main here
|
||||
36
Classifier/classifyHelper.py
Normal file
36
Classifier/classifyHelper.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
import imagehash
|
||||
from Classifier.havoxClassifier import VoxClassifier
|
||||
from Classifier.wdClassifier import WDClassifier
|
||||
from Database.x_classes import ErrorID, HavoxLabel, PostRating
|
||||
from helpers import twos_complement
|
||||
|
||||
async def classify_all(photo_bytes, wd_classifier: WDClassifier, vox_classifier : VoxClassifier):
|
||||
vox_label = await vox_classifier.classify_async(photo_bytes)
|
||||
is_filtered = vox_label == HavoxLabel.Rejected
|
||||
|
||||
tags = []
|
||||
filtered_tags = {}
|
||||
phash = None
|
||||
dhash = None
|
||||
error_id = ErrorID.SUCCESS
|
||||
if not is_filtered:
|
||||
rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6)
|
||||
tags = list(tags.keys())
|
||||
|
||||
try:
|
||||
with BytesIO(photo_bytes) as image_file:
|
||||
with Image.open(image_file) as imag:
|
||||
phash = twos_complement(str(imagehash.phash(imag)), 64)
|
||||
dhash = twos_complement(str(imagehash.dhash(imag)), 64)
|
||||
except UnidentifiedImageError:
|
||||
error_id = ErrorID.HASHING_Unidentified
|
||||
except FileNotFoundError:
|
||||
error_id = ErrorID.HASHING_FileNotFound
|
||||
except Exception as ex:
|
||||
error_id = ErrorID.HASHING_OTHER
|
||||
else:
|
||||
rating = PostRating.Filtered
|
||||
|
||||
return vox_label, rating, tags, filtered_tags, phash, dhash, error_id
|
||||
46
Classifier/havoxClassifier.py
Normal file
46
Classifier/havoxClassifier.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import asyncio
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from transformers import ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
class VoxClassifier():
|
||||
def __init__(self):
|
||||
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
||||
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
|
||||
|
||||
def classify(self, image_bytes):
|
||||
try:
|
||||
with BytesIO(image_bytes) as image_file:
|
||||
with Image.open(image_file).convert("RGB") as image:
|
||||
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
# model predicts one of the 3 tags
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
return self.model.config.id2label[predicted_class_idx]
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return "Rejected"
|
||||
|
||||
async def classify_async(self, image_bytes):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(self.classify, image_bytes)
|
||||
|
||||
while not future.done():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return future.result()
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
classifier = VoxClassifier()
|
||||
images = [
|
||||
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
|
||||
]
|
||||
for image in images:
|
||||
with open(image, "rb") as file:
|
||||
result = await classifier.classify_async(file.read())
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
238
Classifier/wdClassifier.py
Normal file
238
Classifier/wdClassifier.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
import onnxruntime as onnx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@dataclass
|
||||
class LabelData:
|
||||
names: list[str]
|
||||
rating: list[np.int64]
|
||||
general: list[np.int64]
|
||||
character: list[np.int64]
|
||||
|
||||
class WDClassifier():
|
||||
onnxSession = None
|
||||
relevant_tags = []
|
||||
ignored_tags = []
|
||||
|
||||
def __init__(self):
|
||||
self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv")
|
||||
self.relevant_tags = self.get_relevant_tags()
|
||||
self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx')
|
||||
|
||||
def classify(self, image_bytes, max_count = -1, tag_threshold = -1):
|
||||
try:
|
||||
with BytesIO(image_bytes) as image_file:
|
||||
with Image.open(image_file) as image:
|
||||
image_data = self.resize_image(image, 448)
|
||||
# convert to numpy array and add batch dimension
|
||||
image_data = np.array(image_data, np.float32)
|
||||
image_data = np.expand_dims(image_data, axis=0)
|
||||
# NHWC image RGB to BGR
|
||||
image_data = image_data[..., ::-1]
|
||||
|
||||
outputs = self.onnxSession.run(None, {'input': image_data})
|
||||
outputs = outputs[0].squeeze(0)
|
||||
|
||||
ratings, general, filtered = self.processRating(outputs)
|
||||
rating = max(ratings, key=ratings.get)
|
||||
|
||||
if tag_threshold != -1:
|
||||
general = self.filterThreshold(general, tag_threshold)
|
||||
filtered = self.filterThreshold(filtered, tag_threshold)
|
||||
|
||||
if max_count != -1:
|
||||
general = self.filterCount(general, max_count)
|
||||
filtered = self.filterThreshold(filtered, tag_threshold)
|
||||
|
||||
return rating, general, filtered
|
||||
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return "", {}, {}
|
||||
|
||||
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
|
||||
|
||||
while not future.done():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return future.result()
|
||||
|
||||
def processRating(self, data):
|
||||
ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating}
|
||||
general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)}
|
||||
filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags}
|
||||
return ratings, general, filtered
|
||||
|
||||
def filterThreshold(self, data, threshold = 0.6):
|
||||
output = {key:value for key,value in data.items() if value > threshold}
|
||||
return output
|
||||
|
||||
def filterCount(self, data, count = 20):
|
||||
keys = list(data.keys())
|
||||
values = list(data.values())
|
||||
results_ordered = np.flip(np.argsort(values))
|
||||
output = {}
|
||||
for idx, dict_index in enumerate(results_ordered):
|
||||
if idx < count:
|
||||
output[keys[dict_index]] = values[dict_index]
|
||||
else:
|
||||
break
|
||||
return output
|
||||
|
||||
def filterTagList(self, data, tagList):
|
||||
output = []
|
||||
for result in data:
|
||||
if result["label"] in tagList:
|
||||
output.append(result)
|
||||
return output
|
||||
|
||||
def resize_image(self, image : Image.Image, size = 448):
|
||||
image = WDClassifier.pil_ensure_rgb(image)
|
||||
image = WDClassifier.pil_pad_square(image)
|
||||
image = WDClassifier.pil_resize(image, size)
|
||||
return image
|
||||
|
||||
def get_tag_index(self, tag: str):
|
||||
try:
|
||||
idx = self.labels.names.index(tag)
|
||||
except:
|
||||
print(tag, "not found")
|
||||
idx = -1
|
||||
return idx
|
||||
|
||||
def get_relevant_tags(self):
|
||||
if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"):
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
||||
file.write("")
|
||||
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file:
|
||||
self.ignored_tags = file.read().splitlines()
|
||||
|
||||
tags = {tag:index for index,tag in enumerate(self.labels.names)}
|
||||
for tag in self.ignored_tags:
|
||||
if tag in tags:
|
||||
del(tags[tag])
|
||||
return list(tags.values())
|
||||
|
||||
def add_ignored_tag(self, tag: str):
|
||||
idx = self.get_tag_index(tag)
|
||||
if idx != -1 and idx in self.relevant_tags:
|
||||
self.ignored_tags.append(tag)
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
||||
file.write("\n".join(self.ignored_tags))
|
||||
self.relevant_tags.remove(idx)
|
||||
|
||||
@staticmethod
|
||||
def load_labels(csv_path):
|
||||
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
||||
tag_data = LabelData(
|
||||
names=df["name"].tolist(),
|
||||
rating=list(np.where(df["category"] == 9)[0]),
|
||||
general=list(np.where(df["category"] == 0)[0]),
|
||||
character=list(np.where(df["category"] == 4)[0]),
|
||||
)
|
||||
|
||||
return tag_data
|
||||
|
||||
@staticmethod
|
||||
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
||||
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = (
|
||||
image.convert("RGBA")
|
||||
if "transparency" in image.info
|
||||
else image.convert("RGB")
|
||||
)
|
||||
# convert RGBA to RGB with white background
|
||||
if image.mode == "RGBA":
|
||||
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
||||
canvas.alpha_composite(image)
|
||||
image = canvas.convert("RGB")
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def pil_pad_square(image: Image.Image) -> Image.Image:
|
||||
w, h = image.size
|
||||
# get the largest dimension so we can pad to a square
|
||||
px = max(image.size)
|
||||
# pad to square with white background
|
||||
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
||||
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
||||
return canvas
|
||||
|
||||
@staticmethod
|
||||
def pil_resize(image: Image.Image, target_size: int) -> Image.Image:
|
||||
# Resize
|
||||
max_dim = max(image.size)
|
||||
if max_dim != target_size:
|
||||
image = image.resize(
|
||||
(target_size, target_size),
|
||||
Image.BICUBIC,
|
||||
)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_tags(
|
||||
probs,
|
||||
labels : LabelData,
|
||||
gen_threshold: float,
|
||||
char_threshold: float,
|
||||
):
|
||||
# Convert indices+probs to labels
|
||||
probs = list(zip(labels.names, probs))
|
||||
|
||||
# First 4 labels are actually ratings
|
||||
rating_labels = dict([probs[i] for i in labels.rating])
|
||||
|
||||
# General labels, pick any where prediction confidence > threshold
|
||||
gen_labels = [probs[i] for i in labels.general]
|
||||
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
||||
gen_labels = dict(
|
||||
sorted(
|
||||
gen_labels.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Character labels, pick any where prediction confidence > threshold
|
||||
char_labels = [probs[i] for i in labels.character]
|
||||
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
||||
char_labels = dict(
|
||||
sorted(
|
||||
char_labels.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Combine general and character labels, sort by confidence
|
||||
combined_names = [x for x in gen_labels]
|
||||
combined_names.extend([x for x in char_labels])
|
||||
|
||||
# Convert to a string suitable for use as a training caption
|
||||
caption = ", ".join(combined_names)
|
||||
taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)")
|
||||
|
||||
return caption, taglist, rating_labels, char_labels, gen_labels
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
classifier = WDClassifier()
|
||||
with open("Classifier/sample4.png", "rb") as image:
|
||||
image_data = image.read()
|
||||
rating, results, discord = await classifier.classify_async(image_data, 20, 0.6)
|
||||
|
||||
print(rating, results, discord)
|
||||
|
||||
asyncio.run(main())
|
||||
39
Database/db_schema.py
Normal file
39
Database/db_schema.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from enum import StrEnum
|
||||
|
||||
class x_accounts(StrEnum):
|
||||
table = "x_accounts"
|
||||
id = "id"
|
||||
rating = "rating"
|
||||
name = "name"
|
||||
is_deleted = "is_deleted"
|
||||
is_protected = "is_protected"
|
||||
download_mode = "download_mode"
|
||||
discord_channel_id = "discord_channel_id"
|
||||
discord_thread_id = "discord_thread_id"
|
||||
|
||||
class x_posts(StrEnum):
|
||||
table = "x_posts"
|
||||
id = "id"
|
||||
account_id = "account_id"
|
||||
discord_post_id = "discord_post_id"
|
||||
error_id = "error_id"
|
||||
action_taken = "action_taken"
|
||||
rating = "rating"
|
||||
tags = "tags"
|
||||
text = "text"
|
||||
date = "date"
|
||||
|
||||
class x_posts_images(StrEnum):
|
||||
table = "x_posts_images"
|
||||
post_id = "post_id"
|
||||
index = "index"
|
||||
phash = "phash"
|
||||
dhash = "dhash"
|
||||
error_id = "error_id"
|
||||
rating = "rating"
|
||||
tags = "tags"
|
||||
file = "file"
|
||||
saved_file = "saved_file"
|
||||
vox_label = "vox_label"
|
||||
duplicate_id = "duplicate_id"
|
||||
duplicate_index = "duplicate_index"
|
||||
256
Database/dbcontroller.py
Normal file
256
Database/dbcontroller.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import html
|
||||
import re
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row
|
||||
|
||||
from Database.x_classes import x_posts_images, x_accounts, x_posts
|
||||
from Database.db_schema import x_accounts as x_accounts_schema, x_posts as x_posts_schema, x_posts_images as x_posts_images_schema
|
||||
from config import Global_Config
|
||||
|
||||
class DatabaseController:
|
||||
def __init__(self):
|
||||
if Global_Config["postgresql_conninfo"] is None:
|
||||
raise Exception("Database connection string is required")
|
||||
self.conn = psycopg.connect(Global_Config["postgresql_conninfo"], row_factory=dict_row)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def query_get_stream(self, query, params = ([]), count = -1, page_size = 1000, offset = 0, commit = True):
|
||||
while(True):
|
||||
result = self.query_get(query + f' LIMIT {page_size} OFFSET {offset}', params, commit=commit)
|
||||
if len(result) == 0:
|
||||
return
|
||||
for idx, record in enumerate(result):
|
||||
offset += 1
|
||||
if count != -1 and offset > count:
|
||||
return
|
||||
yield record
|
||||
|
||||
def query_get(self, query, params = ([]), count = -1, commit = True):
|
||||
self.cursor.execute(query, params)
|
||||
if count == -1:
|
||||
result = self.cursor.fetchall()
|
||||
elif count == 1:
|
||||
result = self.cursor.fetchone()
|
||||
else:
|
||||
result = self.cursor.fetchmany(count)
|
||||
if commit:
|
||||
self.conn.commit()
|
||||
return result
|
||||
|
||||
def query_set(self, query, params = ([]), commit = True):
|
||||
try:
|
||||
self.cursor.execute(query, params)
|
||||
if commit:
|
||||
self.conn.commit()
|
||||
return True
|
||||
except psycopg.IntegrityError as e:
|
||||
print(e)
|
||||
self.conn.rollback()
|
||||
return False
|
||||
|
||||
def insert(self, table_name, updates, where=None, returning=None):
|
||||
# insert data into table
|
||||
try:
|
||||
query = f"INSERT INTO {table_name} ({', '.join([update[0] for update in updates])}) VALUES ({', '.join(['%s' for update in updates])})"
|
||||
if where:
|
||||
query += f" WHERE {where}"
|
||||
if returning:
|
||||
query += f" RETURNING {returning}"
|
||||
self.cursor.execute(query, ([update[1] for update in updates]))
|
||||
row = self.cursor.fetchone()
|
||||
self.conn.commit()
|
||||
return True, row[returning]
|
||||
else:
|
||||
self.cursor.execute(query, ([update[1] for update in updates]))
|
||||
self.conn.commit()
|
||||
return True, None
|
||||
except psycopg.IntegrityError as e:
|
||||
print(e, query)
|
||||
self.conn.rollback()
|
||||
return False, None
|
||||
|
||||
#region Pixiv
|
||||
def pixiv_get_tags(self, tags):
|
||||
new_tags = []
|
||||
for tag in tags:
|
||||
query = "SELECT id FROM pixiv_tags WHERE name=%s"
|
||||
self.cursor.execute(query, ([tag.name]))
|
||||
result = self.cursor.fetchone()
|
||||
if result is None:
|
||||
query = "INSERT INTO pixiv_tags(name, translated_name) VALUES(%s, %s) RETURNING id"
|
||||
self.cursor.execute(query, ([tag.name, tag.translated_name]))
|
||||
result = self.cursor.fetchone()
|
||||
new_tags.append(result["id"])
|
||||
self.conn.commit()
|
||||
return new_tags
|
||||
|
||||
def pixiv_get_user(self, id):
|
||||
query = "SELECT * FROM pixiv_users WHERE id=%s"
|
||||
return self.query_get(query, ([id]), count=1)
|
||||
|
||||
def pixiv_get_post(self, id):
|
||||
query = "SELECT * FROM pixiv_posts WHERE id=%s"
|
||||
return self.query_get(query, ([id]), count=1)
|
||||
|
||||
def pixiv_insert_user(self, json, commit = True):
|
||||
query = f'INSERT INTO pixiv_users (id, name, account) VALUES (%s,%s,%s)'
|
||||
return self.query_set(query, ([json.id, json.name, json.account]), commit)
|
||||
|
||||
def pixiv_insert_post(self, json, commit = True):
|
||||
query = f'INSERT INTO pixiv_posts (id, title, type, caption, restrict, user_id, tags, create_date, page_count, sanity_level, x_restrict, illust_ai_type, is_saved) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)'
|
||||
tags = self.pixiv_get_tags(json.tags)
|
||||
return self.query_set(query, ([json.id, json.title, json.type, json.caption, json.restrict, json.user.id, tags, json.create_date, json.page_count, json.sanity_level, json.x_restrict, json.illust_ai_type, True]), commit)
|
||||
|
||||
def pixiv_update_post(self, id, date, commit = True):
|
||||
query = f'UPDATE pixiv_posts SET create_date = %s WHERE id = %s'
|
||||
return self.query_set(query, ([date, id]), commit)
|
||||
#endregion
|
||||
#region X
|
||||
def get_all_posts(self, account_id, commit = True):
|
||||
query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s ORDER BY {x_posts_schema.id}"
|
||||
results = self.query_get(query, ([account_id]), commit=commit)
|
||||
|
||||
if results is None:
|
||||
return []
|
||||
else:
|
||||
results = [x_posts(**result) for result in results]
|
||||
return results
|
||||
|
||||
def get_post_images(self, account_id, commit = True):
|
||||
query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s ORDER BY {x_posts_schema.id}"
|
||||
results = self.query_get(query, ([account_id]), commit=commit)
|
||||
|
||||
if results is None:
|
||||
return []
|
||||
else:
|
||||
results = [x_posts(**result) for result in results]
|
||||
return results
|
||||
|
||||
def get_all_post_ids(self, account_id, commit = True):
|
||||
query = f"SELECT {x_posts_schema.id} FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s"
|
||||
results = self.query_get(query, ([account_id]), commit=commit)
|
||||
return [result["id"] for result in results]
|
||||
|
||||
def get_max_post_id(self, account_id : str, commit = True):
|
||||
query = f"SELECT MAX({x_posts_schema.id}) AS max FROM {x_posts_schema.table} WHERE {x_posts_schema.account_id} = %s"
|
||||
result = self.query_get(query, ([account_id]), count=1, commit=commit)
|
||||
if result is None:
|
||||
return 1
|
||||
else:
|
||||
return result["max"]
|
||||
|
||||
def x_get_post_from_id(self, post_id, commit = True):
|
||||
query = f"SELECT * FROM {x_posts_schema.table} WHERE {x_posts_schema.id} = %s"
|
||||
result = self.query_get(query, ([post_id]), count=1, commit=commit)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
else:
|
||||
return x_posts(**result)
|
||||
|
||||
def x_insert_post(self, post : x_posts, commit = True):
|
||||
text = re.sub(r"https://t.co\S+", "", html.unescape(post.text))
|
||||
query = f'''INSERT INTO {x_posts_schema.table}
|
||||
({x_posts_schema.id}, {x_posts_schema.account_id}, {x_posts_schema.discord_post_id},
|
||||
{x_posts_schema.error_id}, {x_posts_schema.action_taken}, {x_posts_schema.rating},
|
||||
{x_posts_schema.tags}, {x_posts_schema.text}, {x_posts_schema.date})
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)'''
|
||||
return self.query_set(query,
|
||||
([
|
||||
post.id, post.account_id, post.discord_post_id,
|
||||
post.error_id, post.action_taken, post.rating,
|
||||
post.tags, text, post.date
|
||||
]), commit)
|
||||
|
||||
def x_search_duplicate(self, user_id: int, max_id: int, phash = None, dhash = None, commit = True):
|
||||
if(phash != None):
|
||||
query = f'''SELECT i.* FROM {x_posts_images_schema.table} i JOIN {x_posts_schema.table} p ON i.{x_posts_images_schema.post_id} = p.{x_posts_schema.id} WHERE i.post_id < {max_id} AND p.account_id = {user_id} and bit_count(i.phash # {phash}) < 1 Order by post_id ASC'''
|
||||
temp = self.query_get(query, count=1, commit = commit)
|
||||
return None if temp == None else x_posts_images(**temp)
|
||||
elif(dhash != None):
|
||||
query = f'''SELECT i.* FROM {x_posts_images_schema.table} i JOIN {x_posts_schema.table} p ON i.{x_posts_images_schema.post_id} = p.{x_posts_schema.id} WHERE i.post_id < {max_id} AND p.account_id = {user_id} and bit_count(i.dhash # {dhash}) < 1 Order by post_id ASC'''
|
||||
temp = self.query_get(query, count=1, commit = commit)
|
||||
return None if temp == None else x_posts_images(**temp)
|
||||
else:
|
||||
return None
|
||||
|
||||
def x_insert_image(self, image : x_posts_images, commit = True):
|
||||
query = f'''INSERT INTO {x_posts_images_schema.table}
|
||||
({x_posts_images_schema.post_id}, {x_posts_images_schema.index}, {x_posts_images_schema.phash},
|
||||
{x_posts_images_schema.dhash}, {x_posts_images_schema.error_id}, {x_posts_images_schema.rating},
|
||||
{x_posts_images_schema.tags}, {x_posts_images_schema.file}, {x_posts_images_schema.saved_file},
|
||||
{x_posts_images_schema.vox_label})
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)'''
|
||||
return self.query_set(query,
|
||||
([
|
||||
image.post_id, image.index, image.phash,
|
||||
image.dhash, image.error_id, image.rating,
|
||||
image.tags, image.file, image.saved_file,
|
||||
image.vox_label
|
||||
]), commit)
|
||||
|
||||
def x_update_post(self, id, discord_post_id, error_id, action_taken, commit = True):
|
||||
query = f"UPDATE {x_posts_schema.table} SET {x_posts_schema.discord_post_id} = %s, {x_posts_schema.error_id} = %s, {x_posts_schema.action_taken} = %s WHERE {x_posts_schema.id} = %s"
|
||||
return self.query_set(query, ([discord_post_id, error_id, action_taken, id]), commit=commit)
|
||||
|
||||
def x_get_all_accounts(self):
|
||||
accounts = self.query_get(f'SELECT * from {x_accounts_schema.table} ORDER BY {x_accounts_schema.id}')
|
||||
result = [x_accounts(**account) for account in accounts]
|
||||
return result
|
||||
|
||||
def x_get_account_by_name(self, handle : str):
|
||||
"""returns TwitterContainer if account exists in database or None"""
|
||||
query = f'SELECT * from {x_accounts_schema.table} where {x_accounts_schema.name} = %s'
|
||||
result = self.query_get(query, ([handle]), count=1)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
container = x_accounts(**result)
|
||||
return container
|
||||
|
||||
def x_get_account_by_id(self, id : int):
|
||||
"""returns TwitterContainer if account exists in database or None"""
|
||||
query = f'SELECT * from {x_accounts_schema.table} where {x_accounts_schema.id} = %s'
|
||||
result = self.query_get(query, ([id]), count=1)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
container = x_accounts(**result)
|
||||
return container
|
||||
|
||||
def x_add_account(self, container : x_accounts):
|
||||
result, id = self.insert(x_accounts_schema.table, [
|
||||
(x_accounts_schema.id , container.id),
|
||||
(x_accounts_schema.name , container.name),
|
||||
(x_accounts_schema.rating , container.rating),
|
||||
(x_accounts_schema.discord_channel_id , container.discord_channel_id),
|
||||
(x_accounts_schema.discord_thread_id , container.discord_thread_id),
|
||||
(x_accounts_schema.download_mode , container.download_mode),
|
||||
], returning= x_accounts_schema.id)
|
||||
if result:
|
||||
return id
|
||||
else:
|
||||
return None
|
||||
|
||||
def x_update_account(self, container : x_accounts):
|
||||
updates = [
|
||||
(x_accounts_schema.name , container.name),
|
||||
(x_accounts_schema.rating , container.rating),
|
||||
(x_accounts_schema.discord_channel_id , container.discord_channel_id),
|
||||
(x_accounts_schema.discord_thread_id , container.discord_thread_id),
|
||||
(x_accounts_schema.download_mode , container.download_mode)
|
||||
]
|
||||
self.x_update_account_properties(container.id, updates)
|
||||
|
||||
def x_update_account_properties(self, account_id, updates):
|
||||
"""Example: updates = [("name", container.name), ("rating", container.rating)]"""
|
||||
query = f"UPDATE {x_accounts_schema.table} SET {', '.join([f'{update[0]} = %s' for update in updates])} WHERE id = {account_id}"
|
||||
self.cursor.execute(query, ([update[1] for update in updates]))
|
||||
self.conn.commit()
|
||||
|
||||
#endregion
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
123
Database/x_classes.py
Normal file
123
Database/x_classes.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from datetime import datetime
|
||||
from enum import IntEnum, StrEnum
|
||||
|
||||
class DownloadMode(IntEnum):
|
||||
NO_DOWNLOAD = 0
|
||||
DOWNLOAD = 1
|
||||
DOWNLOAD_ALL = 2
|
||||
|
||||
class ErrorID(IntEnum):
|
||||
SUCCESS = 0
|
||||
NO_ART = 1
|
||||
OTHER_ERROR = 5
|
||||
DOWNLOAD_FAIL = 6
|
||||
ACCOUNT_DEAD = 7
|
||||
ACCOUNT_SKIP = 8
|
||||
TAGGING_ERROR = 9
|
||||
FILE_READ_ERROR = 10
|
||||
HASHING_Unidentified = 11
|
||||
HASHING_FileNotFound = 12
|
||||
HASHING_OTHER = 13
|
||||
FILE_DELETED_OUTSIDE_ARCHIVE = 14
|
||||
NO_CHANNEL = 15
|
||||
TWEET_DELETED = 16
|
||||
|
||||
class ActionTaken(IntEnum):
|
||||
Null = 0
|
||||
Accepted = 1
|
||||
Rejected = 2
|
||||
Hidden = 3
|
||||
|
||||
class AccountRating(StrEnum):
|
||||
SFW = "SFW"
|
||||
NSFW = "NSFW"
|
||||
NSFL = "NSFL"
|
||||
|
||||
class HavoxLabel(StrEnum):
|
||||
KF = "KF"
|
||||
NonKF = "NonKF"
|
||||
Rejected = "Rejected"
|
||||
|
||||
class PostRating(StrEnum):
|
||||
Unrated = "unrated"
|
||||
General = "general"
|
||||
Sensitive = "sensitive"
|
||||
Questionable = "questionable"
|
||||
Explicit = "explicit"
|
||||
Filtered = "filtered"
|
||||
|
||||
class x_accounts:
|
||||
id : int
|
||||
rating : AccountRating
|
||||
name : str
|
||||
is_deleted : bool
|
||||
is_protected : bool
|
||||
download_mode : int
|
||||
discord_channel_id : int
|
||||
discord_thread_id : int
|
||||
|
||||
def __init__(self, id = 0, rating = AccountRating.NSFW, name : str = "", is_deleted = False, is_protected = False, download_mode = DownloadMode.NO_DOWNLOAD, discord_thread_id = 0, discord_channel_id = 0, **kwargs):
|
||||
self.id = id
|
||||
self.rating = rating
|
||||
self.name = name
|
||||
self.is_deleted = is_deleted
|
||||
self.is_protected = is_protected
|
||||
self.download_mode = download_mode
|
||||
self.discord_thread_id = discord_thread_id
|
||||
self.discord_channel_id = discord_channel_id
|
||||
|
||||
class x_posts:
|
||||
id : int
|
||||
account_id : int
|
||||
discord_post_id : int
|
||||
error_id : int
|
||||
action_taken : ActionTaken
|
||||
rating : PostRating
|
||||
"highest rating from all post images"
|
||||
tags : list[str]
|
||||
text : str
|
||||
date : datetime
|
||||
|
||||
def __init__(self, id = -1, account_id = -1, discord_post_id = 0, error_id = ErrorID.SUCCESS, action_taken = ActionTaken.Null, rating = PostRating.Unrated, tags = [], text = "", date = None, **kwargs):
|
||||
self.id = id
|
||||
self.account_id = account_id
|
||||
self.discord_post_id= discord_post_id
|
||||
self.error_id = error_id
|
||||
self.action_taken = action_taken
|
||||
self.rating = rating
|
||||
self.tags = tags
|
||||
self.text = text
|
||||
self.date = date
|
||||
|
||||
class x_posts_images:
|
||||
post_id : int
|
||||
index : int
|
||||
phash : int = None
|
||||
dhash : int = None
|
||||
error_id : int = 0
|
||||
rating : PostRating = PostRating.Unrated
|
||||
tags : list[str] = []
|
||||
file : str = None
|
||||
saved_file : str = None
|
||||
vox_label : HavoxLabel = None
|
||||
duplicate_id : int = -1
|
||||
duplicate_index : int = -1
|
||||
|
||||
def __init__(self, post_id, index, phash = None, dhash = None, error_id = ErrorID.SUCCESS, rating = PostRating.Unrated, tags = [], file = None, saved_file = None, vox_label : HavoxLabel = None, duplicate_index = -1, duplicate_id = -1, **kwargs):
|
||||
self.post_id = post_id
|
||||
self.index = index
|
||||
self.phash = phash
|
||||
self.dhash = dhash
|
||||
self.error_id = error_id
|
||||
self.rating = rating
|
||||
self.tags = tags
|
||||
self.file = file
|
||||
self.saved_file = saved_file
|
||||
self.vox_label = vox_label
|
||||
self.duplicate_id = duplicate_id
|
||||
self.duplicate_index = duplicate_index
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(ErrorID.ACCOUNT_DEAD)
|
||||
print(PostRating.Questionable)
|
||||
print(x_posts_images(0,1).rating)
|
||||
236
Discord/discordHelper.py
Normal file
236
Discord/discordHelper.py
Normal file
@@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
import html
|
||||
import io
|
||||
import re
|
||||
import nextcord
|
||||
from nextcord import ChannelType, Message
|
||||
|
||||
from Database.dbcontroller import DatabaseController
|
||||
from Database.x_classes import ErrorID, x_accounts
|
||||
from Twitter.tweetHelper import DownloadedMedia
|
||||
from exceptions import NO_CHANNEL, OTHER_ERROR
|
||||
from tweety.types.twDataTypes import Tweet
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
def chunks(s, n):
|
||||
"""Produce `n`-character chunks from `s`."""
|
||||
for start in range(0, len(s), n):
|
||||
yield s[start:start+n]
|
||||
|
||||
def build_pixiv_embed(post):
|
||||
url = "https://www.pixiv.net/en/artworks/" + str(post.id)
|
||||
text = re.sub(r"https://t.co\S+", "", html.unescape(post.caption))
|
||||
date = post.create_date
|
||||
|
||||
embed=nextcord.Embed(description=text)
|
||||
embed.set_author(name=post.user.name, url=url)
|
||||
embed.set_footer(text=date)
|
||||
|
||||
return embed
|
||||
|
||||
def build_x_embed(handle : str, post : Tweet):
|
||||
url = "https://x.com/" + handle + "/status/" + str(post.id)
|
||||
text = re.sub(r"https://t.co\S+", "", html.unescape(post.text))
|
||||
date = datetime.strftime(post.created_on, '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
embed=nextcord.Embed(description=text)
|
||||
embed.set_author(name=handle, url=url, icon_url=post.author.profile_image_url_https)
|
||||
embed.set_footer(text=date)
|
||||
|
||||
return embed
|
||||
|
||||
def build_secondary_embed(main_post : Message, handle : str, post : Tweet):
|
||||
text = re.sub(r"https://t.co\S+", "", html.unescape(post.text))
|
||||
date = datetime.strftime(post.created_on, '%Y-%m-%d %H:%M:%S')
|
||||
attachment_urls = [attachment.url.split("?")[0] for attachment in main_post.attachments]
|
||||
|
||||
embeds = []
|
||||
for url in attachment_urls:
|
||||
embed = nextcord.Embed(url="https://katworks.sytes.net")
|
||||
embed.set_image(url)
|
||||
embeds.append(embed)
|
||||
|
||||
if len(embeds) == 0:
|
||||
return None
|
||||
|
||||
embed : nextcord.Embed = embeds[0]
|
||||
embed.description = text
|
||||
embed.set_footer(text=date)
|
||||
embed.set_author(name=handle, url=main_post.jump_url, icon_url=post.author.profile_image_url_https)
|
||||
return embeds
|
||||
|
||||
async def send_error(ex : Exception, botData : RuntimeBotData):
|
||||
print(ex)
|
||||
errors_channel = nextcord.utils.get(botData.client.guilds[0].channels, name="bot-status")
|
||||
await errors_channel.send(content=str(ex))
|
||||
|
||||
def get_secondary_channel(is_animated, is_filtered, rating, tags : list, artist : x_accounts, guild : nextcord.Guild):
|
||||
if is_animated:
|
||||
return nextcord.utils.get(guild.channels, name="animation-feed")
|
||||
if artist.rating == 'NSFL':
|
||||
return nextcord.utils.get(guild.channels, name="hidden-feed")
|
||||
if "futanari" in tags:
|
||||
return nextcord.utils.get(guild.channels, name="hidden-feed")
|
||||
if is_filtered or not rating:
|
||||
return nextcord.utils.get(guild.channels, name="filtered-feed")
|
||||
if rating == "general":
|
||||
return nextcord.utils.get(guild.channels, name="safe-feed")
|
||||
if rating == "sensitive" or rating == "questionable":
|
||||
return nextcord.utils.get(guild.channels, name="unsafe-feed")
|
||||
if rating == 'explicit':
|
||||
return nextcord.utils.get(guild.channels, name="explicit-feed")
|
||||
|
||||
return nextcord.utils.get(guild.channels, name="filtered-feed")
|
||||
|
||||
async def edit_existing_embed_color(message : Message, color : nextcord.Colour):
|
||||
embeds = message.embeds
|
||||
embeds[0].colour = color
|
||||
await message.edit(embeds=embeds)
|
||||
|
||||
async def send_x_post(post : Tweet, artist : x_accounts, guild : nextcord.Guild, new_accounts : list, files_to_send : list[DownloadedMedia], is_filtered: bool, rating: str, tags : list, auto_approve : bool = False, vox_labels : list = None, duplicate_posts : list = None, xView: nextcord.ui.View = None, yView: nextcord.ui.View = None):
|
||||
if vox_labels is None:
|
||||
vox_labels = []
|
||||
if duplicate_posts is None:
|
||||
duplicate_posts = []
|
||||
|
||||
if artist.discord_channel_id != 0:
|
||||
channel = guild.get_channel(artist.discord_channel_id)
|
||||
elif artist.discord_thread_id != 0:
|
||||
channel = guild.get_thread(artist.discord_thread_id)
|
||||
else:
|
||||
raise NO_CHANNEL("Ensure channel for the account exists")
|
||||
|
||||
embed = build_x_embed(artist.name, post)
|
||||
if rating != "":
|
||||
embed.add_field(name = "rating", value=rating, inline=False)
|
||||
embed.add_field(name = "tags", value=f'{", ".join(tags)}'.replace("_","\\_")[:1024], inline=False)
|
||||
if len(duplicate_posts) > 0:
|
||||
links = [f"[{id}](<https://fxtwitter.com/i/status/{id}>)" for id in duplicate_posts]
|
||||
embed.add_field(name = "duplicates", value=f'{" ".join(links)}'[:1024], inline=False)
|
||||
|
||||
discord_files = [nextcord.File(fp = io.BytesIO(x.file_bytes), filename = x.file_name, force_close=True) for x in files_to_send]
|
||||
try:
|
||||
main_post : Message = await channel.send(embed=embed, files=discord_files)
|
||||
except Exception as e:
|
||||
raise OTHER_ERROR(e)
|
||||
finally:
|
||||
for file in discord_files: file.close()
|
||||
|
||||
#skip posting in the public feed for accounts with too many posts
|
||||
if artist.name in new_accounts and not is_filtered:
|
||||
return main_post
|
||||
|
||||
is_animated = files_to_send[0].file_name.endswith(".mp4")
|
||||
secondary_channel : nextcord.TextChannel = get_secondary_channel(is_animated, is_filtered, rating, tags, artist, guild)
|
||||
|
||||
if is_animated:
|
||||
link = "https://vxtwitter.com/" + artist.name + "/status/" + post.id + " " + main_post.jump_url
|
||||
secondary_post : Message = await secondary_channel.send(content=link, view = None if auto_approve else yView)
|
||||
else:
|
||||
embeds = build_secondary_embed(main_post, artist.name, post)
|
||||
if rating != "":
|
||||
embeds[0].add_field(name = "rating", value=rating, inline=False)
|
||||
#embeds[0].add_field(name = "tags", value=f'{", ".join(tags)}'.replace("_","\\_")[:1024], inline=False)
|
||||
if len(vox_labels) > 0:
|
||||
embeds[0].add_field(name = "prediction", value=", ".join(vox_labels), inline=False)
|
||||
if len(duplicate_posts) > 0:
|
||||
links = [f"[{id}](<https://fxtwitter.com/i/status/{id}>)" for id in duplicate_posts]
|
||||
embeds[0].add_field(name = "duplicates", value=f'{" ".join(links)}'[:1024], inline=False)
|
||||
secondary_post : Message = await secondary_channel.send(embeds=embeds, view = None if auto_approve else xView)
|
||||
|
||||
return main_post
|
||||
|
||||
async def send_pixiv_post(post, files_to_send, channel : nextcord.TextChannel):
|
||||
embed = build_pixiv_embed(post)
|
||||
|
||||
discord_files = [nextcord.File(fp = io.BytesIO(file["file"]), filename = file["name"], force_close=True) for file in files_to_send]
|
||||
try:
|
||||
main_post : Message = await channel.send(embed=embed, files=discord_files)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None, ErrorID.OTHER_ERROR
|
||||
finally:
|
||||
for file in discord_files: file.close()
|
||||
|
||||
return main_post, ErrorID.SUCCESS
|
||||
|
||||
async def post_result(results, guild : nextcord.Guild, new_accounts : list):
|
||||
print("Results: ", len(results))
|
||||
channel = nextcord.utils.get(guild.channels, name="bot-status")
|
||||
string = f'Last check: {datetime.now().strftime("%d/%m/%Y %H:%M:%S")}'
|
||||
for result in results:
|
||||
if result in new_accounts:
|
||||
string += f'\n{result} (new): {results[result]}'
|
||||
new_accounts.remove(result)
|
||||
else:
|
||||
string += f'\n{result}: {results[result]}'
|
||||
|
||||
for chunk in chunks(string, 6000):
|
||||
embed = nextcord.Embed(description=chunk)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
async def get_channel_from_handle(guild: nextcord.Guild, handle : str):
|
||||
channel = nextcord.utils.get(guild.channels, name=handle.lower())
|
||||
if channel is None:
|
||||
catId = 0
|
||||
category = None
|
||||
while category is None or len(category.channels) == 50:
|
||||
category = nextcord.utils.get(guild.categories, name=f'TWITTER-{catId}')
|
||||
if category is None:
|
||||
category = await guild.create_category(f'TWITTER-{catId}')
|
||||
catId+=1
|
||||
channel = await guild.create_text_channel(name=handle.lower(), category=category)
|
||||
return channel
|
||||
|
||||
async def get_thread_from_handle(guild: nextcord.Guild, handle : str):
|
||||
channel = nextcord.utils.get(guild.channels, name="threads")
|
||||
thread = nextcord.utils.get(channel.threads, name=handle.lower())
|
||||
if thread is None:
|
||||
thread = await channel.create_thread(name=handle.lower(), auto_archive_duration=10080, type=ChannelType.public_thread)
|
||||
return thread
|
||||
|
||||
async def get_category_by_name(client: nextcord.Client, name):
|
||||
guild = client.guilds[0]
|
||||
category = nextcord.utils.get(guild.categories, name=name)
|
||||
if category is None:
|
||||
category = await guild.create_category(name)
|
||||
return category
|
||||
|
||||
def check_permission(interaction : nextcord.Interaction):
|
||||
role = nextcord.utils.find(lambda r: r.name == 'Archivist', interaction.guild.roles)
|
||||
return role in interaction.user.roles
|
||||
|
||||
async def message_from_jump_url(server : nextcord.Guild, jump_url:str):
|
||||
link = jump_url.split('/')
|
||||
server_id = int(link[4])
|
||||
channel_id = int(link[5])
|
||||
msg_id = int(link[6])
|
||||
|
||||
channel = server.get_channel(channel_id)
|
||||
if channel is None:
|
||||
channel = server.get_thread(channel_id)
|
||||
message = await channel.fetch_message(msg_id)
|
||||
return message
|
||||
|
||||
async def get_main_post_and_data(guild, jump_url, botData : RuntimeBotData):
|
||||
main_post = await message_from_jump_url(guild, jump_url)
|
||||
tw_embed = main_post.embeds[0]
|
||||
x_post_id = int(tw_embed.author.url.split('/')[-1])
|
||||
return main_post, x_post_id
|
||||
|
||||
async def ensure_has_channel_or_thread(artist: x_accounts, guild: nextcord.Guild, database: DatabaseController):
|
||||
if artist.discord_channel_id == 0 and artist.discord_thread_id == 0:
|
||||
try:
|
||||
channel = await get_channel_from_handle(guild, artist.name)
|
||||
artist.discord_channel_id = channel.id
|
||||
artist.discord_thread_id = 0
|
||||
except:
|
||||
thread = await get_thread_from_handle(guild, artist.name)
|
||||
artist.discord_thread_id = thread.id
|
||||
artist.discord_channel_id = 0
|
||||
|
||||
database.x_update_account(artist)
|
||||
105
Discord/views.py
Normal file
105
Discord/views.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
import nextcord
|
||||
from nextcord.ui import View
|
||||
from Database.x_classes import ActionTaken
|
||||
from Discord import discordHelper
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
class XView(View):
|
||||
def __init__(self, botData : RuntimeBotData):
|
||||
self.botData = botData
|
||||
super().__init__(timeout=None, prevent_update=False)
|
||||
|
||||
@nextcord.ui.button(label="KF Art", custom_id="button-keep", style=nextcord.ButtonStyle.gray)
|
||||
async def button_keep(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData)
|
||||
self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Accepted)
|
||||
await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.green())
|
||||
await self_message.edit(view=None)
|
||||
|
||||
@nextcord.ui.button(label="Non-KF Art", custom_id="button-hide", style=nextcord.ButtonStyle.gray)
|
||||
async def button_hide(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData)
|
||||
self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Hidden)
|
||||
await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.yellow())
|
||||
await self_message.delete()
|
||||
|
||||
@nextcord.ui.button(label="Delete", custom_id="button-delete", style=nextcord.ButtonStyle.gray)
|
||||
async def button_delete(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.embeds[0].author.url, self.botData)
|
||||
print("Deleting", x_post_id, main_post.jump_url)
|
||||
|
||||
await main_post.delete()
|
||||
await self_message.delete()
|
||||
|
||||
self.botData.db.x_update_post(x_post_id, 0, 0, ActionTaken.Rejected)
|
||||
except Exception as e:
|
||||
await interaction.response.send_message("Error occured " + str(e), ephemeral=False)
|
||||
|
||||
class YView(View):
|
||||
|
||||
def __init__(self, botData : RuntimeBotData):
|
||||
self.botData = botData
|
||||
super().__init__(timeout=None, prevent_update=False)
|
||||
|
||||
@nextcord.ui.button(label="KF Art", custom_id="y-button-keep", style=nextcord.ButtonStyle.gray)
|
||||
async def button_keep(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData)
|
||||
self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Accepted)
|
||||
await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.green())
|
||||
await self_message.edit(view=None)
|
||||
|
||||
@nextcord.ui.button(label="Non-KF Art", custom_id="y-button-hide", style=nextcord.ButtonStyle.gray)
|
||||
async def button_hide(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData)
|
||||
self.botData.db.x_update_post(x_post_id, main_post.id, 0, ActionTaken.Hidden)
|
||||
await discordHelper.edit_existing_embed_color(main_post, nextcord.Colour.yellow())
|
||||
await self_message.delete()
|
||||
|
||||
@nextcord.ui.button(label="Delete", custom_id="y-button-delete", style=nextcord.ButtonStyle.gray)
|
||||
async def button_delete(self, button, interaction : nextcord.Interaction):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
try:
|
||||
self_message = interaction.message
|
||||
main_post, x_post_id = await discordHelper.get_main_post_and_data(interaction.guild, self_message.content.split(" ")[1], self.botData)
|
||||
print("Deleting", x_post_id, main_post.jump_url)
|
||||
|
||||
await main_post.delete()
|
||||
await self_message.delete()
|
||||
|
||||
self.botData.db.x_update_post(x_post_id, 0, 0, ActionTaken.Rejected)
|
||||
except Exception as e:
|
||||
await interaction.response.send_message("Error occured " + str(e), ephemeral=False)
|
||||
|
||||
37
Pixiv/downloader.py
Normal file
37
Pixiv/downloader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import nextcord
|
||||
|
||||
from Discord import discordHelper
|
||||
from config import Global_Config
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
async def download_loop(botData: RuntimeBotData):
|
||||
try:
|
||||
db = botData.db
|
||||
guild = botData.client.guilds[0]
|
||||
|
||||
new_posts = await botData.pixivApi.get_new_posts(botData)
|
||||
new_posts.sort(key= lambda x: x.id)
|
||||
|
||||
for post in new_posts:
|
||||
media = (await botData.pixivApi.download_illust(post, Global_Config["pixiv_download_path"]))[:4]
|
||||
if post.x_restrict == 0:
|
||||
channel = nextcord.utils.get(guild.channels, name="pixiv-sfw-feed")
|
||||
elif post.x_restrict == 1:
|
||||
channel = nextcord.utils.get(guild.channels, name="pixiv-r18-feed")
|
||||
else:
|
||||
channel = nextcord.utils.get(guild.channels, name="pixiv-r18g-feed")
|
||||
await discordHelper.send_pixiv_post(post, media, channel)
|
||||
|
||||
if db.pixiv_get_user(post.user.id) == None:
|
||||
db.pixiv_insert_user(post.user)
|
||||
db.pixiv_insert_post(post)
|
||||
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
await discordHelper.send_error(traceback.format_exc()[0:256], botData)
|
||||
print("Pixiv done")
|
||||
230
Pixiv/pixivapi.py
Normal file
230
Pixiv/pixivapi.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from config import Global_Config
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
from zipfile import ZipFile
|
||||
from pixivpy3 import AppPixivAPI
|
||||
from Twitter.tweetyapi import TweetyApi
|
||||
from getPixivpyToken.gppt import GetPixivToken
|
||||
from PIL import Image
|
||||
|
||||
class PixivApi:
|
||||
def init(self):
|
||||
self.api = AppPixivAPI(timeout=60)
|
||||
self.api.auth(refresh_token=self.get_refresh_token())
|
||||
return self
|
||||
|
||||
def get_refresh_token(self) -> str:
|
||||
if Global_Config["pixiv_token"] is not None:
|
||||
return Global_Config["pixiv_token"]
|
||||
else:
|
||||
if Global_Config["pixiv_username"] is None or Global_Config["pixiv_password"] is None:
|
||||
raise Exception("Pixiv username and password are required")
|
||||
g = GetPixivToken(headless=True, username=Global_Config["pixiv_username"], password=Global_Config["pixiv_password"])
|
||||
refresh_token = g.login()["refresh_token"]
|
||||
Global_Config["pixiv_token"] = refresh_token
|
||||
return refresh_token
|
||||
|
||||
async def search_illust(self, retries = 3, **next_qs):
|
||||
last_error = ""
|
||||
sleep_default = 0.125
|
||||
for i in range(retries):
|
||||
try:
|
||||
json_result = self.api.search_illust(**next_qs)
|
||||
|
||||
if json_result.error != None:
|
||||
raise Exception(json_result.error)
|
||||
|
||||
return json_result
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
last_error = ex
|
||||
await TweetyApi.sleep_wait(sleep_default, i)
|
||||
self.api = PixivApi().init().api
|
||||
continue
|
||||
|
||||
raise Exception(last_error)
|
||||
|
||||
async def illust_detail(self, id, retries = 3):
|
||||
last_error = ""
|
||||
sleep_default = 0.125
|
||||
for i in range(retries):
|
||||
try:
|
||||
json_result = self.api.illust_detail(id)
|
||||
|
||||
if json_result.error != None:
|
||||
if json_result.error.user_message == 'Page not found':
|
||||
return None
|
||||
if json_result.error.user_message == 'Artist has made their work private.':
|
||||
return None
|
||||
raise Exception(json_result.error)
|
||||
|
||||
return json_result
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
last_error = ex
|
||||
await TweetyApi.sleep_wait(sleep_default, i)
|
||||
self.api = PixivApi().init().api
|
||||
continue
|
||||
|
||||
raise Exception(last_error)
|
||||
|
||||
async def ugoira_metadata(self, id, retries = 3):
|
||||
last_error = ""
|
||||
sleep_default = 0.125
|
||||
for i in range(retries):
|
||||
try:
|
||||
json_result = self.api.ugoira_metadata(id)
|
||||
|
||||
if json_result.error != None:
|
||||
if json_result.error.user_message == 'Page not found':
|
||||
return None
|
||||
if json_result.error.user_message == 'Artist has made their work private.':
|
||||
return None
|
||||
raise Exception(json_result.error)
|
||||
|
||||
return json_result
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
last_error = ex
|
||||
await TweetyApi.sleep_wait(sleep_default, i)
|
||||
self.api = PixivApi().init().api
|
||||
continue
|
||||
|
||||
raise Exception(last_error)
|
||||
|
||||
async def download(self, url, retries = 3):
|
||||
for i in range(retries):
|
||||
try:
|
||||
with BytesIO() as io_bytes:
|
||||
def foo():
|
||||
self.api.download(url, fname=io_bytes)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, foo)
|
||||
return io_bytes.getbuffer().tobytes()
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
if i == 3:
|
||||
raise ex
|
||||
await asyncio.sleep(i * 60)
|
||||
|
||||
async def get_new_posts(self, botData: RuntimeBotData):
|
||||
posts = []
|
||||
# get all page:
|
||||
next_qs = {'word': 'けものフレンズ', 'search_target': 'partial_match_for_tags', 'sort': 'date_desc', 'filter': 'for_ios', 'end_date': ''}
|
||||
|
||||
page = 0
|
||||
while next_qs:
|
||||
print("Getting pixiv page")
|
||||
has_new_posts = False
|
||||
|
||||
json_result = await self.search_illust(retries=3, **next_qs)
|
||||
|
||||
for illust in json_result.illusts:
|
||||
if botData.db.pixiv_get_post(illust.id) == None:
|
||||
has_new_posts = True
|
||||
print(illust.id, illust.title, illust.create_date)
|
||||
posts.append(illust)
|
||||
|
||||
if has_new_posts == False:
|
||||
break
|
||||
|
||||
next_qs = self.api.parse_qs(json_result.next_url)
|
||||
|
||||
if len(json_result.illusts) > 0:
|
||||
oldest_date = json_result.illusts[-1].create_date[0:10]
|
||||
if "end_date" in next_qs and oldest_date != next_qs["end_date"]:
|
||||
next_qs["end_date"] = oldest_date
|
||||
next_qs["offset"] = sum(1 for x in json_result.illusts if x.create_date[0:10] == oldest_date)
|
||||
|
||||
page += 1
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return posts
|
||||
|
||||
async def download_illust(self, illust, download_path = "Temp/"):
|
||||
print(illust.id)
|
||||
downloaded = []
|
||||
filepath = os.path.abspath(os.path.join(download_path, str(illust.user.id)))
|
||||
os.makedirs(filepath, exist_ok=True)
|
||||
|
||||
if illust.type == 'ugoira':
|
||||
filename = str(illust.id) + "_p0.gif"
|
||||
filepath = os.path.join(filepath, filename)
|
||||
bytes = await self.download_ugoira_to_bytes(illust.id)
|
||||
if bytes is None:
|
||||
return []
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(bytes)
|
||||
downloaded.append({"file":bytes, "name":filename})
|
||||
else:
|
||||
if len(illust.meta_pages) == 0:
|
||||
_, filename = os.path.split(illust.meta_single_page.original_image_url)
|
||||
filepath = os.path.join(filepath, filename)
|
||||
bytes = await self.download(illust.meta_single_page.original_image_url)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(bytes)
|
||||
if len(downloaded) < 4:
|
||||
downloaded.append({"file":bytes, "name":filename})
|
||||
else:
|
||||
for page in illust.meta_pages:
|
||||
filepath = os.path.abspath(os.path.join(download_path, str(illust.user.id)))
|
||||
_, filename = os.path.split(page.image_urls.original)
|
||||
filepath = os.path.join(filepath, filename)
|
||||
bytes = await self.download(page.image_urls.original)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(bytes)
|
||||
if len(downloaded) < 4:
|
||||
downloaded.append({"file":bytes, "name":filename})
|
||||
|
||||
return downloaded
|
||||
|
||||
async def download_ugoira_to_bytes(self, id):
|
||||
closables = []
|
||||
|
||||
metadata = await self.ugoira_metadata(id)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
if len(metadata.ugoira_metadata.zip_urls) > 1:
|
||||
raise Exception
|
||||
|
||||
bytes = await self.download(metadata.ugoira_metadata.zip_urls["medium"])
|
||||
with BytesIO(bytes) as bytes0, ZipFile(bytes0) as input_zip:
|
||||
frames = metadata.ugoira_metadata.frames
|
||||
zip_frames = {}
|
||||
for name in input_zip.namelist():
|
||||
im = input_zip.read(name)
|
||||
im_bytes = BytesIO(im)
|
||||
im_im = Image.open(im_bytes)
|
||||
closables += [im_bytes, im_im]
|
||||
zip_frames[name] = im_im
|
||||
|
||||
with BytesIO() as buffer:
|
||||
zip_frames[frames[0].file].save(
|
||||
buffer,
|
||||
format="GIF",
|
||||
save_all=True,
|
||||
append_images=[zip_frames[frame.file] for frame in frames[1:]],
|
||||
duration=[frame.delay for frame in frames],
|
||||
loop=0,
|
||||
optimize=False
|
||||
)
|
||||
|
||||
[closable.close() for closable in closables]
|
||||
return buffer.getbuffer().tobytes()
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def get_new_posts():
|
||||
api = PixivApi().init()
|
||||
while True:
|
||||
await api.get_new_posts()
|
||||
await asyncio.sleep(30 * 60)
|
||||
|
||||
asyncio.run(get_new_posts())
|
||||
9
README.md
Normal file
9
README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
Discord bot with Pixiv and X post scrapers and a postgresql database
|
||||
|
||||
Will download all of Pixiv posts tagged KF when run
|
||||
|
||||
Uses AI models for tagging and classification: https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/
|
||||
|
||||
Recommended python 3.13.x
|
||||
|
||||
Postgresql database schema included in file: japariarchive_schema.sql
|
||||
147
Twitter/downloader.py
Normal file
147
Twitter/downloader.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import gc
|
||||
import traceback
|
||||
import tracemalloc
|
||||
|
||||
from Classifier.classifyHelper import classify_all
|
||||
from Database.x_classes import DownloadMode
|
||||
from Discord import discordHelper
|
||||
from Twitter import tweetHelper
|
||||
from tweety.types import Tweet
|
||||
from exceptions import ACCOUNT_DEAD, ACCOUNT_SKIP, DOWNLOAD_FAIL, NO_CHANNEL, OTHER_ERROR
|
||||
from Database.x_classes import ActionTaken, DownloadMode, ErrorID, HavoxLabel, PostRating, x_posts, x_posts_images, x_accounts
|
||||
from Database.db_schema import x_accounts as schema_x_accounts
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
async def download_loop(botData: RuntimeBotData):
|
||||
guild = botData.client.guilds[0]
|
||||
try:
|
||||
results = {}
|
||||
botData.new_accounts = []
|
||||
|
||||
for artist in botData.db.x_get_all_accounts():
|
||||
if artist.is_deleted: continue
|
||||
|
||||
#sleep to avoid rate limits
|
||||
await asyncio.sleep(5)
|
||||
|
||||
print("Artist:", artist.name)
|
||||
|
||||
#wait for ALL new posts to be found
|
||||
try:
|
||||
match artist.download_mode:
|
||||
case DownloadMode.NO_DOWNLOAD:
|
||||
continue
|
||||
case DownloadMode.DOWNLOAD:
|
||||
await discordHelper.ensure_has_channel_or_thread(artist, guild, botData.db)
|
||||
new_posts = await tweetHelper.UpdateMediaPosts(artist, botData)
|
||||
case DownloadMode.DOWNLOAD_ALL:
|
||||
await discordHelper.ensure_has_channel_or_thread(artist, guild, botData.db)
|
||||
new_posts = await tweetHelper.DownloadAllMediaPosts(artist, botData)
|
||||
case _:
|
||||
continue
|
||||
except ACCOUNT_DEAD:
|
||||
botData.db.x_update_account_properties(artist.id, [(schema_x_accounts.is_deleted, True)])
|
||||
continue
|
||||
except ACCOUNT_SKIP:
|
||||
continue
|
||||
|
||||
if len(new_posts) == 0: continue
|
||||
|
||||
new_posts_count = len([post for post in new_posts if len(post.media) > 0])
|
||||
if new_posts_count > 20:
|
||||
#skips posting to discord if there are too many posts
|
||||
botData.new_accounts.append(artist.name)
|
||||
|
||||
new_posts.sort(key= lambda x: x.date)
|
||||
|
||||
for tweet in new_posts: #posts should arrive here in chronological order
|
||||
await download_post(artist, tweet, botData)
|
||||
gc.collect()
|
||||
print(tracemalloc.get_traced_memory())
|
||||
|
||||
results[artist.name] = new_posts_count
|
||||
if artist.download_mode == DownloadMode.DOWNLOAD_ALL:
|
||||
botData.db.x_update_account_properties(artist.id, [(schema_x_accounts.download_mode, DownloadMode.DOWNLOAD)])
|
||||
await discordHelper.post_result(results, guild, botData.new_accounts)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
await discordHelper.send_error(traceback.format_exc()[0:256], botData)
|
||||
|
||||
async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotData):
|
||||
x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text)
|
||||
|
||||
if len(tweet.media) == 0:
|
||||
x_post.error_id = ErrorID.NO_ART
|
||||
botData.db.x_insert_post(x_post, commit=True)
|
||||
return
|
||||
|
||||
print("New media post:", str(tweet.url))
|
||||
media = await tweetHelper.GetTweetMediaUrls(tweet)
|
||||
image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)]
|
||||
|
||||
try:
|
||||
downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, botData.session)
|
||||
except DOWNLOAD_FAIL as e:
|
||||
x_post.error_id = e.code
|
||||
|
||||
botData.db.x_insert_post(x_post, commit=False)
|
||||
for image in image_containers:
|
||||
image.error_id = ErrorID.DOWNLOAD_FAIL
|
||||
botData.db.x_insert_image(image, commit=False)
|
||||
botData.db.conn.commit()
|
||||
return
|
||||
|
||||
def get_rating_value(rating):
|
||||
return 4 if rating == PostRating.Explicit else 3 if rating == PostRating.Questionable else 2 if rating == PostRating.Sensitive else 1 if rating == PostRating.General else 0
|
||||
|
||||
vox_labels = []
|
||||
final_filtered_tags = {}
|
||||
duplicates = []
|
||||
for idx, attachment in enumerate(downloaded_media):
|
||||
container = image_containers[idx]
|
||||
container.saved_file = attachment.file_name
|
||||
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox)
|
||||
|
||||
if container.vox_label not in vox_labels:
|
||||
vox_labels.append(container.vox_label)
|
||||
|
||||
if container.phash != None:
|
||||
duplicate = botData.db.x_search_duplicate(user_id=x_post.account_id, max_id = x_post.id, phash=container.phash)
|
||||
if duplicate != None:
|
||||
container.duplicate_id = duplicate.post_id
|
||||
container.duplicate_index = duplicate.index
|
||||
if duplicate.post_id not in duplicates:
|
||||
duplicates.append(duplicate.post_id)
|
||||
|
||||
x_post.tags = list(set(x_post.tags + container.tags))
|
||||
x_post.rating = container.rating if get_rating_value(container.rating) > get_rating_value(x_post.rating) else x_post.rating
|
||||
final_filtered_tags = final_filtered_tags | filtered_tags
|
||||
|
||||
is_filtered = len(vox_labels) == 1 and vox_labels[0] == HavoxLabel.Rejected
|
||||
|
||||
try:
|
||||
discord_post = await discordHelper.send_x_post(tweet, artist, botData.client.guilds[0], botData.new_accounts, downloaded_media, is_filtered, rating=x_post.rating, tags=final_filtered_tags, vox_labels = vox_labels, duplicate_posts=duplicates, xView=botData.xView, yView=botData.yView)
|
||||
except (NO_CHANNEL, OTHER_ERROR) as e:
|
||||
x_post.error_id = e.code
|
||||
x_post.discord_post_id = 0
|
||||
else:
|
||||
x_post.discord_post_id = discord_post.id
|
||||
|
||||
x_post.action_taken = ActionTaken.Rejected if is_filtered else ActionTaken.Null
|
||||
|
||||
try:
|
||||
if not botData.db.x_insert_post(x_post, commit = False):
|
||||
raise Exception("Transaction error")
|
||||
|
||||
for image in image_containers:
|
||||
botData.db.x_insert_image(image, False)
|
||||
except Exception as ex:
|
||||
botData.db.conn.rollback()
|
||||
raise ex
|
||||
else:
|
||||
botData.db.conn.commit()
|
||||
131
Twitter/tweetHelper.py
Normal file
131
Twitter/tweetHelper.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from Database.x_classes import x_accounts
|
||||
from config import Global_Config
|
||||
import downloadHelper
|
||||
from tweety.types.twDataTypes import Tweet
|
||||
|
||||
from exceptions import ACCOUNT_DEAD, ACCOUNT_SKIP
|
||||
from tweety.exceptions_ import UserNotFound, UserProtected
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
class TweetMedia:
|
||||
url : str
|
||||
file_name : str
|
||||
|
||||
def __init__(self, url, file_name):
|
||||
self.url = url
|
||||
self.file_name = file_name
|
||||
|
||||
class DownloadedMedia:
|
||||
file_bytes : str
|
||||
file_name : str
|
||||
|
||||
def __init__(self, bytes, file_name):
|
||||
self.file_bytes = bytes
|
||||
self.file_name = file_name
|
||||
|
||||
async def GetTweetMedia(tweet : Tweet) -> list[TweetMedia]:
|
||||
mediaList : list[TweetMedia] = []
|
||||
for idx, media in enumerate(tweet.media):
|
||||
if media.file_format == 'mp4':
|
||||
best_stream = await media.best_stream()
|
||||
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{media.file_format}"
|
||||
mediaList.append(TweetMedia(best_stream.direct_url, fileName))
|
||||
else:
|
||||
best_stream = await media.best_stream()
|
||||
extension = best_stream.file_format
|
||||
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{extension}"
|
||||
mediaList.append(TweetMedia(best_stream.direct_url, fileName))
|
||||
|
||||
return mediaList
|
||||
|
||||
async def GetTweetMediaUrls(tweet : Tweet):
|
||||
mediaList = await GetTweetMedia(tweet)
|
||||
return [media.url for media in mediaList]
|
||||
|
||||
async def DownloadMedia(post_id, account_id, account_name, url_list : list, session) -> list[DownloadedMedia]:
|
||||
result : list[DownloadedMedia] = []
|
||||
path = f"{Global_Config("x_download_path")}{account_id}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
for idx, file_url in enumerate(url_list):
|
||||
file_name = get_file_name(account_name, post_id, idx, file_url)
|
||||
full_path = f"{path}/{file_name}"
|
||||
|
||||
photo_bytes = await downloadHelper.save_to_file(file_url, full_path, session)
|
||||
|
||||
result.append(DownloadedMedia(photo_bytes, file_name))
|
||||
|
||||
return result
|
||||
|
||||
def get_file_name(account_name: str, post_id: int, image_index: int, image_url: str, account_id : int = None, base_path : str = None):
|
||||
'''
|
||||
`account_id` and `base_path` are optional\n
|
||||
In `base_path`, do not include trailing slash\n
|
||||
Example if none are defined:\n `file_name.ext`
|
||||
Example if `base_path` is defined:\n `c:/base_path/file_name.ext`
|
||||
Example if `account_id` is defined:\n `account_id/file_name.ext`
|
||||
Example if both are defined:\n `c:/base_path/account_id/file_name.ext`
|
||||
'''
|
||||
|
||||
ext = image_url.split("?")[0].split(".")[-1]
|
||||
file_name = f"{account_name}_{post_id}_{image_index}.{ext}"
|
||||
if account_id != None and base_path != None:
|
||||
return f"{base_path}/{account_id}/{file_name}"
|
||||
elif base_path != None:
|
||||
return f"{base_path}/{file_name}"
|
||||
elif account_id != None:
|
||||
return f"{account_id}/{file_name}"
|
||||
return file_name
|
||||
|
||||
async def UpdateMediaPosts(account : x_accounts, botData : RuntimeBotData) -> list[Tweet]:
|
||||
all_posts = botData.db.get_all_post_ids(account.id)
|
||||
newest_post = 1 if len(all_posts) == 0 else max(all_posts)
|
||||
posts = []
|
||||
|
||||
try:
|
||||
posts = [tweet async for tweet in botData.twApi.get_tweets(user_name = account.name, bottom_id = newest_post, all_posts = all_posts)]
|
||||
|
||||
except (UserProtected, UserNotFound) as ex:
|
||||
print("User dead: ", account.name, ex)
|
||||
raise ACCOUNT_DEAD(ex)
|
||||
except Exception as ex:
|
||||
print("Error in ", account.name, ex)
|
||||
raise ACCOUNT_SKIP(ex)
|
||||
|
||||
return posts
|
||||
|
||||
async def DownloadAllMediaPosts(account : x_accounts, botData : RuntimeBotData) -> list[Tweet]:
|
||||
all_posts = botData.db.get_all_post_ids(account.id)
|
||||
posts = []
|
||||
|
||||
try:
|
||||
async for tweet in botData.twApi.get_tweets(user_name = account.name, bottom_id = 1, all_posts = []):
|
||||
if int(tweet.id) not in all_posts:
|
||||
posts.append(tweet)
|
||||
|
||||
except (UserProtected, UserNotFound) as ex:
|
||||
print("User dead: ", account.name, ex)
|
||||
raise ACCOUNT_DEAD(ex)
|
||||
except Exception as ex:
|
||||
print("Error in ", account.name, ex)
|
||||
raise ACCOUNT_SKIP(ex)
|
||||
|
||||
return posts
|
||||
|
||||
def parse_x_url(url : str):
|
||||
"return account (handle, post id) from full X post url"
|
||||
url = url.replace("https://", "").replace("http://", "")
|
||||
split = url.split("?")
|
||||
if len(split) > 0:
|
||||
url = split[0]
|
||||
|
||||
split = url.split('/')
|
||||
if split[2] != "status":
|
||||
raise Exception("Invalid Format")
|
||||
|
||||
return split[1], int(split[3])
|
||||
102
Twitter/tweetyapi.py
Normal file
102
Twitter/tweetyapi.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import asyncio
|
||||
from typing import AsyncIterator
|
||||
from tweety import TwitterAsync
|
||||
from tweety.types.twDataTypes import Tweet, SelfThread, ConversationThread
|
||||
from tweety.types.usertweet import UserMedia
|
||||
from tweety.exceptions_ import RateLimitReached
|
||||
from config import Global_Config
|
||||
|
||||
class TweetyApi:
|
||||
async def init(self, skip_login = False, session_name = "session"):
|
||||
if skip_login:
|
||||
self.app = TwitterAsync(session_name)
|
||||
else:
|
||||
if Global_Config["x_cookies"] == None:
|
||||
raise Exception("X cookies are required")
|
||||
cookies = Global_Config["x_cookies"]
|
||||
self.app = TwitterAsync(session_name)
|
||||
await self.app.load_cookies(cookies)
|
||||
print(self.app.user)
|
||||
return self
|
||||
|
||||
async def get_tweet(self, url):
|
||||
try:
|
||||
tweet = await self.app.tweet_detail(url)
|
||||
return tweet
|
||||
except:
|
||||
return None
|
||||
|
||||
async def get_tweets(self, user_name, bottom_id, all_posts : list) -> AsyncIterator[Tweet]:
|
||||
def validate_tweet(tweet : Tweet):
|
||||
tweet_id_num = int(tweet.id)
|
||||
|
||||
past_bounds = False
|
||||
tweet_valid = True
|
||||
|
||||
if tweet_id_num <= bottom_id:
|
||||
past_bounds = True
|
||||
if tweet_id_num in all_posts:
|
||||
tweet_valid = False
|
||||
|
||||
return past_bounds, tweet_valid
|
||||
|
||||
sleep_default = 0.125
|
||||
sleep_exponent = 1
|
||||
user = None
|
||||
|
||||
while user == None:
|
||||
try:
|
||||
user = await self.app.get_user_info(username=user_name)
|
||||
except RateLimitReached as ex:
|
||||
sleep_exponent = await self.sleep_wait(sleep_default, sleep_exponent)
|
||||
except Exception as ex:
|
||||
print("User error: " + str(ex))
|
||||
raise ex
|
||||
|
||||
tweety_api = UserMedia(user.rest_id, self.app, 1, 2, None)
|
||||
sleep_exponent = 1
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
old_cursor = tweety_api.cursor
|
||||
|
||||
try:
|
||||
tweets = await tweety_api.get_next_page()
|
||||
sleep_exponent = 1
|
||||
except RateLimitReached as ex:
|
||||
sleep_exponent = await self.sleep_wait(sleep_default, sleep_exponent)
|
||||
tweety_api.cursor = old_cursor
|
||||
continue
|
||||
except Exception as ex:
|
||||
raise ex
|
||||
|
||||
has_valid_tweets = False
|
||||
for tweet in tweets:
|
||||
if isinstance(tweet, ConversationThread) | isinstance(tweet, SelfThread):
|
||||
tweet:ConversationThread | SelfThread
|
||||
for tweet1 in tweet.tweets:
|
||||
_, tweet_valid = validate_tweet(tweet1)
|
||||
if tweet_valid:
|
||||
has_valid_tweets = True
|
||||
yield tweet1
|
||||
else:
|
||||
past_bounds, tweet_valid = validate_tweet(tweet)
|
||||
if past_bounds: continue
|
||||
if tweet_valid:
|
||||
has_valid_tweets = True
|
||||
yield tweet
|
||||
|
||||
if len(tweets) == 0 or not has_valid_tweets:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@staticmethod
|
||||
async def sleep_wait(sleep_default, sleep_exponent):
|
||||
sleep_amount = min(sleep_default * pow(2,sleep_exponent), 2)
|
||||
print(f"Sleeping for {round(sleep_amount,2)} hours.")
|
||||
await asyncio.sleep(sleep_amount * 60 * 60)
|
||||
print("Sleep done")
|
||||
sleep_exponent += 1
|
||||
return sleep_exponent
|
||||
|
||||
#asyncio.run(TweetyApi().get_tweets("redhood_depth", 0))
|
||||
41
Twitter/twitterContainer.py
Normal file
41
Twitter/twitterContainer.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
from Database.x_classes import AccountRating, DownloadMode, ErrorID
|
||||
from tweety.exceptions_ import UserNotFound, UserProtected
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
class TwitterContainer():
|
||||
id : int
|
||||
name : str
|
||||
rating : str
|
||||
discord_channel_id : int = 0
|
||||
discord_thread_id : int = 0
|
||||
download_mode : int = 0
|
||||
|
||||
def __init__(self, name : str, rating = AccountRating.NSFW, discord_channel_id = 0, discord_thread_id = 0, download_mode = DownloadMode.NO_DOWNLOAD, id = 0, **kwargs):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.rating = rating.upper()
|
||||
self.discord_channel_id = discord_channel_id
|
||||
self.discord_thread_id = discord_thread_id
|
||||
self.download_mode = download_mode
|
||||
|
||||
async def UpdateMediaPosts(self, botData : RuntimeBotData):
|
||||
all_posts = botData.db.get_all_post_ids(self.id)
|
||||
newest_post = 1 if len(all_posts) == 0 else max(all_posts)
|
||||
posts = []
|
||||
|
||||
try:
|
||||
async for tweet in botData.twApi.get_tweets(user_name = self.name, bottom_id = newest_post, all_posts = all_posts):
|
||||
posts.append(tweet)
|
||||
|
||||
except (UserProtected, UserNotFound) as ex:
|
||||
print("User dead: ", self.name, ex)
|
||||
return ErrorID.ACCOUNT_DEAD, []
|
||||
except Exception as ex:
|
||||
print("Error in ", self.name, ex)
|
||||
return ErrorID.ACCOUNT_SKIP, []
|
||||
|
||||
return ErrorID.SUCCESS, posts
|
||||
54
bot.py
Normal file
54
bot.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
from config import Global_Config
|
||||
|
||||
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
||||
from Twitter.downloader import download_loop as x_download_loop
|
||||
from Pixiv.downloader import download_loop as pixiv_download_loop
|
||||
|
||||
#import asyncio
|
||||
import nextcord
|
||||
from runtimeBotData import RuntimeBotData
|
||||
from nextcord.ext import tasks, commands
|
||||
|
||||
# asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
botData = RuntimeBotData()
|
||||
#loop.run_until_complete(botData.initialize_data())
|
||||
|
||||
intents = nextcord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guild_reactions = True
|
||||
|
||||
bot = commands.Bot(command_prefix=".", intents=intents)
|
||||
botData.client = bot
|
||||
bot.load_extension("commands", extras={"botData": botData})
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
if botData.initialized: return
|
||||
await botData.initialize_data()
|
||||
bot.add_view(botData.xView)
|
||||
bot.add_view(botData.yView)
|
||||
botData.initialized = True
|
||||
print(f'{bot.user} has connected to nextcord!')
|
||||
x_loop.start()
|
||||
pixiv_loop.start()
|
||||
|
||||
@tasks.loop(minutes = 3 * 60)
|
||||
async def x_loop():
|
||||
print(datetime.now().strftime("%d/%m/%Y %H:%M:%S"))
|
||||
await x_download_loop(botData)
|
||||
print("Finished: " + datetime.now().strftime("%d/%m/%Y %H:%M:%S"))
|
||||
|
||||
@tasks.loop(minutes = 30)
|
||||
async def pixiv_loop():
|
||||
await pixiv_download_loop(botData)
|
||||
|
||||
if Global_Config["discord_token"] is None:
|
||||
raise Exception("Discord bot token is required")
|
||||
bot.run(Global_Config["discord_token"])
|
||||
|
||||
250
commands.py
Normal file
250
commands.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
|
||||
from Classifier.classifyHelper import classify_all
|
||||
from Database.x_classes import AccountRating, ActionTaken, DownloadMode, PostRating, x_posts, x_posts_images
|
||||
from Database.db_schema import x_posts as schema_x_posts, x_accounts as schema_x_accounts
|
||||
from Discord import discordHelper
|
||||
from Twitter import tweetHelper
|
||||
from exceptions import NO_CHANNEL, OTHER_ERROR
|
||||
from helpers import get_rating_value
|
||||
from helpers_adv import create_x_account
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
import importlib
|
||||
import nextcord
|
||||
from nextcord.ext import commands
|
||||
from nextcord import application_command
|
||||
|
||||
class Commands(commands.Cog):
|
||||
def __init__(self, botData : RuntimeBotData):
|
||||
self.botData = botData
|
||||
self.db = botData.db
|
||||
self._last_member = None
|
||||
|
||||
@application_command.slash_command()
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def order(self, interaction: nextcord.Interaction):
|
||||
guild = interaction.guild
|
||||
categories = list(filter(lambda cat: cat.name.startswith("TWITTER"), guild.categories))
|
||||
channels = list(filter(lambda chan: not isinstance(chan, nextcord.CategoryChannel) and chan.category.name.startswith("TWITTER"), guild.channels))
|
||||
channelsOrdered = sorted(channels, key=lambda chan: chan.name)
|
||||
|
||||
for idx,category in enumerate(categories):
|
||||
await category.edit(name=f"temp-{idx}")
|
||||
|
||||
last_id = -1
|
||||
target_category = None
|
||||
for idx, channel in enumerate(channelsOrdered):
|
||||
id = idx//50
|
||||
if id > last_id:
|
||||
last_id = id
|
||||
target_category = await guild.create_category(f"TWITTER-{id}")
|
||||
print("moving", channel.name, "to", target_category.name)
|
||||
await channel.edit(category=target_category)
|
||||
|
||||
for category in categories:
|
||||
print("deleting", category.name)
|
||||
await category.delete()
|
||||
|
||||
@application_command.slash_command(description="Downloads single post from (domain)/(account_name)/status/(post_id) format")
|
||||
async def download_post(self, interaction: nextcord.Interaction, full_link : str, new_account_rating : str = nextcord.SlashOption(choices= [AccountRating.NSFW, AccountRating.SFW, AccountRating.NSFL], default = AccountRating.NSFW), download_mode : int = nextcord.SlashOption(choices={"No":DownloadMode.NO_DOWNLOAD, "Yes":DownloadMode.DOWNLOAD, "More Yes":DownloadMode.DOWNLOAD_ALL}, default=DownloadMode.NO_DOWNLOAD)):
|
||||
botData = self.botData
|
||||
|
||||
try:
|
||||
handle, post_id = tweetHelper.parse_x_url(full_link)
|
||||
except:
|
||||
await interaction.response.send_message("Invalid url. Format should be (domain)/(account_name)/status/(post_id)")
|
||||
return
|
||||
|
||||
await interaction.response.defer()
|
||||
|
||||
query_result = botData.db.x_get_post_from_id(post_id)
|
||||
if query_result is not None:
|
||||
if query_result.discord_post_id == 0:
|
||||
await interaction.followup.send(content=f'Post was previously deleted. This is not implemented yet.' )
|
||||
return
|
||||
else:
|
||||
await interaction.followup.send(content=f'Post was previously downloaded')#: https://discord.com/channels/{client.guilds[0].id}/{query_result.discord_channel_id}/{query_result["discord_post_id"]}' )
|
||||
return
|
||||
else:
|
||||
try:
|
||||
tweet = await botData.twApi.get_tweet(f'x.com/{handle}/status/{post_id}')
|
||||
x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text)
|
||||
|
||||
artist = botData.db.x_get_account_by_name(handle)
|
||||
|
||||
if artist is None:
|
||||
artist = await create_x_account(handle, botData, new_account_rating, download_mode = download_mode)
|
||||
|
||||
await discordHelper.ensure_has_channel_or_thread(artist, interaction.guild, botData.db)
|
||||
|
||||
image_containers : list[x_posts_images] = []
|
||||
media = await tweetHelper.GetTweetMediaUrls(tweet)
|
||||
image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)]
|
||||
downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, botData.session)
|
||||
|
||||
vox_labels = []
|
||||
final_filtered_tags = {}
|
||||
duplicates = []
|
||||
for idx, attachment in enumerate(downloaded_media):
|
||||
container = image_containers[idx]
|
||||
container.saved_file = attachment.file_name
|
||||
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox)
|
||||
|
||||
if container.vox_label not in vox_labels:
|
||||
vox_labels.append(container.vox_label)
|
||||
|
||||
if container.phash != None:
|
||||
duplicate = botData.db.x_search_duplicate(user_id=x_post.account_id, max_id = x_post.id, phash=container.phash)
|
||||
if duplicate != None:
|
||||
container.duplicate_id = duplicate.post_id
|
||||
container.duplicate_index = duplicate.index
|
||||
if duplicate.post_id not in duplicates:
|
||||
duplicates.append(duplicate.post_id)
|
||||
|
||||
x_post.tags = list(set(x_post.tags + container.tags))
|
||||
x_post.rating = container.rating if get_rating_value(container.rating) > get_rating_value(x_post.rating) else x_post.rating
|
||||
final_filtered_tags = final_filtered_tags | filtered_tags
|
||||
|
||||
try:
|
||||
discord_post = await discordHelper.send_x_post(tweet, artist, interaction.guild, [artist.name], downloaded_media, False, rating=x_post.rating, tags=final_filtered_tags, auto_approve=True, vox_labels = vox_labels, duplicate_posts=duplicates, xView=botData.xView, yView=botData.yView)
|
||||
except (NO_CHANNEL, OTHER_ERROR) as e:
|
||||
x_post.error_id = e.code
|
||||
x_post.discord_post_id = 0
|
||||
else:
|
||||
x_post.discord_post_id = discord_post.id
|
||||
|
||||
x_post.action_taken = ActionTaken.Accepted
|
||||
|
||||
try:
|
||||
if not botData.db.x_insert_post(x_post, commit = False):
|
||||
raise Exception("Transaction error")
|
||||
|
||||
for image in image_containers:
|
||||
botData.db.x_insert_image(image, False)
|
||||
except Exception as ex:
|
||||
botData.db.conn.rollback()
|
||||
raise ex
|
||||
else:
|
||||
botData.db.conn.commit()
|
||||
|
||||
if discord_post is not None:
|
||||
embeds = discordHelper.build_secondary_embed(discord_post, artist.name, tweet)
|
||||
await discordHelper.edit_existing_embed_color(discord_post, nextcord.Colour.green())
|
||||
await interaction.followup.send(embeds=embeds)
|
||||
return
|
||||
else:
|
||||
await interaction.followup.send(content=f'Posting failed: {x_post.error_id}')
|
||||
return
|
||||
except Exception as e:
|
||||
await interaction.followup.send(content=e)
|
||||
return
|
||||
|
||||
@application_command.slash_command()
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def query(self, interaction: nextcord.Interaction, query: str):
|
||||
self.db.cursor.execute(query)
|
||||
self.db.conn.commit()
|
||||
await interaction.response.send_message("ok")
|
||||
|
||||
@application_command.slash_command()
|
||||
@commands.has_permissions(administrator=True)
|
||||
async def ignore_tag(self, interaction: nextcord.Interaction, tag: str):
|
||||
self.botData.classifier.add_ignored_tag(tag)
|
||||
await interaction.response.send_message("ok", ephemeral=True)
|
||||
|
||||
@application_command.slash_command()
|
||||
async def add_x_account(self, interaction: nextcord.Interaction, handle: str, rating : str = nextcord.SlashOption(choices= [AccountRating.NSFW, AccountRating.SFW, AccountRating.NSFL], default = AccountRating.NSFW), download_mode : int = nextcord.SlashOption(choices={"No":DownloadMode.NO_DOWNLOAD, "Yes":DownloadMode.DOWNLOAD, "Yes (check all posts)":DownloadMode.DOWNLOAD_ALL}, default=DownloadMode.DOWNLOAD_ALL)):
|
||||
await interaction.response.defer()
|
||||
if "x.com" in handle:
|
||||
handle = handle.split('/')
|
||||
handle = handle[handle.index('x.com')+1]
|
||||
try:
|
||||
print(handle)
|
||||
result = requests.get(f"https://api.vxtwitter.com/{handle}")
|
||||
if result.status_code != 200:
|
||||
raise Exception("Failed to get user id")
|
||||
id = result.json()["id"]
|
||||
existing_account = self.db.x_get_account_by_id(id)
|
||||
if existing_account != None:
|
||||
if existing_account.download_mode == download_mode:
|
||||
raise Exception("Account is already on download list")
|
||||
self.db.x_update_account_properties(id, updates=[(schema_x_accounts.download_mode, download_mode)])
|
||||
if download_mode == DownloadMode.DOWNLOAD_ALL:
|
||||
await interaction.followup.send("Added " + handle + " to download list (with full check)")
|
||||
elif download_mode == DownloadMode.NO_DOWNLOAD:
|
||||
await interaction.followup.send("Removed " + handle + " from download list")
|
||||
elif download_mode == DownloadMode.DOWNLOAD:
|
||||
await interaction.followup.send("Added " + handle + " to download list")
|
||||
else:
|
||||
await interaction.followup.send("huh??")
|
||||
else:
|
||||
account = await create_x_account(handle, self.botData, rating, download_mode=download_mode)
|
||||
await interaction.followup.send("Added " + handle)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
await interaction.followup.send(str(ex))
|
||||
|
||||
@application_command.message_command(guild_ids=[1043267878851457096])
|
||||
async def archive_set_KF(self, interaction: nextcord.Interaction, message: nextcord.Message):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
query =f'SELECT {schema_x_posts.id}, {schema_x_posts.account_id}, {schema_x_posts.action_taken} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s '
|
||||
result = self.db.query_get(query, ([message.id]), count = 1)
|
||||
if result is None:
|
||||
await interaction.response.send_message("post not found in database", ephemeral=True)
|
||||
return
|
||||
if result[schema_x_posts.action_taken] == ActionTaken.Accepted:
|
||||
await discordHelper.edit_existing_embed_color(message, nextcord.Colour.green())
|
||||
await interaction.response.send_message("post was already KF", ephemeral=True)
|
||||
return
|
||||
|
||||
await discordHelper.edit_existing_embed_color(message, nextcord.Colour.green())
|
||||
result = self.db.x_update_post(result[schema_x_posts.id], message.id, 0, ActionTaken.Accepted)
|
||||
await interaction.response.send_message("post approved", ephemeral=True)
|
||||
|
||||
@application_command.message_command(guild_ids=[1043267878851457096])
|
||||
async def archive_set_nonKF(self, interaction: nextcord.Interaction, message: nextcord.Message):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
query =f'SELECT {schema_x_posts.id}, {schema_x_posts.account_id}, {schema_x_posts.action_taken} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s '
|
||||
result = self.db.query_get(query, ([message.id]), count = 1)
|
||||
if result is None:
|
||||
await interaction.response.send_message("post not found in database", ephemeral=True)
|
||||
return
|
||||
if result[schema_x_posts.action_taken] == ActionTaken.Hidden:
|
||||
await discordHelper.edit_existing_embed_color(message, nextcord.Colour.yellow())
|
||||
await interaction.response.send_message("post was already NonKF", ephemeral=True)
|
||||
return
|
||||
|
||||
await discordHelper.edit_existing_embed_color(message, nextcord.Colour.yellow())
|
||||
result = self.db.x_update_post(result[schema_x_posts.id], message.id, 0, ActionTaken.Hidden)
|
||||
await interaction.response.send_message("post hidden", ephemeral=True)
|
||||
|
||||
@application_command.message_command(guild_ids=[1043267878851457096])
|
||||
async def archive_delete(self, interaction: nextcord.Interaction, message: nextcord.Message):
|
||||
if not discordHelper.check_permission(interaction):
|
||||
await interaction.response.send_message("No permission", ephemeral=True)
|
||||
return
|
||||
|
||||
query =f'SELECT {schema_x_posts.id} FROM {schema_x_posts.table} WHERE {schema_x_posts.discord_post_id} = %s'
|
||||
result = self.db.query_get(query, ([message.id]), count = 1)
|
||||
if result is None:
|
||||
await interaction.response.send_message("post not found in database", ephemeral=True)
|
||||
return
|
||||
|
||||
await message.delete()
|
||||
result = self.db.x_update_post(result["id"], 0, 0, ActionTaken.Rejected)
|
||||
await interaction.response.send_message("post deleted", ephemeral=True)
|
||||
|
||||
def setup(bot: commands.Bot, botData: RuntimeBotData):
|
||||
bot.add_cog(Commands(botData))
|
||||
print("Loaded Commands")
|
||||
45
config.py
Normal file
45
config.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
class Config:
|
||||
name = 'config'
|
||||
config = {
|
||||
"discord_token": None,
|
||||
#format: "dbname=xxx user=xxx password=xxx"
|
||||
"postgresql_conninfo": None,
|
||||
#format: "auth_token=xxx; ct0=xxx;"
|
||||
"x_cookies": None,
|
||||
#format: "C:/Storage/X/" - remember about trailing /
|
||||
"x_download_path": None,
|
||||
"pixiv_token": None,
|
||||
"pixiv_username": None,
|
||||
"pixiv_password": None,
|
||||
#format: "C:/Storage/Pixiv/" - remember about trailing /
|
||||
"pixiv_download_path": None
|
||||
}
|
||||
|
||||
def __init__(self, name : str = "config"):
|
||||
'"config" name is used as global config'
|
||||
self.name = name
|
||||
if os.path.exists(f"configs/{name}.json"):
|
||||
with open(f"configs/{name}.json", "rt") as f:
|
||||
self.config = json.load(f)
|
||||
else:
|
||||
self.save()
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self.config:
|
||||
self.config[key] = None
|
||||
self.save()
|
||||
return self.config[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.config[key] = value
|
||||
self.save()
|
||||
|
||||
def save(self):
|
||||
os.makedirs("configs", exist_ok=True)
|
||||
with open(f"configs/{self.name}.json", "wt") as f:
|
||||
json.dump(self.config, f, indent=1)
|
||||
|
||||
Global_Config = Config()
|
||||
29
downloadHelper.py
Normal file
29
downloadHelper.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import shutil
|
||||
import aiohttp
|
||||
|
||||
from exceptions import DOWNLOAD_FAIL
|
||||
|
||||
async def downloadFileToMemory(url, session : aiohttp.ClientSession):
|
||||
print("Downloading", url)
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
result = await resp.read()
|
||||
if len(result) == 0:
|
||||
raise DOWNLOAD_FAIL("received 0 bytes")
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
print(url, e)
|
||||
raise DOWNLOAD_FAIL(e)
|
||||
|
||||
async def save_to_file(url, out_file, session):
|
||||
file_bytes = await downloadFileToMemory(url, session)
|
||||
|
||||
if os.path.exists(out_file):
|
||||
shutil.copy(out_file, out_file + "_bck")
|
||||
|
||||
with open(out_file, "wb") as binary_file:
|
||||
binary_file.write(file_bytes)
|
||||
|
||||
return file_bytes
|
||||
37
exceptions.py
Normal file
37
exceptions.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from Database.x_classes import ErrorID
|
||||
|
||||
class NO_ART(Exception):
|
||||
code = ErrorID.NO_ART
|
||||
|
||||
class OTHER_ERROR(Exception):
|
||||
code = ErrorID.OTHER_ERROR
|
||||
|
||||
class DOWNLOAD_FAIL(Exception):
|
||||
code = ErrorID.DOWNLOAD_FAIL
|
||||
|
||||
class ACCOUNT_DEAD(Exception):
|
||||
code = ErrorID.ACCOUNT_DEAD
|
||||
|
||||
class ACCOUNT_SKIP(Exception):
|
||||
code = ErrorID.ACCOUNT_SKIP
|
||||
|
||||
class TAGGING_ERROR(Exception):
|
||||
code = ErrorID.TAGGING_ERROR
|
||||
|
||||
class FILE_READ_ERROR(Exception):
|
||||
code = ErrorID.FILE_READ_ERROR
|
||||
|
||||
class HASHING_Unidentified(Exception):
|
||||
code = ErrorID.HASHING_Unidentified
|
||||
|
||||
class HASHING_FileNotFound(Exception):
|
||||
code = ErrorID.HASHING_FileNotFound
|
||||
|
||||
class HASHING_OTHER(Exception):
|
||||
code = ErrorID.HASHING_OTHER
|
||||
|
||||
class FILE_DELETED_OUTSIDE_ARCHIVE(Exception):
|
||||
code = ErrorID.FILE_DELETED_OUTSIDE_ARCHIVE
|
||||
|
||||
class NO_CHANNEL(Exception):
|
||||
code = ErrorID.NO_CHANNEL
|
||||
1
getPixivpyToken
Submodule
1
getPixivpyToken
Submodule
Submodule getPixivpyToken added at 93544c23c7
24
helpers.py
Normal file
24
helpers.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from Database.x_classes import PostRating
|
||||
|
||||
|
||||
def twos_complement(hexstr, bits):
|
||||
value = int(hexstr,16) #convert hexadecimal to integer
|
||||
|
||||
#convert from unsigned number to signed number with "bits" bits
|
||||
if value & (1 << (bits-1)):
|
||||
value -= 1 << bits
|
||||
return value
|
||||
|
||||
def get_rating_value(rating):
|
||||
"for deciding final post rating"
|
||||
match rating:
|
||||
case PostRating.Explicit:
|
||||
return 4
|
||||
case PostRating.Questionable:
|
||||
return 3
|
||||
case PostRating.Sensitive:
|
||||
return 2
|
||||
case PostRating.General:
|
||||
return 1
|
||||
case _:
|
||||
return 0
|
||||
26
helpers_adv.py
Normal file
26
helpers_adv.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#helpers that require other helpers
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from Database.x_classes import AccountRating, DownloadMode, x_accounts
|
||||
from Discord import discordHelper
|
||||
if TYPE_CHECKING:
|
||||
from runtimeBotData import RuntimeBotData
|
||||
|
||||
async def create_x_account(handle, botData : RuntimeBotData, new_account_rating = AccountRating.NSFW, download_mode = DownloadMode.NO_DOWNLOAD):
|
||||
id = await botData.twApi.app.get_user_id(handle)
|
||||
try:
|
||||
thread = None
|
||||
channel = await discordHelper.get_channel_from_handle(botData.client.guilds[0], handle)
|
||||
except:
|
||||
thread = await discordHelper.get_thread_from_handle(botData.client.guilds[0], handle)
|
||||
channel = None
|
||||
artist = x_accounts(id= id,
|
||||
rating= new_account_rating,
|
||||
name= handle,
|
||||
discord_channel_id= 0 if channel is None else channel.id,
|
||||
discord_thread_id= 0 if thread is None else thread.id,
|
||||
download_mode= download_mode)
|
||||
if not botData.db.x_add_account(artist):
|
||||
raise Exception("Failed to add to database")
|
||||
return artist
|
||||
BIN
japariarchive_schema.sql
Normal file
BIN
japariarchive_schema.sql
Normal file
Binary file not shown.
BIN
requirements.txt
Normal file
BIN
requirements.txt
Normal file
Binary file not shown.
36
runtimeBotData.py
Normal file
36
runtimeBotData.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import aiohttp
|
||||
import nextcord
|
||||
from Discord.views import XView, YView
|
||||
from Pixiv.pixivapi import PixivApi
|
||||
from Twitter.tweetyapi import TweetyApi
|
||||
from Classifier.wdClassifier import WDClassifier
|
||||
from Database.dbcontroller import DatabaseController
|
||||
from Classifier.havoxClassifier import VoxClassifier
|
||||
|
||||
class RuntimeBotData:
|
||||
initialized = False
|
||||
dead_accounts = []
|
||||
new_accounts = []
|
||||
|
||||
client : nextcord.Client = None
|
||||
twApi : TweetyApi = None
|
||||
pixivApi : PixivApi = None
|
||||
db : DatabaseController = None
|
||||
vox : VoxClassifier = None
|
||||
classifier : WDClassifier = None
|
||||
session : aiohttp.ClientSession = None
|
||||
|
||||
xView : XView = None
|
||||
yView : YView = None
|
||||
|
||||
async def initialize_data(self):
|
||||
self.twApi = await TweetyApi().init()
|
||||
self.pixivApi = PixivApi().init()
|
||||
self.db = DatabaseController()
|
||||
self.vox = VoxClassifier()
|
||||
self.classifier = WDClassifier()
|
||||
self.xView = XView(self)
|
||||
self.yView = YView(self)
|
||||
|
||||
#connector = aiohttp.TCPConnector(limit=60)
|
||||
self.session = aiohttp.ClientSession()
|
||||
Reference in New Issue
Block a user