Files
marzban_tg_bot/database.py

310 lines
12 KiB
Python

import logging
import aiosqlite
import asyncpg
from datetime import datetime, timedelta
import random
import string
import re
from config import CONFIG
logger = logging.getLogger(__name__)
class Database:
def __init__(self):
self.url = CONFIG["DATABASE_URL"]
self.is_sqlite = not self.url or self.url.startswith("sqlite")
self.conn = None
self.pool = None
async def connect(self):
if self.conn or self.pool:
return
if self.is_sqlite:
db_path = "bot.db"
self.conn = await aiosqlite.connect(db_path)
self.conn.row_factory = aiosqlite.Row
logger.info(f"Using SQLite database: {db_path}")
else:
self.pool = await asyncpg.create_pool(self.url)
logger.info("Using PostgreSQL database")
await self.create_tables()
async def execute(self, query: str, *args):
if self.is_sqlite:
query = re.sub(r'\$\d+', '?', query)
async with self.conn.execute(query, args) as cursor:
await self.conn.commit()
return cursor
else:
async with self.pool.acquire() as conn:
return await conn.execute(query, *args)
async def fetchrow(self, query: str, *args):
if self.is_sqlite:
query = re.sub(r'\$\d+', '?', query)
async with self.conn.execute(query, args) as cursor:
return await cursor.fetchone()
else:
async with self.pool.acquire() as conn:
return await conn.fetchrow(query, *args)
async def fetchval(self, query: str, *args):
if self.is_sqlite:
query = re.sub(r'\$\d+', '?', query)
async with self.conn.execute(query, args) as cursor:
row = await cursor.fetchone()
return row[0] if row else None
else:
async with self.pool.acquire() as conn:
return await conn.fetchval(query, *args)
async def fetch(self, query: str, *args):
if self.is_sqlite:
query = re.sub(r'\$\d+', '?', query)
async with self.conn.execute(query, args) as cursor:
return await cursor.fetchall()
else:
async with self.pool.acquire() as conn:
return await conn.fetch(query, *args)
async def create_tables(self):
now_default = "CURRENT_TIMESTAMP" if self.is_sqlite else "NOW()"
serial_type = "INTEGER PRIMARY KEY AUTOINCREMENT" if self.is_sqlite else "SERIAL PRIMARY KEY"
queries = [
f"""CREATE TABLE IF NOT EXISTS users (
user_id BIGINT PRIMARY KEY,
username TEXT,
marzban_username TEXT UNIQUE,
subscription_until TIMESTAMP,
data_limit INTEGER,
invited_by BIGINT,
last_traffic_reset TIMESTAMP DEFAULT {now_default},
created_at TIMESTAMP DEFAULT {now_default}
)""",
f"""CREATE TABLE IF NOT EXISTS invite_codes (
code TEXT PRIMARY KEY,
created_by BIGINT,
used_by BIGINT,
used_at TIMESTAMP,
created_at TIMESTAMP DEFAULT {now_default}
)""",
f"""CREATE TABLE IF NOT EXISTS promo_codes (
code TEXT PRIMARY KEY,
discount INTEGER,
uses_left INTEGER,
expires_at TIMESTAMP NULL,
is_unlimited BOOLEAN DEFAULT 0,
bonus_days INTEGER DEFAULT 0,
is_sticky BOOLEAN DEFAULT 0,
created_by BIGINT,
created_at TIMESTAMP DEFAULT {now_default}
)""",
f"""CREATE TABLE IF NOT EXISTS payments (
id {serial_type},
user_id BIGINT,
plan TEXT,
amount INTEGER,
promo_code TEXT,
paid_at TIMESTAMP DEFAULT {now_default}
)"""
]
for q in queries:
await self.execute(q)
await self.migrate_db()
async def migrate_db(self):
# Простая миграция для SQLite/PG добавлением колонок, если их нет
try:
await self.execute("ALTER TABLE promo_codes ADD COLUMN expires_at TIMESTAMP NULL")
except Exception:
pass # Колонка уже есть
try:
await self.execute("ALTER TABLE promo_codes ADD COLUMN is_unlimited BOOLEAN DEFAULT 0")
except Exception:
pass
try:
await self.execute("ALTER TABLE promo_codes ADD COLUMN bonus_days INTEGER DEFAULT 0")
except Exception:
pass
try:
await self.execute("ALTER TABLE promo_codes ADD COLUMN is_sticky BOOLEAN DEFAULT 0")
except Exception:
pass
try:
await self.execute("ALTER TABLE users ADD COLUMN last_traffic_reset TIMESTAMP DEFAULT CURRENT_TIMESTAMP")
except Exception:
pass
try:
await self.execute("ALTER TABLE users ADD COLUMN personal_discount INTEGER DEFAULT 0")
except Exception:
pass
async def get_user(self, user_id: int):
return await self.fetchrow("SELECT * FROM users WHERE user_id = $1", user_id)
async def get_user_by_username(self, username: str):
# Remove @ if present
username = username.lstrip('@')
return await self.fetchrow("SELECT * FROM users WHERE LOWER(username) = LOWER($1)", username)
async def search_users(self, query: str):
if query.isdigit():
# Exact ID search, but returned as list
rows = await self.fetch("SELECT * FROM users WHERE user_id = $1", int(query))
return rows
# Username partial search
term = f"%{query}%"
if self.is_sqlite:
sql = "SELECT * FROM users WHERE username LIKE $1 LIMIT 20"
else:
sql = "SELECT * FROM users WHERE username ILIKE $1 LIMIT 20"
return await self.fetch(sql, term)
async def get_all_users(self):
return await self.fetch("SELECT * FROM users")
async def create_user(self, user_id: int, username: str, marzban_username: str, invited_by: int = None):
await self.execute(
"INSERT INTO users (user_id, username, marzban_username, invited_by) VALUES ($1, $2, $3, $4)",
user_id, username, marzban_username, invited_by
)
async def update_traffic_reset_date(self, user_id: int):
now = datetime.now()
await self.execute("UPDATE users SET last_traffic_reset = $1 WHERE user_id = $2", now, user_id)
async def remove_subscription(self, user_id: int):
await self.execute("UPDATE users SET subscription_until = NULL WHERE user_id = $1", user_id)
async def update_subscription(self, user_id: int, days: int, data_limit: int):
user = await self.get_user(user_id)
sub_until = user['subscription_until']
if isinstance(sub_until, str):
try:
sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S')
except ValueError:
try:
sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S.%f')
except ValueError: # Fallback for ISO format
sub_until = datetime.fromisoformat(sub_until) if sub_until else None
# Fix for NoneType
if not sub_until:
sub_until = datetime.now()
if sub_until > datetime.now():
new_date = sub_until + timedelta(days=days)
else:
new_date = datetime.now() + timedelta(days=days)
# Если дней 9999+ (бесконечность), ставим далекое будущее
if days > 10000:
new_date = datetime(2099, 12, 31)
await self.execute(
"UPDATE users SET subscription_until = $1, data_limit = $2 WHERE user_id = $3",
new_date, data_limit, user_id
)
async def create_invite_code(self, created_by: int):
code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8))
await self.execute(
"INSERT INTO invite_codes (code, created_by) VALUES ($1, $2)",
code, created_by
)
return code
async def use_invite_code(self, code: str, user_id: int):
await self.execute(
"UPDATE invite_codes SET used_by = $1, used_at = CURRENT_TIMESTAMP WHERE code = $2",
user_id, code
)
async def check_invite_code(self, code: str):
return await self.fetchrow("SELECT * FROM invite_codes WHERE code = $1 AND used_by IS NULL", code)
async def create_promo_code(self, code: str, discount: int, uses: int, created_by: int, expires_at: datetime = None, is_unlimited: bool = False, bonus_days: int = 0, is_sticky: bool = False):
await self.execute(
"INSERT INTO promo_codes (code, discount, uses_left, created_by, expires_at, is_unlimited, bonus_days, is_sticky) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
code, discount, uses, created_by, expires_at, is_unlimited, bonus_days, is_sticky
)
async def set_user_discount(self, user_id: int, discount: int):
await self.execute("UPDATE users SET personal_discount = $1 WHERE user_id = $2", discount, user_id)
async def get_promo_code(self, code: str):
# Check basic validity
promo = await self.fetchrow("SELECT * FROM promo_codes WHERE code = $1 AND uses_left > 0", code)
if not promo:
return None
# Check expiration logic manually or via SQL if dialect allows. Let's do manual for safety across SQLite/PG
expires_at = promo['expires_at']
if expires_at:
if isinstance(expires_at, str):
try:
expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S')
except:
try:
expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S.%f')
except:
pass
if isinstance(expires_at, datetime) and expires_at < datetime.now():
return None
return promo
async def get_active_promos(self):
# Return only potentially active promos
promos = await self.fetch("SELECT * FROM promo_codes WHERE uses_left > 0")
active = []
now = datetime.now()
for p in promos:
exp = p['expires_at']
if isinstance(exp, str):
try:
exp = datetime.strptime(exp, '%Y-%m-%d %H:%M:%S')
except:
pass
if not exp or (isinstance(exp, datetime) and exp > now):
active.append(p)
return active
async def decrement_promo_usage(self, code: str):
await self.execute("UPDATE promo_codes SET uses_left = uses_left - 1 WHERE code = $1", code)
async def add_payment(self, user_id: int, plan: str, amount: int, promo_code: str = None):
await self.execute(
"INSERT INTO payments (user_id, plan, amount, promo_code) VALUES ($1, $2, $3, $4)",
user_id, plan, amount, promo_code
)
async def get_stats(self):
total_users = await self.fetchval("SELECT COUNT(*) FROM users")
active_revenue = await self.fetchval("SELECT SUM(amount) FROM payments") or 0
if self.is_sqlite:
active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > datetime('now')")
else:
active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > NOW()")
return {
"total": total_users,
"revenue": active_revenue,
"active": active_subs
}
async def get_users_for_broadcast(self):
return await self.fetch("SELECT user_id FROM users")
db = Database()