You've already forked JapariArchive
fixed memory leak related to video files
This commit is contained in:
@@ -15,7 +15,7 @@ async def classify_all(photo_bytes, wd_classifier: WDClassifier, vox_classifier
|
|||||||
phash = None
|
phash = None
|
||||||
dhash = None
|
dhash = None
|
||||||
error_id = ErrorID.SUCCESS
|
error_id = ErrorID.SUCCESS
|
||||||
if not is_filtered:
|
if not is_filtered: #won't execute if photo bytes is None (is video)
|
||||||
rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6)
|
rating, tags, filtered_tags = await wd_classifier.classify_async(photo_bytes, tag_threshold=0.6)
|
||||||
tags = list(tags.keys())
|
tags = list(tags.keys())
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from io import BytesIO
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from transformers import ViTForImageClassification, ViTImageProcessor
|
from transformers import ViTForImageClassification, ViTImageProcessor
|
||||||
|
|
||||||
|
from Database.x_classes import HavoxLabel
|
||||||
|
|
||||||
class VoxClassifier():
|
class VoxClassifier():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
self.feature_extractor = ViTImageProcessor.from_pretrained('Classifier/HAV0X/')
|
||||||
@@ -21,9 +23,13 @@ class VoxClassifier():
|
|||||||
return self.model.config.id2label[predicted_class_idx]
|
return self.model.config.id2label[predicted_class_idx]
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
return "Rejected"
|
return HavoxLabel.Rejected
|
||||||
|
|
||||||
async def classify_async(self, image_bytes):
|
async def classify_async(self, image_bytes):
|
||||||
|
if not image_bytes:
|
||||||
|
#mainly for video files
|
||||||
|
print("image bytes was null")
|
||||||
|
return HavoxLabel.Rejected
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(self.classify, image_bytes)
|
future = executor.submit(self.classify, image_bytes)
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ class WDClassifier():
|
|||||||
return "", {}, {}
|
return "", {}, {}
|
||||||
|
|
||||||
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
|
async def classify_async(self, image_bytes, max_count = -1, tag_threshold = -1):
|
||||||
|
if not image_bytes:
|
||||||
|
#mainly for video files
|
||||||
|
print("image bytes was null")
|
||||||
|
return "", {}, {}
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
|
future = executor.submit(self.classify, image_bytes, max_count, tag_threshold)
|
||||||
|
|
||||||
|
|||||||
@@ -63,10 +63,10 @@ def build_secondary_embed(main_post : Message, handle : str, post : Tweet):
|
|||||||
embed.set_author(name=handle, url=main_post.jump_url, icon_url=post.author.profile_image_url_https)
|
embed.set_author(name=handle, url=main_post.jump_url, icon_url=post.author.profile_image_url_https)
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
async def send_error(ex : Exception, botData : RuntimeBotData):
|
async def send_error(ex : str, botData : RuntimeBotData):
|
||||||
print(ex)
|
print(ex)
|
||||||
errors_channel = nextcord.utils.get(botData.client.guilds[0].channels, name="bot-status")
|
errors_channel = nextcord.utils.get(botData.client.guilds[0].channels, name="bot-status")
|
||||||
await errors_channel.send(content=str(ex))
|
await errors_channel.send(content=ex[0:512])
|
||||||
|
|
||||||
def get_secondary_channel(is_animated, is_filtered, rating, tags : list, artist : x_accounts, guild : nextcord.Guild):
|
def get_secondary_channel(is_animated, is_filtered, rating, tags : list, artist : x_accounts, guild : nextcord.Guild):
|
||||||
if is_animated:
|
if is_animated:
|
||||||
|
|||||||
@@ -33,5 +33,5 @@ async def download_loop(botData: RuntimeBotData):
|
|||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
await discordHelper.send_error(traceback.format_exc()[0:256], botData)
|
await discordHelper.send_error(str(ex) + " " + traceback.format_exc(), botData)
|
||||||
print("Pixiv done")
|
print("Pixiv done")
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ async def download_loop(botData: RuntimeBotData):
|
|||||||
await discordHelper.post_result(results, guild, botData.new_accounts)
|
await discordHelper.post_result(results, guild, botData.new_accounts)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(ex)
|
print(ex)
|
||||||
await discordHelper.send_error(traceback.format_exc()[0:256], botData)
|
await discordHelper.send_error(str(ex) + " " + traceback.format_exc(), botData)
|
||||||
|
|
||||||
async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotData):
|
async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotData):
|
||||||
x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text)
|
x_post = x_posts(id = tweet.id, account_id = tweet.author.id, date = tweet.date, text = tweet.text)
|
||||||
@@ -79,8 +79,8 @@ async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotDat
|
|||||||
return
|
return
|
||||||
|
|
||||||
print("New media post:", str(tweet.url))
|
print("New media post:", str(tweet.url))
|
||||||
media = await tweetHelper.GetTweetMediaUrls(tweet)
|
media = await tweetHelper.GetTweetMedia(tweet)
|
||||||
image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)]
|
image_containers = [x_posts_images(tweet.id, idx, file = med.url) for idx, med in enumerate(media)]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@@ -104,7 +104,7 @@ async def download_post(artist: x_accounts, tweet: Tweet, botData: RuntimeBotDat
|
|||||||
for idx, attachment in enumerate(downloaded_media):
|
for idx, attachment in enumerate(downloaded_media):
|
||||||
container = image_containers[idx]
|
container = image_containers[idx]
|
||||||
container.saved_file = attachment.file_name
|
container.saved_file = attachment.file_name
|
||||||
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox)
|
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes if not attachment.is_video else None, botData.classifier, botData.vox)
|
||||||
|
|
||||||
if container.vox_label not in vox_labels:
|
if container.vox_label not in vox_labels:
|
||||||
vox_labels.append(container.vox_label)
|
vox_labels.append(container.vox_label)
|
||||||
|
|||||||
@@ -15,18 +15,22 @@ if TYPE_CHECKING:
|
|||||||
class TweetMedia:
|
class TweetMedia:
|
||||||
url : str
|
url : str
|
||||||
file_name : str
|
file_name : str
|
||||||
|
is_video: bool
|
||||||
|
|
||||||
def __init__(self, url, file_name):
|
def __init__(self, url, file_name, is_video: bool):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
|
self.is_video = is_video
|
||||||
|
|
||||||
class DownloadedMedia:
|
class DownloadedMedia:
|
||||||
file_bytes : str
|
file_bytes : bytes
|
||||||
file_name : str
|
file_name : str
|
||||||
|
is_video: bool
|
||||||
|
|
||||||
def __init__(self, bytes, file_name):
|
def __init__(self, bytes, file_name, is_video: bool):
|
||||||
self.file_bytes = bytes
|
self.file_bytes = bytes
|
||||||
self.file_name = file_name
|
self.file_name = file_name
|
||||||
|
self.is_video = is_video
|
||||||
|
|
||||||
async def GetTweetMedia(tweet : Tweet) -> list[TweetMedia]:
|
async def GetTweetMedia(tweet : Tweet) -> list[TweetMedia]:
|
||||||
mediaList : list[TweetMedia] = []
|
mediaList : list[TweetMedia] = []
|
||||||
@@ -34,12 +38,12 @@ async def GetTweetMedia(tweet : Tweet) -> list[TweetMedia]:
|
|||||||
if media.file_format == 'mp4':
|
if media.file_format == 'mp4':
|
||||||
best_stream = await media.best_stream()
|
best_stream = await media.best_stream()
|
||||||
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{media.file_format}"
|
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{media.file_format}"
|
||||||
mediaList.append(TweetMedia(best_stream.direct_url, fileName))
|
mediaList.append(TweetMedia(best_stream.direct_url, fileName, True))
|
||||||
else:
|
else:
|
||||||
best_stream = await media.best_stream()
|
best_stream = await media.best_stream()
|
||||||
extension = best_stream.file_format
|
extension = best_stream.file_format
|
||||||
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{extension}"
|
fileName = f"{tweet.author.screen_name}_{tweet.id}_{idx}.{extension}"
|
||||||
mediaList.append(TweetMedia(best_stream.direct_url, fileName))
|
mediaList.append(TweetMedia(best_stream.direct_url, fileName, False))
|
||||||
|
|
||||||
return mediaList
|
return mediaList
|
||||||
|
|
||||||
@@ -47,18 +51,18 @@ async def GetTweetMediaUrls(tweet : Tweet):
|
|||||||
mediaList = await GetTweetMedia(tweet)
|
mediaList = await GetTweetMedia(tweet)
|
||||||
return [media.url for media in mediaList]
|
return [media.url for media in mediaList]
|
||||||
|
|
||||||
async def DownloadMedia(post_id, account_id, account_name, url_list : list, session) -> list[DownloadedMedia]:
|
async def DownloadMedia(post_id, account_id, account_name, media_list : list[TweetMedia], session) -> list[DownloadedMedia]:
|
||||||
result : list[DownloadedMedia] = []
|
result : list[DownloadedMedia] = []
|
||||||
path = f"{Global_Config["x_download_path"]}{account_id}"
|
path = f"{Global_Config["x_download_path"]}{account_id}"
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
for idx, file_url in enumerate(url_list):
|
for idx, media in enumerate(media_list):
|
||||||
file_name = get_file_name(account_name, post_id, idx, file_url)
|
file_name = get_file_name(account_name, post_id, idx, media.url)
|
||||||
full_path = f"{path}/{file_name}"
|
full_path = f"{path}/{file_name}"
|
||||||
|
|
||||||
photo_bytes = await downloadHelper.save_to_file(file_url, full_path, session)
|
photo_bytes = await downloadHelper.save_to_file(media.url, full_path, session)
|
||||||
|
|
||||||
result.append(DownloadedMedia(photo_bytes, file_name))
|
result.append(DownloadedMedia(photo_bytes, file_name, media.is_video))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -83,8 +83,8 @@ class Commands(commands.Cog):
|
|||||||
await discordHelper.ensure_has_channel_or_thread(artist, interaction.guild, botData.db)
|
await discordHelper.ensure_has_channel_or_thread(artist, interaction.guild, botData.db)
|
||||||
|
|
||||||
image_containers : list[x_posts_images] = []
|
image_containers : list[x_posts_images] = []
|
||||||
media = await tweetHelper.GetTweetMediaUrls(tweet)
|
media = await tweetHelper.GetTweetMedia(tweet)
|
||||||
image_containers = [x_posts_images(tweet.id, idx, file = url) for idx, url in enumerate(media)]
|
image_containers = [x_posts_images(tweet.id, idx, file = med.url) for idx, med in enumerate(media)]
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, session)
|
downloaded_media = await tweetHelper.DownloadMedia(tweet.id, tweet.author.id, tweet.author.username, media, session)
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ class Commands(commands.Cog):
|
|||||||
for idx, attachment in enumerate(downloaded_media):
|
for idx, attachment in enumerate(downloaded_media):
|
||||||
container = image_containers[idx]
|
container = image_containers[idx]
|
||||||
container.saved_file = attachment.file_name
|
container.saved_file = attachment.file_name
|
||||||
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes, botData.classifier, botData.vox)
|
container.vox_label, container.rating, container.tags, filtered_tags, container.phash, container.dhash, container.error_id = await classify_all(attachment.file_bytes if not attachment.is_video else None, botData.classifier, botData.vox)
|
||||||
|
|
||||||
if container.vox_label not in vox_labels:
|
if container.vox_label not in vox_labels:
|
||||||
vox_labels.append(container.vox_label)
|
vox_labels.append(container.vox_label)
|
||||||
|
|||||||
Reference in New Issue
Block a user