You've already forked JapariArchive
238 lines
8.5 KiB
Python
238 lines
8.5 KiB
Python
import asyncio
|
|
import os
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
|
|
import onnxruntime as onnx
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
@dataclass
|
|
class LabelData:
|
|
names: list[str]
|
|
rating: list[np.int64]
|
|
general: list[np.int64]
|
|
character: list[np.int64]
|
|
|
|
class WDClassifier():
|
|
onnxSession = None
|
|
relevant_tags = []
|
|
ignored_tags = []
|
|
|
|
def __init__(self):
|
|
self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv")
|
|
self.relevant_tags = self.get_relevant_tags()
|
|
self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx')
|
|
|
|
def classify(self, image_bytes, max_count = -1, tag_threshold = -1):
|
|
try:
|
|
with BytesIO(image_bytes) as image_file:
|
|
with Image.open(image_file) as image:
|
|
image_data = self.resize_image(image, 448)
|
|
# convert to numpy array and add batch dimension
|
|
image_data = np.array(image_data, np.float32)
|
|
image_data = np.expand_dims(image_data, axis=0)
|
|
# NHWC image RGB to BGR
|
|
image_data = image_data[..., ::-1]
|
|
|
|
outputs = self.onnxSession.run(None, {'input': image_data})
|
|
outputs = outputs[0].squeeze(0)
|
|
|
|
ratings, general, filtered = self.processRating(outputs)
|
|
rating = max(ratings, key=ratings.get)
|
|
|
|
if tag_threshold != -1:
|
|
general = self.filterThreshold(general, tag_threshold)
|
|
filtered = self.filterThreshold(filtered, tag_threshold)
|
|
|
|
if max_count != -1:
|
|
general = self.filterCount(general, max_count)
|
|
filtered = self.filterThreshold(filtered, tag_threshold)
|
|
|
|
return rating, general, filtered
|
|
|
|
except Exception as ex:
|
|
print(ex)
|
|
return "", {}, {}
|
|
|
|
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
|
|
with ThreadPoolExecutor() as executor:
|
|
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
|
|
|
|
while not future.done():
|
|
await asyncio.sleep(1)
|
|
|
|
return future.result()
|
|
|
|
def processRating(self, data):
|
|
ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating}
|
|
general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)}
|
|
filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags}
|
|
return ratings, general, filtered
|
|
|
|
def filterThreshold(self, data, threshold = 0.6):
|
|
output = {key:value for key,value in data.items() if value > threshold}
|
|
return output
|
|
|
|
def filterCount(self, data, count = 20):
|
|
keys = list(data.keys())
|
|
values = list(data.values())
|
|
results_ordered = np.flip(np.argsort(values))
|
|
output = {}
|
|
for idx, dict_index in enumerate(results_ordered):
|
|
if idx < count:
|
|
output[keys[dict_index]] = values[dict_index]
|
|
else:
|
|
break
|
|
return output
|
|
|
|
def filterTagList(self, data, tagList):
|
|
output = []
|
|
for result in data:
|
|
if result["label"] in tagList:
|
|
output.append(result)
|
|
return output
|
|
|
|
def resize_image(self, image : Image.Image, size = 448):
|
|
image = WDClassifier.pil_ensure_rgb(image)
|
|
image = WDClassifier.pil_pad_square(image)
|
|
image = WDClassifier.pil_resize(image, size)
|
|
return image
|
|
|
|
def get_tag_index(self, tag: str):
|
|
try:
|
|
idx = self.labels.names.index(tag)
|
|
except:
|
|
print(tag, "not found")
|
|
idx = -1
|
|
return idx
|
|
|
|
def get_relevant_tags(self):
|
|
if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"):
|
|
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
|
file.write("")
|
|
|
|
with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file:
|
|
self.ignored_tags = file.read().splitlines()
|
|
|
|
tags = {tag:index for index,tag in enumerate(self.labels.names)}
|
|
for tag in self.ignored_tags:
|
|
if tag in tags:
|
|
del(tags[tag])
|
|
return list(tags.values())
|
|
|
|
def add_ignored_tag(self, tag: str):
|
|
idx = self.get_tag_index(tag)
|
|
if idx != -1 and idx in self.relevant_tags:
|
|
self.ignored_tags.append(tag)
|
|
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
|
file.write("\n".join(self.ignored_tags))
|
|
self.relevant_tags.remove(idx)
|
|
|
|
@staticmethod
|
|
def load_labels(csv_path):
|
|
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
|
tag_data = LabelData(
|
|
names=df["name"].tolist(),
|
|
rating=list(np.where(df["category"] == 9)[0]),
|
|
general=list(np.where(df["category"] == 0)[0]),
|
|
character=list(np.where(df["category"] == 4)[0]),
|
|
)
|
|
|
|
return tag_data
|
|
|
|
@staticmethod
|
|
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
|
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
|
if image.mode not in ["RGB", "RGBA"]:
|
|
image = (
|
|
image.convert("RGBA")
|
|
if "transparency" in image.info
|
|
else image.convert("RGB")
|
|
)
|
|
# convert RGBA to RGB with white background
|
|
if image.mode == "RGBA":
|
|
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
|
canvas.alpha_composite(image)
|
|
image = canvas.convert("RGB")
|
|
return image
|
|
|
|
@staticmethod
|
|
def pil_pad_square(image: Image.Image) -> Image.Image:
|
|
w, h = image.size
|
|
# get the largest dimension so we can pad to a square
|
|
px = max(image.size)
|
|
# pad to square with white background
|
|
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
|
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
|
return canvas
|
|
|
|
@staticmethod
|
|
def pil_resize(image: Image.Image, target_size: int) -> Image.Image:
|
|
# Resize
|
|
max_dim = max(image.size)
|
|
if max_dim != target_size:
|
|
image = image.resize(
|
|
(target_size, target_size),
|
|
Image.BICUBIC,
|
|
)
|
|
return image
|
|
|
|
@staticmethod
|
|
def get_tags(
|
|
probs,
|
|
labels : LabelData,
|
|
gen_threshold: float,
|
|
char_threshold: float,
|
|
):
|
|
# Convert indices+probs to labels
|
|
probs = list(zip(labels.names, probs))
|
|
|
|
# First 4 labels are actually ratings
|
|
rating_labels = dict([probs[i] for i in labels.rating])
|
|
|
|
# General labels, pick any where prediction confidence > threshold
|
|
gen_labels = [probs[i] for i in labels.general]
|
|
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
|
gen_labels = dict(
|
|
sorted(
|
|
gen_labels.items(),
|
|
key=lambda item: item[1],
|
|
reverse=True,
|
|
)
|
|
)
|
|
|
|
# Character labels, pick any where prediction confidence > threshold
|
|
char_labels = [probs[i] for i in labels.character]
|
|
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
|
char_labels = dict(
|
|
sorted(
|
|
char_labels.items(),
|
|
key=lambda item: item[1],
|
|
reverse=True,
|
|
)
|
|
)
|
|
|
|
# Combine general and character labels, sort by confidence
|
|
combined_names = [x for x in gen_labels]
|
|
combined_names.extend([x for x in char_labels])
|
|
|
|
# Convert to a string suitable for use as a training caption
|
|
caption = ", ".join(combined_names)
|
|
taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)")
|
|
|
|
return caption, taglist, rating_labels, char_labels, gen_labels
|
|
|
|
if __name__ == "__main__":
|
|
async def main():
|
|
classifier = WDClassifier()
|
|
with open("Classifier/sample4.png", "rb") as image:
|
|
image_data = image.read()
|
|
rating, results, discord = await classifier.classify_async(image_data, 20, 0.6)
|
|
|
|
print(rating, results, discord)
|
|
|
|
asyncio.run(main()) |