fresh start

This commit is contained in:
2026-01-08 22:07:03 +01:00
commit 715be97289
30 changed files with 2302 additions and 0 deletions

View File

@@ -0,0 +1 @@
put all files from https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter/tree/main here

View File

@@ -0,0 +1 @@
put model.onnx and selected_tags.csv from https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/tree/main here

View File

@@ -0,0 +1,36 @@
from io import BytesIO
from PIL import Image, UnidentifiedImageError
import imagehash
from Classifier.havoxClassifier import VoxClassifier
from Classifier.wdClassifier import WDClassifier
from Database.x_classes import ErrorID, HavoxLabel, PostRating
from helpers import twos_complement
async def classify_all(photo_bytes, wd_classifier: WDClassifier, vox_classifier : VoxClassifier):
vox_label = await vox_classifier.classify_async(photo_bytes)
is_filtered = vox_label == HavoxLabel.Rejected
tags = []
filtered_tags = {}
phash = None
dhash = None
error_id = ErrorID.SUCCESS
if not is_filtered:
rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6)
tags = list(tags.keys())
try:
with BytesIO(photo_bytes) as image_file:
with Image.open(image_file) as imag:
phash = twos_complement(str(imagehash.phash(imag)), 64)
dhash = twos_complement(str(imagehash.dhash(imag)), 64)
except UnidentifiedImageError:
error_id = ErrorID.HASHING_Unidentified
except FileNotFoundError:
error_id = ErrorID.HASHING_FileNotFound
except Exception as ex:
error_id = ErrorID.HASHING_OTHER
else:
rating = PostRating.Filtered
return vox_label, rating, tags, filtered_tags, phash, dhash, error_id

View File

@@ -0,0 +1,46 @@
import asyncio
from PIL import Image
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor
from transformers import ViTForImageClassification, ViTImageProcessor
class VoxClassifier():
def __init__(self):
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
def classify(self, image_bytes):
try:
with BytesIO(image_bytes) as image_file:
with Image.open(image_file).convert("RGB") as image:
inputs = self.feature_extractor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
logits = outputs.logits
# model predicts one of the 3 tags
predicted_class_idx = logits.argmax(-1).item()
return self.model.config.id2label[predicted_class_idx]
except Exception as ex:
print(ex)
return "Rejected"
async def classify_async(self, image_bytes):
with ThreadPoolExecutor() as executor:
future = executor.submit(self.classify, image_bytes)
while not future.done():
await asyncio.sleep(1)
return future.result()
if __name__ == "__main__":
async def main():
classifier = VoxClassifier()
images = [
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
]
for image in images:
with open(image, "rb") as file:
result = await classifier.classify_async(file.read())
print(result)
asyncio.run(main())

238
Classifier/wdClassifier.py Normal file
View File

@@ -0,0 +1,238 @@
import asyncio
import os
from dataclasses import dataclass
from io import BytesIO
from PIL import Image
import onnxruntime as onnx
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
class WDClassifier():
onnxSession = None
relevant_tags = []
ignored_tags = []
def __init__(self):
self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv")
self.relevant_tags = self.get_relevant_tags()
self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx')
def classify(self, image_bytes, max_count = -1, tag_threshold = -1):
try:
with BytesIO(image_bytes) as image_file:
with Image.open(image_file) as image:
image_data = self.resize_image(image, 448)
# convert to numpy array and add batch dimension
image_data = np.array(image_data, np.float32)
image_data = np.expand_dims(image_data, axis=0)
# NHWC image RGB to BGR
image_data = image_data[..., ::-1]
outputs = self.onnxSession.run(None, {'input': image_data})
outputs = outputs[0].squeeze(0)
ratings, general, filtered = self.processRating(outputs)
rating = max(ratings, key=ratings.get)
if tag_threshold != -1:
general = self.filterThreshold(general, tag_threshold)
filtered = self.filterThreshold(filtered, tag_threshold)
if max_count != -1:
general = self.filterCount(general, max_count)
filtered = self.filterThreshold(filtered, tag_threshold)
return rating, general, filtered
except Exception as ex:
print(ex)
return "", {}, {}
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
with ThreadPoolExecutor() as executor:
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
while not future.done():
await asyncio.sleep(1)
return future.result()
def processRating(self, data):
ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating}
general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)}
filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags}
return ratings, general, filtered
def filterThreshold(self, data, threshold = 0.6):
output = {key:value for key,value in data.items() if value > threshold}
return output
def filterCount(self, data, count = 20):
keys = list(data.keys())
values = list(data.values())
results_ordered = np.flip(np.argsort(values))
output = {}
for idx, dict_index in enumerate(results_ordered):
if idx < count:
output[keys[dict_index]] = values[dict_index]
else:
break
return output
def filterTagList(self, data, tagList):
output = []
for result in data:
if result["label"] in tagList:
output.append(result)
return output
def resize_image(self, image : Image.Image, size = 448):
image = WDClassifier.pil_ensure_rgb(image)
image = WDClassifier.pil_pad_square(image)
image = WDClassifier.pil_resize(image, size)
return image
def get_tag_index(self, tag: str):
try:
idx = self.labels.names.index(tag)
except:
print(tag, "not found")
idx = -1
return idx
def get_relevant_tags(self):
if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"):
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
file.write("")
with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file:
self.ignored_tags = file.read().splitlines()
tags = {tag:index for index,tag in enumerate(self.labels.names)}
for tag in self.ignored_tags:
if tag in tags:
del(tags[tag])
return list(tags.values())
def add_ignored_tag(self, tag: str):
idx = self.get_tag_index(tag)
if idx != -1 and idx in self.relevant_tags:
self.ignored_tags.append(tag)
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
file.write("\n".join(self.ignored_tags))
self.relevant_tags.remove(idx)
@staticmethod
def load_labels(csv_path):
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
@staticmethod
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = (
image.convert("RGBA")
if "transparency" in image.info
else image.convert("RGB")
)
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
@staticmethod
def pil_pad_square(image: Image.Image) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), (255, 255, 255))
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
@staticmethod
def pil_resize(image: Image.Image, target_size: int) -> Image.Image:
# Resize
max_dim = max(image.size)
if max_dim != target_size:
image = image.resize(
(target_size, target_size),
Image.BICUBIC,
)
return image
@staticmethod
def get_tags(
probs,
labels : LabelData,
gen_threshold: float,
char_threshold: float,
):
# Convert indices+probs to labels
probs = list(zip(labels.names, probs))
# First 4 labels are actually ratings
rating_labels = dict([probs[i] for i in labels.rating])
# General labels, pick any where prediction confidence > threshold
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(
sorted(
gen_labels.items(),
key=lambda item: item[1],
reverse=True,
)
)
# Character labels, pick any where prediction confidence > threshold
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(
sorted(
char_labels.items(),
key=lambda item: item[1],
reverse=True,
)
)
# Combine general and character labels, sort by confidence
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])
# Convert to a string suitable for use as a training caption
caption = ", ".join(combined_names)
taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)")
return caption, taglist, rating_labels, char_labels, gen_labels
if __name__ == "__main__":
async def main():
classifier = WDClassifier()
with open("Classifier/sample4.png", "rb") as image:
image_data = image.read()
rating, results, discord = await classifier.classify_async(image_data, 20, 0.6)
print(rating, results, discord)
asyncio.run(main())