Files
backend/app/database.py
2026-02-02 22:22:11 -08:00

168 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# database.py
import os
from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
# 优先从环境变量读取Docker 部署推荐)
SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://user:password@db:5432/fastapi_db")
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
Base = declarative_base()
async def get_db():
async with AsyncSessionLocal() as session:
yield session
# 注意:通常不在 get_db 里统一 commit建议在 endpoint 里手动 commit
async def init_db():
async with engine.begin() as conn:
# 导入模型以确保 metadata 注册了表
from . import models
# 自动创建表
await conn.run_sync(Base.metadata.create_all)
async with AsyncSessionLocal() as session:
from . import auth
from sqlalchemy.future import select
# 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints)
tiers = [
{
"name": "Free",
"max_heirs": 1,
"weekly_token_limit": 1000,
"max_assets": 5,
"max_storage_mb": 10,
"can_use_ai_proxy": False,
"description": "Standard free tier"
},
{
"name": "Pro",
"max_heirs": 5,
"weekly_token_limit": 10000,
"max_assets": 50,
"max_storage_mb": 100,
"can_use_ai_proxy": True,
"description": "Professional tier for active users"
},
{
"name": "Ultra",
"max_heirs": 100,
"weekly_token_limit": 100000,
"max_assets": 1000,
"max_storage_mb": 1024,
"can_use_ai_proxy": True,
"description": "Ultimate tier for power users"
},
{
"name": "Unlimited",
"max_heirs": 9999,
"weekly_token_limit": 999999,
"max_assets": 9999,
"max_storage_mb": 999999,
"can_use_ai_proxy": True,
"description": "Internal unlimited tier"
}
]
for tier_data in tiers:
result = await session.execute(
select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"])
)
if not result.scalars().first():
new_tier = models.SubscriptionPlans(**tier_data)
session.add(new_tier)
print(f"✅ Default subscription tier '{tier_data['name']}' created")
await session.commit()
# 2. 检查并创建默认管理员用户
result = await session.execute(
select(models.User).where(models.User.username == "admin")
)
existing_admin = result.scalars().first()
if not existing_admin:
# 创建管理员用户
private_key, public_key = auth.generate_key_pair()
admin_user = models.User(
username="admin",
hashed_password=auth.hash_password("admin123"),
private_key=private_key,
public_key=public_key,
is_admin=True,
guale=False,
tier="Unlimited",
last_active_at=datetime.utcnow()
)
session.add(admin_user)
await session.commit()
print("✅ Default admin user created (username: admin, password: admin123)")
# 3. 检查是否已存在 Gemini 配置
result = await session.execute(
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
)
existing_gemini = result.scalars().first()
if not existing_gemini:
gemini_config = models.AIConfig(
provider_name="gemini",
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
default_model="gemini-3-flash-preview",
is_active=True
)
session.add(gemini_config)
await session.commit()
print("✅ Default Gemini AI configuration created")
# 4. 检查并初始化 AI Roles
ai_roles_data = [
{
"id": 0,
"name": 'Reflective Assistant',
"description": 'Helps you dive deep into your thoughts and feelings through meaningful reflection.',
"systemPrompt": 'You are a helpful journal assistant. Help the user reflect on their thoughts and feelings.',
"icon": 'journal-outline',
"iconFamily": 'Ionicons',
},
{
"id": 1,
"name": 'Creative Spark',
"description": 'A partner for brainstorming, creative writing, and exploring new ideas.',
"systemPrompt": 'You are a creative brainstorming partner. Help the user explore new ideas, write stories, or look at things from a fresh perspective.',
"icon": 'bulb-outline',
"iconFamily": 'Ionicons',
},
{
"id": 2,
"name": 'Action Planner',
"description": 'Focused on turning thoughts into actionable plans and organized goals.',
"systemPrompt": 'You are a productivity coach. Help the user break down their thoughts into actionable steps and clear goals.',
"icon": 'list-outline',
"iconFamily": 'Ionicons',
},
{
"id": 3,
"name": 'Empathetic Guide',
"description": 'Provides a safe, non-judgmental space for emotional support and empathy.',
"systemPrompt": 'You are a supportive and empathetic friend. Listen to the user\'s concerns and provide emotional support without judgment.',
"icon": 'heart-outline',
"iconFamily": 'Ionicons',
},
]
for role_data in ai_roles_data:
result = await session.execute(
select(models.AIRole).where(models.AIRole.id == role_data["id"])
)
if not result.scalars().first():
new_role = models.AIRole(**role_data)
session.add(new_role)
print(f"✅ AI Role '{role_data['name']}' created")
await session.commit()