168 lines
6.7 KiB
Python
168 lines
6.7 KiB
Python
# database.py
|
||
import os
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||
|
||
# 优先从环境变量读取(Docker 部署推荐)
|
||
SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://user:password@db:5432/fastapi_db")
|
||
|
||
engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
|
||
AsyncSessionLocal = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
||
Base = declarative_base()
|
||
|
||
async def get_db():
|
||
async with AsyncSessionLocal() as session:
|
||
yield session
|
||
# 注意:通常不在 get_db 里统一 commit,建议在 endpoint 里手动 commit
|
||
|
||
async def init_db():
|
||
async with engine.begin() as conn:
|
||
# 导入模型以确保 metadata 注册了表
|
||
from . import models
|
||
# 自动创建表
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
from . import auth
|
||
from sqlalchemy.future import select
|
||
|
||
# 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints)
|
||
tiers = [
|
||
{
|
||
"name": "Free",
|
||
"max_heirs": 1,
|
||
"weekly_token_limit": 1000,
|
||
"max_assets": 5,
|
||
"max_storage_mb": 10,
|
||
"can_use_ai_proxy": False,
|
||
"description": "Standard free tier"
|
||
},
|
||
{
|
||
"name": "Pro",
|
||
"max_heirs": 5,
|
||
"weekly_token_limit": 10000,
|
||
"max_assets": 50,
|
||
"max_storage_mb": 100,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Professional tier for active users"
|
||
},
|
||
{
|
||
"name": "Ultra",
|
||
"max_heirs": 100,
|
||
"weekly_token_limit": 100000,
|
||
"max_assets": 1000,
|
||
"max_storage_mb": 1024,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Ultimate tier for power users"
|
||
},
|
||
{
|
||
"name": "Unlimited",
|
||
"max_heirs": 9999,
|
||
"weekly_token_limit": 999999,
|
||
"max_assets": 9999,
|
||
"max_storage_mb": 999999,
|
||
"can_use_ai_proxy": True,
|
||
"description": "Internal unlimited tier"
|
||
}
|
||
]
|
||
|
||
for tier_data in tiers:
|
||
result = await session.execute(
|
||
select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"])
|
||
)
|
||
if not result.scalars().first():
|
||
new_tier = models.SubscriptionPlans(**tier_data)
|
||
session.add(new_tier)
|
||
print(f"✅ Default subscription tier '{tier_data['name']}' created")
|
||
|
||
await session.commit()
|
||
|
||
# 2. 检查并创建默认管理员用户
|
||
result = await session.execute(
|
||
select(models.User).where(models.User.username == "admin")
|
||
)
|
||
existing_admin = result.scalars().first()
|
||
|
||
if not existing_admin:
|
||
# 创建管理员用户
|
||
private_key, public_key = auth.generate_key_pair()
|
||
admin_user = models.User(
|
||
username="admin",
|
||
hashed_password=auth.hash_password("admin123"),
|
||
private_key=private_key,
|
||
public_key=public_key,
|
||
is_admin=True,
|
||
guale=False,
|
||
tier="Unlimited",
|
||
last_active_at=datetime.utcnow()
|
||
)
|
||
session.add(admin_user)
|
||
await session.commit()
|
||
print("✅ Default admin user created (username: admin, password: admin123)")
|
||
|
||
# 3. 检查是否已存在 Gemini 配置
|
||
result = await session.execute(
|
||
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
|
||
)
|
||
existing_gemini = result.scalars().first()
|
||
|
||
if not existing_gemini:
|
||
gemini_config = models.AIConfig(
|
||
provider_name="gemini",
|
||
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
|
||
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
|
||
default_model="gemini-3-flash-preview",
|
||
is_active=True
|
||
)
|
||
session.add(gemini_config)
|
||
await session.commit()
|
||
print("✅ Default Gemini AI configuration created")
|
||
|
||
# 4. 检查并初始化 AI Roles
|
||
ai_roles_data = [
|
||
{
|
||
"id": 0,
|
||
"name": 'Reflective Assistant',
|
||
"description": 'Helps you dive deep into your thoughts and feelings through meaningful reflection.',
|
||
"systemPrompt": 'You are a helpful journal assistant. Help the user reflect on their thoughts and feelings.',
|
||
"icon": 'journal-outline',
|
||
"iconFamily": 'Ionicons',
|
||
},
|
||
{
|
||
"id": 1,
|
||
"name": 'Creative Spark',
|
||
"description": 'A partner for brainstorming, creative writing, and exploring new ideas.',
|
||
"systemPrompt": 'You are a creative brainstorming partner. Help the user explore new ideas, write stories, or look at things from a fresh perspective.',
|
||
"icon": 'bulb-outline',
|
||
"iconFamily": 'Ionicons',
|
||
},
|
||
{
|
||
"id": 2,
|
||
"name": 'Action Planner',
|
||
"description": 'Focused on turning thoughts into actionable plans and organized goals.',
|
||
"systemPrompt": 'You are a productivity coach. Help the user break down their thoughts into actionable steps and clear goals.',
|
||
"icon": 'list-outline',
|
||
"iconFamily": 'Ionicons',
|
||
},
|
||
{
|
||
"id": 3,
|
||
"name": 'Empathetic Guide',
|
||
"description": 'Provides a safe, non-judgmental space for emotional support and empathy.',
|
||
"systemPrompt": 'You are a supportive and empathetic friend. Listen to the user\'s concerns and provide emotional support without judgment.',
|
||
"icon": 'heart-outline',
|
||
"iconFamily": 'Ionicons',
|
||
},
|
||
]
|
||
|
||
for role_data in ai_roles_data:
|
||
result = await session.execute(
|
||
select(models.AIRole).where(models.AIRole.id == role_data["id"])
|
||
)
|
||
if not result.scalars().first():
|
||
new_role = models.AIRole(**role_data)
|
||
session.add(new_role)
|
||
print(f"✅ AI Role '{role_data['name']}' created")
|
||
|
||
await session.commit() |