import logging import aiosqlite import asyncpg from datetime import datetime, timedelta import random import string import re from config import CONFIG logger = logging.getLogger(__name__) class Database: def __init__(self): self.url = CONFIG["DATABASE_URL"] self.is_sqlite = not self.url or self.url.startswith("sqlite") self.conn = None self.pool = None async def connect(self): if self.conn or self.pool: return if self.is_sqlite: db_path = "bot.db" self.conn = await aiosqlite.connect(db_path) self.conn.row_factory = aiosqlite.Row logger.info(f"Using SQLite database: {db_path}") else: self.pool = await asyncpg.create_pool(self.url) logger.info("Using PostgreSQL database") await self.create_tables() async def execute(self, query: str, *args): if self.is_sqlite: query = re.sub(r'\$\d+', '?', query) async with self.conn.execute(query, args) as cursor: await self.conn.commit() return cursor else: async with self.pool.acquire() as conn: return await conn.execute(query, *args) async def fetchrow(self, query: str, *args): if self.is_sqlite: query = re.sub(r'\$\d+', '?', query) async with self.conn.execute(query, args) as cursor: return await cursor.fetchone() else: async with self.pool.acquire() as conn: return await conn.fetchrow(query, *args) async def fetchval(self, query: str, *args): if self.is_sqlite: query = re.sub(r'\$\d+', '?', query) async with self.conn.execute(query, args) as cursor: row = await cursor.fetchone() return row[0] if row else None else: async with self.pool.acquire() as conn: return await conn.fetchval(query, *args) async def fetch(self, query: str, *args): if self.is_sqlite: query = re.sub(r'\$\d+', '?', query) async with self.conn.execute(query, args) as cursor: return await cursor.fetchall() else: async with self.pool.acquire() as conn: return await conn.fetch(query, *args) async def create_tables(self): now_default = "CURRENT_TIMESTAMP" if self.is_sqlite else "NOW()" serial_type = "INTEGER PRIMARY KEY AUTOINCREMENT" if self.is_sqlite else "SERIAL PRIMARY KEY" queries = [ f"""CREATE TABLE IF NOT EXISTS users ( user_id BIGINT PRIMARY KEY, username TEXT, marzban_username TEXT UNIQUE, subscription_until TIMESTAMP, data_limit INTEGER, invited_by BIGINT, last_traffic_reset TIMESTAMP DEFAULT {now_default}, created_at TIMESTAMP DEFAULT {now_default} )""", f"""CREATE TABLE IF NOT EXISTS invite_codes ( code TEXT PRIMARY KEY, created_by BIGINT, used_by BIGINT, used_at TIMESTAMP, created_at TIMESTAMP DEFAULT {now_default} )""", f"""CREATE TABLE IF NOT EXISTS promo_codes ( code TEXT PRIMARY KEY, discount INTEGER, uses_left INTEGER, expires_at TIMESTAMP NULL, is_unlimited BOOLEAN DEFAULT 0, bonus_days INTEGER DEFAULT 0, is_sticky BOOLEAN DEFAULT 0, created_by BIGINT, created_at TIMESTAMP DEFAULT {now_default} )""", f"""CREATE TABLE IF NOT EXISTS payments ( id {serial_type}, user_id BIGINT, plan TEXT, amount INTEGER, promo_code TEXT, paid_at TIMESTAMP DEFAULT {now_default} )""" ] for q in queries: await self.execute(q) await self.migrate_db() async def migrate_db(self): # Простая миграция для SQLite/PG добавлением колонок, если их нет try: await self.execute("ALTER TABLE promo_codes ADD COLUMN expires_at TIMESTAMP NULL") except Exception: pass # Колонка уже есть try: await self.execute("ALTER TABLE promo_codes ADD COLUMN is_unlimited BOOLEAN DEFAULT 0") except Exception: pass try: await self.execute("ALTER TABLE promo_codes ADD COLUMN bonus_days INTEGER DEFAULT 0") except Exception: pass try: await self.execute("ALTER TABLE promo_codes ADD COLUMN is_sticky BOOLEAN DEFAULT 0") except Exception: pass try: await self.execute("ALTER TABLE users ADD COLUMN last_traffic_reset TIMESTAMP DEFAULT CURRENT_TIMESTAMP") except Exception: pass try: await self.execute("ALTER TABLE users ADD COLUMN personal_discount INTEGER DEFAULT 0") except Exception: pass async def get_user(self, user_id: int): return await self.fetchrow("SELECT * FROM users WHERE user_id = $1", user_id) async def get_user_by_username(self, username: str): # Remove @ if present username = username.lstrip('@') return await self.fetchrow("SELECT * FROM users WHERE LOWER(username) = LOWER($1)", username) async def search_users(self, query: str): if query.isdigit(): # Exact ID search, but returned as list rows = await self.fetch("SELECT * FROM users WHERE user_id = $1", int(query)) return rows # Username partial search term = f"%{query}%" if self.is_sqlite: sql = "SELECT * FROM users WHERE username LIKE $1 LIMIT 20" else: sql = "SELECT * FROM users WHERE username ILIKE $1 LIMIT 20" return await self.fetch(sql, term) async def get_all_users(self): return await self.fetch("SELECT * FROM users") async def create_user(self, user_id: int, username: str, marzban_username: str, invited_by: int = None): await self.execute( "INSERT INTO users (user_id, username, marzban_username, invited_by) VALUES ($1, $2, $3, $4)", user_id, username, marzban_username, invited_by ) async def update_traffic_reset_date(self, user_id: int): now = datetime.now() await self.execute("UPDATE users SET last_traffic_reset = $1 WHERE user_id = $2", now, user_id) async def remove_subscription(self, user_id: int): await self.execute("UPDATE users SET subscription_until = NULL WHERE user_id = $1", user_id) async def update_subscription(self, user_id: int, days: int, data_limit: int): user = await self.get_user(user_id) sub_until = user['subscription_until'] if isinstance(sub_until, str): try: sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S') except ValueError: try: sub_until = datetime.strptime(sub_until, '%Y-%m-%d %H:%M:%S.%f') except ValueError: # Fallback for ISO format sub_until = datetime.fromisoformat(sub_until) if sub_until else None # Fix for NoneType if not sub_until: sub_until = datetime.now() if sub_until > datetime.now(): new_date = sub_until + timedelta(days=days) else: new_date = datetime.now() + timedelta(days=days) # Если дней 9999+ (бесконечность), ставим далекое будущее if days > 10000: new_date = datetime(2099, 12, 31) await self.execute( "UPDATE users SET subscription_until = $1, data_limit = $2 WHERE user_id = $3", new_date, data_limit, user_id ) async def create_invite_code(self, created_by: int): code = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) await self.execute( "INSERT INTO invite_codes (code, created_by) VALUES ($1, $2)", code, created_by ) return code async def use_invite_code(self, code: str, user_id: int): await self.execute( "UPDATE invite_codes SET used_by = $1, used_at = CURRENT_TIMESTAMP WHERE code = $2", user_id, code ) async def check_invite_code(self, code: str): return await self.fetchrow("SELECT * FROM invite_codes WHERE code = $1 AND used_by IS NULL", code) async def create_promo_code(self, code: str, discount: int, uses: int, created_by: int, expires_at: datetime = None, is_unlimited: bool = False, bonus_days: int = 0, is_sticky: bool = False): await self.execute( "INSERT INTO promo_codes (code, discount, uses_left, created_by, expires_at, is_unlimited, bonus_days, is_sticky) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", code, discount, uses, created_by, expires_at, is_unlimited, bonus_days, is_sticky ) async def set_user_discount(self, user_id: int, discount: int): await self.execute("UPDATE users SET personal_discount = $1 WHERE user_id = $2", discount, user_id) async def get_promo_code(self, code: str): # Check basic validity promo = await self.fetchrow("SELECT * FROM promo_codes WHERE code = $1 AND uses_left > 0", code) if not promo: return None # Check expiration logic manually or via SQL if dialect allows. Let's do manual for safety across SQLite/PG expires_at = promo['expires_at'] if expires_at: if isinstance(expires_at, str): try: expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S') except: try: expires_at = datetime.strptime(expires_at, '%Y-%m-%d %H:%M:%S.%f') except: pass if isinstance(expires_at, datetime) and expires_at < datetime.now(): return None return promo async def get_active_promos(self): # Return only potentially active promos promos = await self.fetch("SELECT * FROM promo_codes WHERE uses_left > 0") active = [] now = datetime.now() for p in promos: exp = p['expires_at'] if isinstance(exp, str): try: exp = datetime.strptime(exp, '%Y-%m-%d %H:%M:%S') except: pass if not exp or (isinstance(exp, datetime) and exp > now): active.append(p) return active async def decrement_promo_usage(self, code: str): await self.execute("UPDATE promo_codes SET uses_left = uses_left - 1 WHERE code = $1", code) async def add_payment(self, user_id: int, plan: str, amount: int, promo_code: str = None): await self.execute( "INSERT INTO payments (user_id, plan, amount, promo_code) VALUES ($1, $2, $3, $4)", user_id, plan, amount, promo_code ) async def get_stats(self): total_users = await self.fetchval("SELECT COUNT(*) FROM users") active_revenue = await self.fetchval("SELECT SUM(amount) FROM payments") or 0 if self.is_sqlite: active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > datetime('now')") else: active_subs = await self.fetchval("SELECT COUNT(*) FROM users WHERE subscription_until > NOW()") return { "total": total_users, "revenue": active_revenue, "active": active_subs } async def get_users_for_broadcast(self): return await self.fetch("SELECT user_id FROM users") db = Database()