You've already forked JapariArchive
fresh start
This commit is contained in:
1
Classifier/HAV0X/readme.txt
Normal file
1
Classifier/HAV0X/readme.txt
Normal file
@@ -0,0 +1 @@
|
||||
put all files from https://huggingface.co/HAV0X1014/Kemono-Friends-Sorter/tree/main here
|
||||
1
Classifier/SmilingWolf/readme.txt
Normal file
1
Classifier/SmilingWolf/readme.txt
Normal file
@@ -0,0 +1 @@
|
||||
put model.onnx and selected_tags.csv from https://huggingface.co/SmilingWolf/wd-vit-tagger-v3/tree/main here
|
||||
36
Classifier/classifyHelper.py
Normal file
36
Classifier/classifyHelper.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from io import BytesIO
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
import imagehash
|
||||
from Classifier.havoxClassifier import VoxClassifier
|
||||
from Classifier.wdClassifier import WDClassifier
|
||||
from Database.x_classes import ErrorID, HavoxLabel, PostRating
|
||||
from helpers import twos_complement
|
||||
|
||||
async def classify_all(photo_bytes, wd_classifier: WDClassifier, vox_classifier : VoxClassifier):
|
||||
vox_label = await vox_classifier.classify_async(photo_bytes)
|
||||
is_filtered = vox_label == HavoxLabel.Rejected
|
||||
|
||||
tags = []
|
||||
filtered_tags = {}
|
||||
phash = None
|
||||
dhash = None
|
||||
error_id = ErrorID.SUCCESS
|
||||
if not is_filtered:
|
||||
rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6)
|
||||
tags = list(tags.keys())
|
||||
|
||||
try:
|
||||
with BytesIO(photo_bytes) as image_file:
|
||||
with Image.open(image_file) as imag:
|
||||
phash = twos_complement(str(imagehash.phash(imag)), 64)
|
||||
dhash = twos_complement(str(imagehash.dhash(imag)), 64)
|
||||
except UnidentifiedImageError:
|
||||
error_id = ErrorID.HASHING_Unidentified
|
||||
except FileNotFoundError:
|
||||
error_id = ErrorID.HASHING_FileNotFound
|
||||
except Exception as ex:
|
||||
error_id = ErrorID.HASHING_OTHER
|
||||
else:
|
||||
rating = PostRating.Filtered
|
||||
|
||||
return vox_label, rating, tags, filtered_tags, phash, dhash, error_id
|
||||
46
Classifier/havoxClassifier.py
Normal file
46
Classifier/havoxClassifier.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import asyncio
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from transformers import ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
class VoxClassifier():
|
||||
def __init__(self):
|
||||
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
||||
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
|
||||
|
||||
def classify(self, image_bytes):
|
||||
try:
|
||||
with BytesIO(image_bytes) as image_file:
|
||||
with Image.open(image_file).convert("RGB") as image:
|
||||
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
# model predicts one of the 3 tags
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
return self.model.config.id2label[predicted_class_idx]
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return "Rejected"
|
||||
|
||||
async def classify_async(self, image_bytes):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(self.classify, image_bytes)
|
||||
|
||||
while not future.done():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return future.result()
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
classifier = VoxClassifier()
|
||||
images = [
|
||||
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
|
||||
]
|
||||
for image in images:
|
||||
with open(image, "rb") as file:
|
||||
result = await classifier.classify_async(file.read())
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
238
Classifier/wdClassifier.py
Normal file
238
Classifier/wdClassifier.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
import onnxruntime as onnx
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@dataclass
|
||||
class LabelData:
|
||||
names: list[str]
|
||||
rating: list[np.int64]
|
||||
general: list[np.int64]
|
||||
character: list[np.int64]
|
||||
|
||||
class WDClassifier():
|
||||
onnxSession = None
|
||||
relevant_tags = []
|
||||
ignored_tags = []
|
||||
|
||||
def __init__(self):
|
||||
self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv")
|
||||
self.relevant_tags = self.get_relevant_tags()
|
||||
self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx')
|
||||
|
||||
def classify(self, image_bytes, max_count = -1, tag_threshold = -1):
|
||||
try:
|
||||
with BytesIO(image_bytes) as image_file:
|
||||
with Image.open(image_file) as image:
|
||||
image_data = self.resize_image(image, 448)
|
||||
# convert to numpy array and add batch dimension
|
||||
image_data = np.array(image_data, np.float32)
|
||||
image_data = np.expand_dims(image_data, axis=0)
|
||||
# NHWC image RGB to BGR
|
||||
image_data = image_data[..., ::-1]
|
||||
|
||||
outputs = self.onnxSession.run(None, {'input': image_data})
|
||||
outputs = outputs[0].squeeze(0)
|
||||
|
||||
ratings, general, filtered = self.processRating(outputs)
|
||||
rating = max(ratings, key=ratings.get)
|
||||
|
||||
if tag_threshold != -1:
|
||||
general = self.filterThreshold(general, tag_threshold)
|
||||
filtered = self.filterThreshold(filtered, tag_threshold)
|
||||
|
||||
if max_count != -1:
|
||||
general = self.filterCount(general, max_count)
|
||||
filtered = self.filterThreshold(filtered, tag_threshold)
|
||||
|
||||
return rating, general, filtered
|
||||
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return "", {}, {}
|
||||
|
||||
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
|
||||
|
||||
while not future.done():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return future.result()
|
||||
|
||||
def processRating(self, data):
|
||||
ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating}
|
||||
general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)}
|
||||
filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags}
|
||||
return ratings, general, filtered
|
||||
|
||||
def filterThreshold(self, data, threshold = 0.6):
|
||||
output = {key:value for key,value in data.items() if value > threshold}
|
||||
return output
|
||||
|
||||
def filterCount(self, data, count = 20):
|
||||
keys = list(data.keys())
|
||||
values = list(data.values())
|
||||
results_ordered = np.flip(np.argsort(values))
|
||||
output = {}
|
||||
for idx, dict_index in enumerate(results_ordered):
|
||||
if idx < count:
|
||||
output[keys[dict_index]] = values[dict_index]
|
||||
else:
|
||||
break
|
||||
return output
|
||||
|
||||
def filterTagList(self, data, tagList):
|
||||
output = []
|
||||
for result in data:
|
||||
if result["label"] in tagList:
|
||||
output.append(result)
|
||||
return output
|
||||
|
||||
def resize_image(self, image : Image.Image, size = 448):
|
||||
image = WDClassifier.pil_ensure_rgb(image)
|
||||
image = WDClassifier.pil_pad_square(image)
|
||||
image = WDClassifier.pil_resize(image, size)
|
||||
return image
|
||||
|
||||
def get_tag_index(self, tag: str):
|
||||
try:
|
||||
idx = self.labels.names.index(tag)
|
||||
except:
|
||||
print(tag, "not found")
|
||||
idx = -1
|
||||
return idx
|
||||
|
||||
def get_relevant_tags(self):
|
||||
if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"):
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
||||
file.write("")
|
||||
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file:
|
||||
self.ignored_tags = file.read().splitlines()
|
||||
|
||||
tags = {tag:index for index,tag in enumerate(self.labels.names)}
|
||||
for tag in self.ignored_tags:
|
||||
if tag in tags:
|
||||
del(tags[tag])
|
||||
return list(tags.values())
|
||||
|
||||
def add_ignored_tag(self, tag: str):
|
||||
idx = self.get_tag_index(tag)
|
||||
if idx != -1 and idx in self.relevant_tags:
|
||||
self.ignored_tags.append(tag)
|
||||
with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file:
|
||||
file.write("\n".join(self.ignored_tags))
|
||||
self.relevant_tags.remove(idx)
|
||||
|
||||
@staticmethod
|
||||
def load_labels(csv_path):
|
||||
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
|
||||
tag_data = LabelData(
|
||||
names=df["name"].tolist(),
|
||||
rating=list(np.where(df["category"] == 9)[0]),
|
||||
general=list(np.where(df["category"] == 0)[0]),
|
||||
character=list(np.where(df["category"] == 4)[0]),
|
||||
)
|
||||
|
||||
return tag_data
|
||||
|
||||
@staticmethod
|
||||
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
||||
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = (
|
||||
image.convert("RGBA")
|
||||
if "transparency" in image.info
|
||||
else image.convert("RGB")
|
||||
)
|
||||
# convert RGBA to RGB with white background
|
||||
if image.mode == "RGBA":
|
||||
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
||||
canvas.alpha_composite(image)
|
||||
image = canvas.convert("RGB")
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def pil_pad_square(image: Image.Image) -> Image.Image:
|
||||
w, h = image.size
|
||||
# get the largest dimension so we can pad to a square
|
||||
px = max(image.size)
|
||||
# pad to square with white background
|
||||
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
||||
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
||||
return canvas
|
||||
|
||||
@staticmethod
|
||||
def pil_resize(image: Image.Image, target_size: int) -> Image.Image:
|
||||
# Resize
|
||||
max_dim = max(image.size)
|
||||
if max_dim != target_size:
|
||||
image = image.resize(
|
||||
(target_size, target_size),
|
||||
Image.BICUBIC,
|
||||
)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_tags(
|
||||
probs,
|
||||
labels : LabelData,
|
||||
gen_threshold: float,
|
||||
char_threshold: float,
|
||||
):
|
||||
# Convert indices+probs to labels
|
||||
probs = list(zip(labels.names, probs))
|
||||
|
||||
# First 4 labels are actually ratings
|
||||
rating_labels = dict([probs[i] for i in labels.rating])
|
||||
|
||||
# General labels, pick any where prediction confidence > threshold
|
||||
gen_labels = [probs[i] for i in labels.general]
|
||||
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
|
||||
gen_labels = dict(
|
||||
sorted(
|
||||
gen_labels.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Character labels, pick any where prediction confidence > threshold
|
||||
char_labels = [probs[i] for i in labels.character]
|
||||
char_labels = dict([x for x in char_labels if x[1] > char_threshold])
|
||||
char_labels = dict(
|
||||
sorted(
|
||||
char_labels.items(),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Combine general and character labels, sort by confidence
|
||||
combined_names = [x for x in gen_labels]
|
||||
combined_names.extend([x for x in char_labels])
|
||||
|
||||
# Convert to a string suitable for use as a training caption
|
||||
caption = ", ".join(combined_names)
|
||||
taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)")
|
||||
|
||||
return caption, taglist, rating_labels, char_labels, gen_labels
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
classifier = WDClassifier()
|
||||
with open("Classifier/sample4.png", "rb") as image:
|
||||
image_data = image.read()
|
||||
rating, results, discord = await classifier.classify_async(image_data, 20, 0.6)
|
||||
|
||||
print(rating, results, discord)
|
||||
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user