Files
backend/app/database.py
2026-01-30 14:59:48 -08:00

121 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# database.py
import os
from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
# 优先从环境变量读取Docker 部署推荐)
SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://user:password@db:5432/fastapi_db")
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
Base = declarative_base()
async def get_db():
async with AsyncSessionLocal() as session:
yield session
# 注意:通常不在 get_db 里统一 commit建议在 endpoint 里手动 commit
async def init_db():
async with engine.begin() as conn:
# 导入模型以确保 metadata 注册了表
from . import models
# 自动创建表
await conn.run_sync(Base.metadata.create_all)
async with AsyncSessionLocal() as session:
from . import auth
from sqlalchemy.future import select
# 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints)
tiers = [
{
"name": "Free",
"max_heirs": 1,
"weekly_token_limit": 1000,
"max_assets": 5,
"max_storage_mb": 10,
"can_use_ai_proxy": False,
"description": "Standard free tier"
},
{
"name": "Pro",
"max_heirs": 5,
"weekly_token_limit": 10000,
"max_assets": 50,
"max_storage_mb": 100,
"can_use_ai_proxy": True,
"description": "Professional tier for active users"
},
{
"name": "Ultra",
"max_heirs": 100,
"weekly_token_limit": 100000,
"max_assets": 1000,
"max_storage_mb": 1024,
"can_use_ai_proxy": True,
"description": "Ultimate tier for power users"
},
{
"name": "Unlimited",
"max_heirs": 9999,
"weekly_token_limit": 999999,
"max_assets": 9999,
"max_storage_mb": 999999,
"can_use_ai_proxy": True,
"description": "Internal unlimited tier"
}
]
for tier_data in tiers:
result = await session.execute(
select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"])
)
if not result.scalars().first():
new_tier = models.SubscriptionPlans(**tier_data)
session.add(new_tier)
print(f"✅ Default subscription tier '{tier_data['name']}' created")
await session.commit()
# 2. 检查并创建默认管理员用户
result = await session.execute(
select(models.User).where(models.User.username == "admin")
)
existing_admin = result.scalars().first()
if not existing_admin:
# 创建管理员用户
private_key, public_key = auth.generate_key_pair()
admin_user = models.User(
username="admin",
hashed_password=auth.hash_password("admin123"),
private_key=private_key,
public_key=public_key,
is_admin=True,
guale=False,
tier="Unlimited",
last_active_at=datetime.utcnow()
)
session.add(admin_user)
await session.commit()
print("✅ Default admin user created (username: admin, password: admin123)")
# 3. 检查是否已存在 Gemini 配置
result = await session.execute(
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
)
existing_gemini = result.scalars().first()
if not existing_gemini:
gemini_config = models.AIConfig(
provider_name="gemini",
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
default_model="gemini-3-flash-preview",
is_active=True
)
session.add(gemini_config)
await session.commit()
print("✅ Default Gemini AI configuration created")