fresh start

This commit is contained in:
2026-01-08 22:07:03 +01:00
commit 715be97289
30 changed files with 2302 additions and 0 deletions

238
Classifier/wdClassifier.py Normal file
View File

@@ -0,0 +1,238 @@
import asyncio
import os
from dataclasses import dataclass
from io import BytesIO
from PIL import Image
import onnxruntime as onnx
import numpy as np
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
class WDClassifier():
onnxSession = None
relevant_tags = []
ignored_tags = []
def __init__(self):
self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv")
self.relevant_tags = self.get_relevant_tags()
self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx')
def classify(self, image_bytes, max_count = -1, tag_threshold = -1):
try:
with BytesIO(image_bytes) as image_file:
with Image.open(image_file) as image:
image_data = self.resize_image(image, 448)
# convert to numpy array and add batch dimension
image_data = np.array(image_data, np.float32)
image_data = np.expand_dims(image_data, axis=0)
# NHWC image RGB to BGR
image_data = image_data[..., ::-1]
outputs = self.onnxSession.run(None, {'input': image_data})
outputs = outputs[0].squeeze(0)
ratings, general, filtered = self.processRating(outputs)
rating = max(ratings, key=ratings.get)
if tag_threshold != -1:
general = self.filterThreshold(general, tag_threshold)
filtered = self.filterThreshold(filtered, tag_threshold)
if max_count != -1:
general = self.filterCount(general, max_count)
filtered = self.filterThreshold(filtered, tag_threshold)
return rating, general, filtered
except Exception as ex:
print(ex)
return "", {}, {}
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
with ThreadPoolExecutor() as executor:
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
while not future.done():
await asyncio.sleep(1)
return future.result()
def processRating(self, data):
ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating}
general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)}
filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags}
return ratings, general, filtered
def filterThreshold(self, data, threshold = 0.6):
output = {key:value for key,value in data.items() if value > threshold}
return output
def filterCount(self, data, count = 20):
keys = list(data.keys())
values = list(data.values())
results_ordered = np.flip(np.argsort(values))
output = {}
for idx, dict_index in enumerate(results_ordered):
if idx < count:
output[keys[dict_index]] = values[dict_index]
else:
break
return output
def filterTagList(self, data, tagList):
output = []
for result in data:
if result["label"] in tagList:
output.append(result)
return output
def resize_image(self, image : Image.Image, size = 448):
image = WDClassifier.pil_ensure_rgb(image)
image = WDClassifier.pil_pad_square(image)
image = WDClassifier.pil_resize(image, size)
return image
def get_tag_index(self, tag: str):
try:
idx = self.labels.names.index(tag)
except:
print(tag, "not found")
idx = -1
return idx
def get_relevant_tags(self):
if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"):
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
file.write("")
with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file:
self.ignored_tags = file.read().splitlines()
tags = {tag:index for index,tag in enumerate(self.labels.names)}
for tag in self.ignored_tags:
if tag in tags:
del(tags[tag])
return list(tags.values())
def add_ignored_tag(self, tag: str):
idx = self.get_tag_index(tag)
if idx != -1 and idx in self.relevant_tags:
self.ignored_tags.append(tag)
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
file.write("\n".join(self.ignored_tags))
self.relevant_tags.remove(idx)
@staticmethod
def load_labels(csv_path):
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
@staticmethod
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = (
image.convert("RGBA")
if "transparency" in image.info
else image.convert("RGB")
)
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
@staticmethod
def pil_pad_square(image: Image.Image) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), (255, 255, 255))
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
@staticmethod
def pil_resize(image: Image.Image, target_size: int) -> Image.Image:
# Resize
max_dim = max(image.size)
if max_dim != target_size:
image = image.resize(
(target_size, target_size),
Image.BICUBIC,
)
return image
@staticmethod
def get_tags(
probs,
labels : LabelData,
gen_threshold: float,
char_threshold: float,
):
# Convert indices+probs to labels
probs = list(zip(labels.names, probs))
# First 4 labels are actually ratings
rating_labels = dict([probs[i] for i in labels.rating])
# General labels, pick any where prediction confidence > threshold
gen_labels = [probs[i] for i in labels.general]
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
gen_labels = dict(
sorted(
gen_labels.items(),
key=lambda item: item[1],
reverse=True,
)
)
# Character labels, pick any where prediction confidence > threshold
char_labels = [probs[i] for i in labels.character]
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
char_labels = dict(
sorted(
char_labels.items(),
key=lambda item: item[1],
reverse=True,
)
)
# Combine general and character labels, sort by confidence
combined_names = [x for x in gen_labels]
combined_names.extend([x for x in char_labels])
# Convert to a string suitable for use as a training caption
caption = ", ".join(combined_names)
taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)")
return caption, taglist, rating_labels, char_labels, gen_labels
if __name__ == "__main__":
async def main():
classifier = WDClassifier()
with open("Classifier/sample4.png", "rb") as image:
image_data = image.read()
rating, results, discord = await classifier.classify_async(image_data, 20, 0.6)
print(rating, results, discord)
asyncio.run(main())