Files
JapariArchive/Classifier/havoxClassifier.py

52 lines
1.9 KiB
Python

import asyncio
from PIL import Image
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor
from transformers import ViTForImageClassification, ViTImageProcessor
from Database.x_classes import HavoxLabel
class VoxClassifier():
def __init__(self):
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
def classify(self, image_bytes):
try:
with BytesIO(image_bytes) as image_file:
with Image.open(image_file).convert("RGB") as image:
inputs = self.feature_extractor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
logits = outputs.logits
# model predicts one of the 3 tags
predicted_class_idx = logits.argmax(-1).item()
return self.model.config.id2label[predicted_class_idx]
except Exception as ex:
print(ex)
return HavoxLabel.Rejected
async def classify_async(self, image_bytes):
if not image_bytes:
#mainly for video files
print("image bytes was null")
return HavoxLabel.Rejected
with ThreadPoolExecutor() as executor:
future = executor.submit(self.classify, image_bytes)
while not future.done():
await asyncio.sleep(1)
return future.result()
if __name__ == "__main__":
async def main():
classifier = VoxClassifier()
images = [
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
]
for image in images:
with open(image, "rb") as file:
result = await classifier.classify_async(file.read())
print(result)
asyncio.run(main())