121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
# database.py
|
||
import os
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||
|
||
# 优先从环境变量读取(Docker 部署推荐)
|
||
SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://user:password@db:5432/fastapi_db")
|
||
|
||
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
|
||
AsyncSessionLocal = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
||
Base = declarative_base()
|
||
|
||
async def get_db():
|
||
async with AsyncSessionLocal() as session:
|
||
yield session
|
||
# 注意:通常不在 get_db 里统一 commit,建议在 endpoint 里手动 commit
|
||
|
||
async def init_db():
|
||
async with engine.begin() as conn:
|
||
# 导入模型以确保 metadata 注册了表
|
||
from . import models
|
||
# 自动创建表
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
from . import auth
|
||
from sqlalchemy.future import select
|
||
|
||
# 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints)
|
||
tiers = [
|
||
{
|
||
"name": "Free",
|
||
"max_heirs": 1,
|
||
"weekly_token_limit": 1000,
|
||
"max_assets": 5,
|
||
"max_storage_mb": 10,
|
||
"can_use_ai_proxy": False,
|
||
"description": "Standard free tier"
|
||
},
|
||
{
|
||
"name": "Pro",
|
||
"max_heirs": 5,
|
||
"weekly_token_limit": 10000,
|
||
"max_assets": 50,
|
||
"max_storage_mb": 100,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Professional tier for active users"
|
||
},
|
||
{
|
||
"name": "Ultra",
|
||
"max_heirs": 100,
|
||
"weekly_token_limit": 100000,
|
||
"max_assets": 1000,
|
||
"max_storage_mb": 1024,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Ultimate tier for power users"
|
||
},
|
||
{
|
||
"name": "Unlimited",
|
||
"max_heirs": 9999,
|
||
"weekly_token_limit": 999999,
|
||
"max_assets": 9999,
|
||
"max_storage_mb": 999999,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Internal unlimited tier"
|
||
}
|
||
]
|
||
|
||
for tier_data in tiers:
|
||
result = await session.execute(
|
||
select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"])
|
||
)
|
||
if not result.scalars().first():
|
||
new_tier = models.SubscriptionPlans(**tier_data)
|
||
session.add(new_tier)
|
||
print(f"✅ Default subscription tier '{tier_data['name']}' created")
|
||
|
||
await session.commit()
|
||
|
||
# 2. 检查并创建默认管理员用户
|
||
result = await session.execute(
|
||
select(models.User).where(models.User.username == "admin")
|
||
)
|
||
existing_admin = result.scalars().first()
|
||
|
||
if not existing_admin:
|
||
# 创建管理员用户
|
||
private_key, public_key = auth.generate_key_pair()
|
||
admin_user = models.User(
|
||
username="admin",
|
||
hashed_password=auth.hash_password("admin123"),
|
||
private_key=private_key,
|
||
public_key=public_key,
|
||
is_admin=True,
|
||
guale=False,
|
||
tier="Unlimited",
|
||
last_active_at=datetime.utcnow()
|
||
)
|
||
session.add(admin_user)
|
||
await session.commit()
|
||
print("✅ Default admin user created (username: admin, password: admin123)")
|
||
|
||
# 3. 检查是否已存在 Gemini 配置
|
||
result = await session.execute(
|
||
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
|
||
)
|
||
existing_gemini = result.scalars().first()
|
||
|
||
if not existing_gemini:
|
||
gemini_config = models.AIConfig(
|
||
provider_name="gemini",
|
||
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
|
||
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
|
||
default_model="gemini-3-flash-preview",
|
||
is_active=True
|
||
)
|
||
session.add(gemini_config)
|
||
await session.commit()
|
||
print("✅ Default Gemini AI configuration created") |