You've already forked JapariArchive
46 lines
1.7 KiB
Python
46 lines
1.7 KiB
Python
import asyncio
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from transformers import ViTForImageClassification, ViTImageProcessor
|
|
|
|
class VoxClassifier():
|
|
def __init__(self):
|
|
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
|
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
|
|
|
|
def classify(self, image_bytes):
|
|
try:
|
|
with BytesIO(image_bytes) as image_file:
|
|
with Image.open(image_file).convert("RGB") as image:
|
|
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
|
outputs = self.model(**inputs)
|
|
logits = outputs.logits
|
|
# model predicts one of the 3 tags
|
|
predicted_class_idx = logits.argmax(-1).item()
|
|
return self.model.config.id2label[predicted_class_idx]
|
|
except Exception as ex:
|
|
print(ex)
|
|
return "Rejected"
|
|
|
|
async def classify_async(self, image_bytes):
|
|
with ThreadPoolExecutor() as executor:
|
|
future = executor.submit(self.classify, image_bytes)
|
|
|
|
while not future.done():
|
|
await asyncio.sleep(1)
|
|
|
|
return future.result()
|
|
|
|
if __name__ == "__main__":
|
|
async def main():
|
|
classifier = VoxClassifier()
|
|
images = [
|
|
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
|
|
]
|
|
for image in images:
|
|
with open(image, "rb") as file:
|
|
result = await classifier.classify_async(file.read())
|
|
print(result)
|
|
|
|
asyncio.run(main()) |