KemoFureApi/modules/Archive/database.py

128 lines
5.0 KiB
Python

import json
from flask import Flask
from modules.Archive.endpoints.tag_stats import TagStats
from .databaseController import DatabaseController
from .endpoints.query import Query
from .endpoints.command import Command
from .endpoints.commands import Commands
from .endpoints.post import GetPost
from .endpoints.posts import GetPosts
from .endpoints.new_query import NewQuery
from .endpoints.set_action import SetAction
from .endpoints.posts_count import GetPostsCount
from .endpoints.account_stats import AccountStats
class Database:
db : DatabaseController = None
app : Flask = None
def __init__(self, api) -> None:
self.app = api.app
if "Archive" in self.app.databases:
del self.app.databases["Archive"]
self.reload_data()
self.app.databases["Archive"] = self
api.add_resource(Query, "/Query")
api.add_resource(NewQuery, "/NewQuery")
api.add_resource(Command, "/Command")
api.add_resource(Commands, "/Commands")
api.add_resource(GetPost, "/GetPost/<id>")
api.add_resource(GetPosts, "/GetPosts")
api.add_resource(GetPostsCount, "/GetPosts/Count")
api.add_resource(AccountStats, "/AccountStats")
api.add_resource(TagStats, "/TagStats")
api.add_resource(SetAction, "/SetAction")
def get_accounts(self):
query = f'''
SELECT x_handle, id FROM x_accounts
ORDER BY x_handle ASC'''
return self.db.run_query(query)
def get_account_stats(self):
query = 'SELECT * FROM "Artist Stats"'
return self.db.run_query(query)
def get_tag_stats(self):
query = 'SELECT * FROM "Tag Stats"'
return self.db.run_query(query)
def get_post(self, id):
query = f'''SELECT cast(x_posts.id as TEXT), x_posts.*, x_accounts.x_handle, x_accounts.rating from x_posts
LEFT JOIN x_accounts
ON x_posts.account_id = x_accounts.id WHERE x_posts.id = {id}
LIMIT 1'''
return self.db.run_query(query)
def build_where_query(self, artist, actions_taken = [0, 1, 2, 3], last_id = -1, include_ratings = ["SFW", "NSFW"], tags = []):
where_query = "WHERE x_posts.error_id = 0"
if last_id != -1:
where_query += f" AND x_posts.id < {last_id}"
if actions_taken != [0, 1, 2, 3]:
where_query += " AND (" + " OR ".join([f"x_posts.action_taken = {action}" for action in actions_taken]) + ")"
if include_ratings != ["SFW", "NSFW", "NSFL"]:
where_query += " AND (" + " OR ".join([f'x_accounts.rating = \'{rating}\'' for rating in include_ratings]) + ")"
if artist is not None and artist != "":
where_query += f' AND x_accounts.x_handle = \'{artist}\''
if len(tags) > 0:
tags = ", ".join(["'" + tag + "'" for tag in tags])
where_query += f' AND tags @> ARRAY[{tags}]'
return where_query
def get_posts_count(self, artist, actions_taken = [0, 1, 2, 3], last_id = -1, include_ratings = ["SFW", "NSFW"], tags = []):
where_query = self.build_where_query(artist, actions_taken, last_id, include_ratings, tags)
query = f'''
SELECT count(*) as count FROM x_posts
LEFT JOIN x_accounts
ON x_posts.account_id = x_accounts.id
{where_query}'''
_, result = self.db.run_query(query)
return result[0]["count"]
def get_posts(self, num_posts, artist, actions_taken = [0, 1, 2, 3], last_id = -1, offset = 0, include_ratings = ["SFW", "NSFW"], tags = [], order = "DESC"):
num_posts = max(1, min(num_posts, 100))
order_by = "RANDOM()" if order == "RAND" else "x_posts.id ASC" if order == "ASC" else "x_posts.id DESC"
where_query = self.build_where_query(artist, actions_taken, last_id, include_ratings, tags)
query = f'''
SELECT x_posts.id, cast(x_posts.id as TEXT) as id_str, action_taken, saved_files != \'{{}}\' as is_saved, text, files, date, x_handle, x_accounts.rating FROM x_posts
LEFT JOIN x_accounts
ON x_posts.account_id = x_accounts.id
{where_query}
ORDER BY {order_by}
LIMIT {num_posts} OFFSET {offset}'''
return self.db.run_query(query)
def wrap_query_response(self, result, mode = "json", error : str = ""):
if result is None:
response = self.app.response_class(response=error, status=400)
else:
if mode == "json":
response = self.app.response_class(
response=json.dumps(result, ensure_ascii=False, indent=1, default=str),
status=200,
mimetype='application/json'
)
elif mode == "text":
response = self.app.response_class(
response=str(result),
status=200,
mimetype='text/plain'
)
response.headers.add("Access-Control-Allow-Origin", "*")
return response
def reload_data(self):
self.db = DatabaseController()