You've already forked JapariArchive
fresh start
This commit is contained in:
46
Classifier/havoxClassifier.py
Normal file
46
Classifier/havoxClassifier.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import asyncio
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from transformers import ViTForImageClassification, ViTImageProcessor
|
||||
|
||||
class VoxClassifier():
|
||||
def __init__(self):
|
||||
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
||||
self.model = ViTForImageClassification.from_pretrained('Classifier/HAV0X/')
|
||||
|
||||
def classify(self, image_bytes):
|
||||
try:
|
||||
with BytesIO(image_bytes) as image_file:
|
||||
with Image.open(image_file).convert("RGB") as image:
|
||||
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs.logits
|
||||
# model predicts one of the 3 tags
|
||||
predicted_class_idx = logits.argmax(-1).item()
|
||||
return self.model.config.id2label[predicted_class_idx]
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
return "Rejected"
|
||||
|
||||
async def classify_async(self, image_bytes):
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(self.classify, image_bytes)
|
||||
|
||||
while not future.done():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return future.result()
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
classifier = VoxClassifier()
|
||||
images = [
|
||||
"Classifier/sample.jpg", "Classifier/sample2.jpg", "Classifier/sample3.jpg", "Classifier/sample4.png"
|
||||
]
|
||||
for image in images:
|
||||
with open(image, "rb") as file:
|
||||
result = await classifier.classify_async(file.read())
|
||||
print(result)
|
||||
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user