322 lines
13 KiB
Python
322 lines
13 KiB
Python
import logging
|
|
import aiosqlite
|
|
import asyncpg
|
|
from datetime import datetime, timedelta
|
|
import random
|
|
import string
|
|
import re
|
|
from config import CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class Database:
|
|
def __init__(self):
|
|
self.url = CONFIG["DATABASE_URL"]
|
|
self.is_sqlite = not self.url or self.url.startswith("sqlite")
|
|
self.conn = None
|
|
self.pool = None
|
|
|
|
async def connect(self):
|
|
if self.conn or self.pool:
|
|
return
|
|
if self.is_sqlite:
|
|
db_path = "bot.db"
|
|
self.conn = await aiosqlite.connect(db_path)
|
|
self.conn.row_factory = aiosqlite.Row
|
|
logger.info(f"Using SQLite database: {db_path}")
|
|
else:
|
|
self.pool = await asyncpg.create_pool(self.url)
|
|
logger.info("Using PostgreSQL database")
|
|
await self.create_tables()
|
|
|
|
async def execute(self, query: str, *args):
|
|
if self.is_sqlite:
|
|
query = re.sub(r'\$\d+', '?', query)
|
|
async with self.conn.execute(query, args) as cursor:
|
|
await self.conn.commit()
|
|
return cursor
|
|
else:
|
|
async with self.pool.acquire() as conn:
|
|
return await conn.execute(query, *args)
|
|
|
|
async def fetchrow(self, query: str, *args):
|
|
if self.is_sqlite:
|
|
query = re.sub(r'\$\d+', '?', query)
|
|
async with self.conn.execute(query, args) as cursor:
|
|
return await cursor.fetchone()
|
|
else:
|
|
async with self.pool.acquire() as conn:
|
|
return await conn.fetchrow(query, *args)
|
|
|
|
async def fetchval(self, query: str, *args):
|
|
if self.is_sqlite:
|
|
query = re.sub(r'\$\d+', '?', query)
|
|
async with self.conn.execute(query, args) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else None
|
|
else:
|
|
async with self.pool.acquire() as conn:
|
|
return await conn.fetchval(query, *args)
|
|
|
|
async def fetch(self, query: str, *args):
|
|
if self.is_sqlite:
|
|
query = re.sub(r'\$\d+', '?', query)
|
|
async with self.conn.execute(query, args) as cursor:
|
|
return await cursor.fetchall()
|
|
else:
|
|
async with self.pool.acquire() as conn:
|
|
return await conn.fetch(query, *args)
|
|
|
|
async def create_tables(self):
|
|
now_default = "CURRENT_TIMESTAMP" if self.is_sqlite else "NOW()"
|
|
serial_type = "INTEGER PRIMARY KEY AUTOINCREMENT" if self.is_sqlite else "SERIAL PRIMARY KEY"
|
|
|
|
queries = [
|
|
f"""CREATE TABLE IF NOT EXISTS users (
|
|
user_id BIGINT PRIMARY KEY,
|
|
username TEXT,
|
|
marzban_username TEXT UNIQUE,
|
|
subscription_until TIMESTAMP,
|
|
data_limit INTEGER,
|
|
invited_by BIGINT,
|
|
last_traffic_reset TIMESTAMP DEFAULT {now_default},
|
|
created_at TIMESTAMP DEFAULT {now_default}
|
|
)""",
|
|
f"""CREATE TABLE IF NOT EXISTS invite_codes (
|
|
code TEXT PRIMARY KEY,
|
|
created_by BIGINT,
|
|
used_by BIGINT,
|
|
used_at TIMESTAMP,
|
|
created_at TIMESTAMP DEFAULT {now_default}
|
|
)""",
|
|
f"""CREATE TABLE IF NOT EXISTS promo_codes (
|
|
code TEXT PRIMARY KEY,
|
|
discount INTEGER,
|
|
uses_left INTEGER,
|
|
expires_at TIMESTAMP NULL,
|
|
is_unlimited BOOLEAN DEFAULT FALSE,
|
|
bonus_days INTEGER DEFAULT 0,
|
|
is_sticky BOOLEAN DEFAULT FALSE,
|
|
created_by BIGINT,
|
|
created_at TIMESTAMP DEFAULT {now_default}
|
|
)""",
|
|
f"""CREATE TABLE IF NOT EXISTS payments (
|
|
id {serial_type},
|
|
user_id BIGINT,
|
|
plan TEXT,
|
|
amount INTEGER,
|
|
promo_code TEXT,
|
|
paid_at TIMESTAMP DEFAULT {now_default}
|
|
)"""
|
|
]
|
|
for q in queries:
|
|
await self.execute(q)
|
|
await self.migrate_db()
|
|
|
|
async def migrate_db(self):
|
|
# Простая миграция для SQLite/PG добавлением колонок, если их нет
|
|
try:
|
|
await self.execute("ALTER TABLE promo_codes ADD COLUMN expires_at TIMESTAMP NULL")
|
|
except Exception:
|
|
pass # Колонка уже есть
|
|
|
|
try:
|
|
await self.execute("ALTER TABLE promo_codes ADD COLUMN is_unlimited BOOLEAN DEFAULT FALSE")
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self.execute("ALTER TABLE promo_codes ADD COLUMN bonus_days INTEGER DEFAULT 0")
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self.execute("ALTER TABLE promo_codes ADD COLUMN is_sticky BOOLEAN DEFAULT FALSE")
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self.execute("ALTER TABLE users ADD COLUMN last_traffic_reset TIMESTAMP DEFAULT CURRENT_TIMESTAMP")
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self.execute("ALTER TABLE users ADD COLUMN personal_discount INTEGER DEFAULT 0")
|
|
except Exception:
|
|
pass
|
|
|
|
async def get_user(self, user_id: int):
|
|
return await self.fetchrow("SELECT * FROM users WHERE user_id = $1", user_id)
|
|
|
|
async def get_user_by_username(self, username: str):
|
|
# Remove @ if present
|
|
username = username.lstrip('@')
|
|
return await self.fetchrow("SELECT * FROM users WHERE LOWER(username) = LOWER($1)", username)
|
|
|
|
async def search_users(self, query: str):
|
|
if query.isdigit():
|
|
# Exact ID search, but returned as list
|
|
rows = await self.fetch("SELECT * FROM users WHERE user_id = $1", int(query))
|
|
return rows
|
|
|
|
# Username partial search
|
|
term = f"%{query}%"
|
|
if self.is_sqlite:
|
|
sql = "SELECT * FROM users WHERE username LIKE $1 LIMIT 20"
|
|
else:
|
|
sql = "SELECT * FROM users WHERE username ILIKE $1 LIMIT 20"
|
|
return await self.fetch(sql, term)
|
|
|
|
async def get_all_users(self):
|
|
return await self.fetch("SELECT * FROM users")
|
|
|
|
async def create_user(self, user_id: int, username: str, marzban_username: str, invited_by: int = None):
|
|
await self.execute(
|
|
"INSERT INTO users (user_id, username, marzban_username, invited_by) VALUES ($1, $2, $3, $4)",
|
|
user_id, username, marzban_username, invited_by
|
|
)
|
|
|
|
async def update_traffic_reset_date(self, user_id: int):
|
|
now = datetime.now()
|
|
await self.execute("UPDATE users SET last_traffic_reset = $1 WHERE user_id = $2", now, user_id)
|
|
|
|
async def get_referrals_count(self, user_id: int):
|
|
return await self.fetchval("SELECT COUNT(*) FROM users WHERE invited_by = $1", user_id) or 0
|
|
|
|
async def get_user_payments_info(self, user_id: int):
|
|
total_amount = await self.fetchval("SELECT SUM(amount) FROM payments WHERE user_id = $1", user_id) or 0
|
|
total_count = await self.fetchval("SELECT COUNT(*) FROM payments WHERE user_id = $1", user_id) or 0
|
|
return {"total_amount": total_amount, "total_count": total_count}
|
|
|
|
async def remove_subscription(self, user_id: int):
|
|
await self.execute("UPDATE users SET subscription_until = NULL WHERE user_id = $1", user_id)
|
|
|
|
async def update_subscription(self, user_id: int, days: int, data_limit: int):
|
|
user = await self.get_user(user_id)
|
|
|
|
sub_until = user['subscription_until']
|
|
if isinstance(sub_until, str):
|
|
try:
|
|
sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S')
|
|
except ValueError:
|
|
try:
|
|
sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S.%f')
|
|
except ValueError: # Fallback for ISO format
|
|
sub_until = datetime.fromisoformat(sub_until) if sub_until else None
|
|
|
|
# Fix for NoneType
|
|
if not sub_until:
|
|
sub_until = datetime.now()
|
|
|
|
if sub_until > datetime.now():
|
|
new_date = sub_until + timedelta(days=days)
|
|
else:
|
|
new_date = datetime.now() + timedelta(days=days)
|
|
|
|
# Если дней 9999+ (бесконечность), ставим далекое будущее
|
|
if days > 10000:
|
|
new_date = datetime(2099, 12, 31)
|
|
|
|
await self.execute(
|
|
"UPDATE users SET subscription_until = $1, data_limit = $2 WHERE user_id = $3",
|
|
new_date, data_limit, user_id
|
|
)
|
|
|
|
async def create_invite_code(self, created_by: int):
|
|
code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
|
|
await self.execute(
|
|
"INSERT INTO invite_codes (code, created_by) VALUES ($1, $2)",
|
|
code, created_by
|
|
)
|
|
return code
|
|
|
|
async def use_invite_code(self, code: str, user_id: int):
|
|
await self.execute(
|
|
"UPDATE invite_codes SET used_by = $1, used_at = CURRENT_TIMESTAMP WHERE code = $2",
|
|
user_id, code
|
|
)
|
|
|
|
async def check_invite_code(self, code: str):
|
|
return await self.fetchrow("SELECT * FROM invite_codes WHERE code = $1 AND used_by IS NULL", code)
|
|
|
|
async def create_promo_code(self, code: str, discount: int, uses: int, created_by: int, expires_at: datetime = None, is_unlimited: bool = False, bonus_days: int = 0, is_sticky: bool = False):
|
|
await self.execute(
|
|
"INSERT INTO promo_codes (code, discount, uses_left, created_by, expires_at, is_unlimited, bonus_days, is_sticky) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
|
|
code, discount, uses, created_by, expires_at, is_unlimited, bonus_days, is_sticky
|
|
)
|
|
|
|
async def delete_promo_code(self, code: str):
|
|
await self.execute("DELETE FROM promo_codes WHERE code = $1", code)
|
|
|
|
async def set_user_discount(self, user_id: int, discount: int):
|
|
await self.execute("UPDATE users SET personal_discount = $1 WHERE user_id = $2", discount, user_id)
|
|
|
|
async def get_promo_code(self, code: str):
|
|
# Check basic validity
|
|
promo = await self.fetchrow("SELECT * FROM promo_codes WHERE code = $1 AND uses_left > 0", code)
|
|
if not promo:
|
|
return None
|
|
|
|
# Check expiration logic manually or via SQL if dialect allows. Let's do manual for safety across SQLite/PG
|
|
expires_at = promo['expires_at']
|
|
if expires_at:
|
|
if isinstance(expires_at, str):
|
|
try:
|
|
expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S')
|
|
except:
|
|
try:
|
|
expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S.%f')
|
|
except:
|
|
pass
|
|
if isinstance(expires_at, datetime) and expires_at < datetime.now():
|
|
return None
|
|
|
|
return promo
|
|
|
|
async def get_active_promos(self):
|
|
# Return only potentially active promos
|
|
promos = await self.fetch("SELECT * FROM promo_codes WHERE uses_left > 0")
|
|
active = []
|
|
now = datetime.now()
|
|
for p in promos:
|
|
exp = p['expires_at']
|
|
if isinstance(exp, str):
|
|
try:
|
|
exp = datetime.strptime(exp, '%Y-%m-%d %H:%M:%S')
|
|
except:
|
|
pass
|
|
|
|
if not exp or (isinstance(exp, datetime) and exp > now):
|
|
active.append(p)
|
|
return active
|
|
|
|
async def decrement_promo_usage(self, code: str):
|
|
await self.execute("UPDATE promo_codes SET uses_left = uses_left - 1 WHERE code = $1", code)
|
|
await self.execute("DELETE FROM promo_codes WHERE code = $1 AND uses_left <= 0", code)
|
|
|
|
async def add_payment(self, user_id: int, plan: str, amount: int, promo_code: str = None):
|
|
await self.execute(
|
|
"INSERT INTO payments (user_id, plan, amount, promo_code) VALUES ($1, $2, $3, $4)",
|
|
user_id, plan, amount, promo_code
|
|
)
|
|
|
|
async def get_stats(self):
|
|
total_users = await self.fetchval("SELECT COUNT(*) FROM users")
|
|
active_revenue = await self.fetchval("SELECT SUM(amount) FROM payments") or 0
|
|
|
|
if self.is_sqlite:
|
|
active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > datetime('now')")
|
|
else:
|
|
active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > NOW()")
|
|
|
|
return {
|
|
"total": total_users,
|
|
"revenue": active_revenue,
|
|
"active": active_subs
|
|
}
|
|
|
|
async def get_users_for_broadcast(self):
|
|
return await self.fetch("SELECT user_id FROM users")
|
|
|
|
db = Database()
|