import asyncio import os from dataclasses import dataclass from io import BytesIO from PIL import Image import onnxruntime as onnx import numpy as np import pandas as pd from concurrent.futures import ThreadPoolExecutor @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] class WDClassifier(): onnxSession = None relevant_tags = [] ignored_tags = [] def __init__(self): self.labels = WDClassifier.load_labels("Classifier/SmilingWolf/selected_tags.csv") self.relevant_tags = self.get_relevant_tags() self.onnxSession = onnx.InferenceSession('Classifier/SmilingWolf/model.onnx') def classify(self, image_bytes, max_count = -1, tag_threshold = -1): try: with BytesIO(image_bytes) as image_file: with Image.open(image_file) as image: image_data = self.resize_image(image, 448) # convert to numpy array and add batch dimension image_data = np.array(image_data, np.float32) image_data = np.expand_dims(image_data, axis=0) # NHWC image RGB to BGR image_data = image_data[..., ::-1] outputs = self.onnxSession.run(None, {'input': image_data}) outputs = outputs[0].squeeze(0) ratings, general, filtered = self.processRating(outputs) rating = max(ratings, key=ratings.get) if tag_threshold != -1: general = self.filterThreshold(general, tag_threshold) filtered = self.filterThreshold(filtered, tag_threshold) if max_count != -1: general = self.filterCount(general, max_count) filtered = self.filterThreshold(filtered, tag_threshold) return rating, general, filtered except Exception as ex: print(ex) return "", {}, {} async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1): with ThreadPoolExecutor() as executor: future = executor.submit(self.classify, image_bytes, max_count, tag_threshold) while not future.done(): await asyncio.sleep(1) return future.result() def processRating(self, data): ratings = {self.labels.names[i]: round(data[i].item(),3) for i in self.labels.rating} general = {self.labels.names[i]: round(data[i].item(),3) for i in (self.labels.general + self.labels.character)} filtered ={self.labels.names[i]: round(data[i].item(),3) for i in self.relevant_tags} return ratings, general, filtered def filterThreshold(self, data, threshold = 0.6): output = {key:value for key,value in data.items() if value > threshold} return output def filterCount(self, data, count = 20): keys = list(data.keys()) values = list(data.values()) results_ordered = np.flip(np.argsort(values)) output = {} for idx, dict_index in enumerate(results_ordered): if idx < count: output[keys[dict_index]] = values[dict_index] else: break return output def filterTagList(self, data, tagList): output = [] for result in data: if result["label"] in tagList: output.append(result) return output def resize_image(self, image : Image.Image, size = 448): image = WDClassifier.pil_ensure_rgb(image) image = WDClassifier.pil_pad_square(image) image = WDClassifier.pil_resize(image, size) return image def get_tag_index(self, tag: str): try: idx = self.labels.names.index(tag) except: print(tag, "not found") idx = -1 return idx def get_relevant_tags(self): if not os.path.exists("Classifier/SmilingWolf/ignored_tags.txt"): with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file: file.write("") with open("Classifier/SmilingWolf/ignored_tags.txt", "rt", encoding="utf-8") as file: self.ignored_tags = file.read().splitlines() tags = {tag:index for index,tag in enumerate(self.labels.names)} for tag in self.ignored_tags: if tag in tags: del(tags[tag]) return list(tags.values()) def add_ignored_tag(self, tag: str): idx = self.get_tag_index(tag) if idx != -1 and idx in self.relevant_tags: self.ignored_tags.append(tag) with open("Classifier/SmilingWolf/ignored_tags.txt", "wt", encoding="utf-8") as file: file.write("\n".join(self.ignored_tags)) self.relevant_tags.remove(idx) @staticmethod def load_labels(csv_path): df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data @staticmethod def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = ( image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") ) # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image @staticmethod def pil_pad_square(image: Image.Image) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), (255, 255, 255)) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas @staticmethod def pil_resize(image: Image.Image, target_size: int) -> Image.Image: # Resize max_dim = max(image.size) if max_dim != target_size: image = image.resize( (target_size, target_size), Image.BICUBIC, ) return image @staticmethod def get_tags( probs, labels : LabelData, gen_threshold: float, char_threshold: float, ): # Convert indices+probs to labels probs = list(zip(labels.names, probs)) # First 4 labels are actually ratings rating_labels = dict([probs[i] for i in labels.rating]) # General labels, pick any where prediction confidence > threshold gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) gen_labels = dict( sorted( gen_labels.items(), key=lambda item: item[1], reverse=True, ) ) # Character labels, pick any where prediction confidence > threshold char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > char_threshold]) char_labels = dict( sorted( char_labels.items(), key=lambda item: item[1], reverse=True, ) ) # Combine general and character labels, sort by confidence combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) # Convert to a string suitable for use as a training caption caption = ", ".join(combined_names) taglist = caption.replace("_", " ").replace("(", r"\(").replace(")", r"\)") return caption, taglist, rating_labels, char_labels, gen_labels if __name__ == "__main__": async def main(): classifier = WDClassifier() with open("Classifier/sample4.png", "rb") as image: image_data = image.read() rating, results, discord = await classifier.classify_async(image_data, 20, 0.6) print(rating, results, discord) asyncio.run(main())