import json from flask import Flask from modules.Archive.endpoints.tag_stats import TagStats from .databaseController import DatabaseController from .endpoints.query import Query from .endpoints.command import Command from .endpoints.commands import Commands from .endpoints.post import GetPost from .endpoints.posts import GetPosts from .endpoints.new_query import NewQuery from .endpoints.set_action import SetAction from .endpoints.posts_count import GetPostsCount from .endpoints.account_stats import AccountStats class Database: db : DatabaseController = None app : Flask = None def __init__(self, api) -> None: self.app = api.app if "Archive" in self.app.databases: del self.app.databases["Archive"] self.reload_data() self.app.databases["Archive"] = self api.add_resource(Query, "/Query") api.add_resource(NewQuery, "/NewQuery") api.add_resource(Command, "/Command") api.add_resource(Commands, "/Commands") api.add_resource(GetPost, "/GetPost/") api.add_resource(GetPosts, "/GetPosts") api.add_resource(GetPostsCount, "/GetPosts/Count") api.add_resource(AccountStats, "/AccountStats") api.add_resource(TagStats, "/TagStats") api.add_resource(SetAction, "/SetAction") def get_accounts(self): query = f''' SELECT x_handle, id FROM x_accounts ORDER BY x_handle ASC''' return self.db.run_query(query) def get_account_stats(self): query = 'SELECT * FROM "Artist Stats"' return self.db.run_query(query) def get_tag_stats(self): query = 'SELECT * FROM "Tag Stats"' return self.db.run_query(query) def get_post(self, id): query = f'''SELECT cast(x_posts.id as TEXT), x_posts.*, x_accounts.x_handle, x_accounts.rating from x_posts LEFT JOIN x_accounts ON x_posts.account_id = x_accounts.id WHERE x_posts.id = {id} LIMIT 1''' return self.db.run_query(query) def build_where_query(self, artist, actions_taken = [0, 1, 2, 3], last_id = -1, include_ratings = ["SFW", "NSFW"], tags = []): where_query = "WHERE x_posts.error_id = 0" if last_id != -1: where_query += f" AND x_posts.id < {last_id}" if actions_taken != [0, 1, 2, 3]: where_query += " AND (" + " OR ".join([f"x_posts.action_taken = {action}" for action in actions_taken]) + ")" if include_ratings != ["SFW", "NSFW", "NSFL"]: where_query += " AND (" + " OR ".join([f'x_accounts.rating = \'{rating}\'' for rating in include_ratings]) + ")" if artist is not None and artist != "": where_query += f' AND x_accounts.x_handle = \'{artist}\'' if len(tags) > 0: tags = ", ".join(["'" + tag + "'" for tag in tags]) where_query += f' AND tags @> ARRAY[{tags}]' return where_query def get_posts_count(self, artist, actions_taken = [0, 1, 2, 3], last_id = -1, include_ratings = ["SFW", "NSFW"], tags = []): where_query = self.build_where_query(artist, actions_taken, last_id, include_ratings, tags) query = f''' SELECT count(*) as count FROM x_posts LEFT JOIN x_accounts ON x_posts.account_id = x_accounts.id {where_query}''' _, result = self.db.run_query(query) return result[0]["count"] def get_posts(self, num_posts, artist, actions_taken = [0, 1, 2, 3], last_id = -1, offset = 0, include_ratings = ["SFW", "NSFW"], tags = [], order = "DESC"): num_posts = max(1, min(num_posts, 100)) order_by = "RANDOM()" if order == "RAND" else "x_posts.id ASC" if order == "ASC" else "x_posts.id DESC" where_query = self.build_where_query(artist, actions_taken, last_id, include_ratings, tags) query = f''' SELECT x_posts.id, cast(x_posts.id as TEXT) as id_str, action_taken, saved_files != \'{{}}\' as is_saved, text, files, date, x_handle, x_accounts.rating FROM x_posts LEFT JOIN x_accounts ON x_posts.account_id = x_accounts.id {where_query} ORDER BY {order_by} LIMIT {num_posts} OFFSET {offset}''' return self.db.run_query(query) def wrap_query_response(self, result, mode = "json", error : str = ""): if result is None: response = self.app.response_class(response=error, status=400) else: if mode == "json": response = self.app.response_class( response=json.dumps(result, ensure_ascii=False, indent=1, default=str), status=200, mimetype='application/json' ) elif mode == "text": response = self.app.response_class( response=str(result), status=200, mimetype='text/plain' ) response.headers.add("Access-Control-Allow-Origin", "*") return response def reload_data(self): self.db = DatabaseController()