408 lines
12 KiB
Python
408 lines
12 KiB
Python
from fastapi import FastAPI, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload
|
|
from . import models, schemas, auth, database
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy.exc import IntegrityError
|
|
from contextlib import asynccontextmanager
|
|
import httpx
|
|
from datetime import datetime, timedelta
|
|
from typing import List
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# 启动时执行:创建表
|
|
await database.init_db()
|
|
yield
|
|
# 关闭时执行(如果需要)
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
origins = [
|
|
"http://localhost",
|
|
"http://localhost:8081",
|
|
"http://localhost:8080",
|
|
"http://localhost:3000",
|
|
]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.post("/register", response_model=schemas.UserOut)
|
|
async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database.get_db)):
|
|
hashed_pwd = auth.hash_password(user.password)
|
|
private_key, public_key = auth.generate_key_pair()
|
|
|
|
new_user = models.User(
|
|
username=user.username,
|
|
email=user.email,
|
|
hashed_password=hashed_pwd,
|
|
private_key=private_key,
|
|
public_key=public_key,
|
|
tier="Free",
|
|
last_active_at=datetime.utcnow()
|
|
)
|
|
db.add(new_user)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(new_user)
|
|
return new_user
|
|
|
|
except IntegrityError:
|
|
# 发生唯一约束冲突时回滚并报错
|
|
await db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="邮箱已存在"
|
|
)
|
|
except Exception as e:
|
|
await db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"注册失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@app.post("/login", response_model=schemas.LoginResponse)
|
|
async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(database.get_db)):
|
|
result = await db.execute(select(models.User).where(models.User.username == form_data.username))
|
|
user = result.scalars().first()
|
|
if not user or not auth.verify_password(form_data.password, user.hashed_password):
|
|
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
|
|
|
access_token = auth.create_access_token(data={"sub": user.username})
|
|
|
|
# Update last_active_at
|
|
user.last_active_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"user": user
|
|
}
|
|
|
|
|
|
|
|
@app.get("/assets/get", response_model=List[schemas.AssetOut])
|
|
async def get_my_assets(
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
result = await db.execute(
|
|
select(models.Asset).where(models.Asset.author_id == current_user.id)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.post("/assets/create", response_model=schemas.AssetOut)
|
|
async def create_asset(
|
|
asset: schemas.AssetCreate,
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
# Encrypt the inner content using user's public key
|
|
encrypted_content = auth.encrypt_data(asset.content_inner_encrypted, current_user.public_key)
|
|
|
|
new_asset = models.Asset(
|
|
title=asset.title,
|
|
type=asset.type,
|
|
content_outer_encrypted=encrypted_content,
|
|
private_key_shard=asset.private_key_shard,
|
|
author_id=current_user.id,
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow()
|
|
)
|
|
db.add(new_asset)
|
|
await db.commit()
|
|
await db.refresh(new_asset)
|
|
return new_asset
|
|
|
|
|
|
@app.post("/assets/claim")
|
|
async def claim_asset(
|
|
asset_claim: schemas.AssetClaim,
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
# Fetch asset with author loaded
|
|
result = await db.execute(
|
|
select(models.Asset)
|
|
.options(selectinload(models.Asset.author))
|
|
.where(models.Asset.id == asset_claim.asset_id)
|
|
)
|
|
asset = result.scalars().first()
|
|
|
|
if not asset:
|
|
raise HTTPException(status_code=404, detail="Asset not found")
|
|
|
|
# 1. 验证用户是否是继承人
|
|
if asset.heir_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You are not the designated heir for this asset"
|
|
)
|
|
|
|
# 2. 验证所有人是否已经挂了 (guale)
|
|
if not asset.author.guale:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="The owner of this asset is still alive. You cannot claim it yet."
|
|
)
|
|
|
|
# 3. 验证通过后用asset所有人的private_key解密内容
|
|
try:
|
|
decrypted_content = auth.decrypt_data(
|
|
asset.content_outer_encrypted,
|
|
asset.author.private_key
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to decrypt asset: {str(e)}"
|
|
)
|
|
|
|
return {
|
|
"asset_id": asset.id,
|
|
"title": asset.title,
|
|
"decrypted_content": decrypted_content,
|
|
"server_shard_key": asset.private_key_shard
|
|
}
|
|
|
|
|
|
@app.post("/assets/assign")
|
|
async def assign_asset(
|
|
assignment: schemas.AssetAssign,
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
# Fetch Asset
|
|
result = await db.execute(
|
|
select(models.Asset)
|
|
.options(selectinload(models.Asset.heir))
|
|
.where(models.Asset.id == assignment.asset_id)
|
|
)
|
|
asset = result.scalars().first()
|
|
|
|
if not asset:
|
|
raise HTTPException(status_code=404, detail="Asset not found")
|
|
|
|
if asset.author_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="Not authorized to assign this asset")
|
|
|
|
|
|
heir_result = await db.execute(
|
|
select(models.User).where(
|
|
models.User.email == assignment.heir_email
|
|
)
|
|
)
|
|
heir_user = heir_result.scalars().first()
|
|
|
|
if not heir_user:
|
|
raise HTTPException(status_code=404, detail="Heir not found")
|
|
|
|
if heir_user.id == current_user.id:
|
|
asset.heir = None
|
|
await db.commit()
|
|
#raise HTTPException(status_code=403, detail="You cannot assign an asset to yourself")
|
|
return {"message": "Asset unassigned"}
|
|
|
|
asset.heir = heir_user
|
|
await db.commit()
|
|
|
|
return {"message": f"Asset assigned to {assignment.heir_email}"}
|
|
|
|
|
|
@app.get("/assets/designated", response_model=List[schemas.AssetOut])
|
|
async def get_designated_assets(
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
"""
|
|
Query assets where the current user is the designated heir.
|
|
"""
|
|
result = await db.execute(
|
|
select(models.Asset).where(models.Asset.heir_id == current_user.id)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
|
|
|
|
@app.post("/admin/declare-guale")
|
|
async def declare_user_guale(
|
|
declare: schemas.DeclareGuale,
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
# Check if current user is admin
|
|
if not current_user.is_admin:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Only administrators can declare users as deceased"
|
|
)
|
|
|
|
# Find the target user
|
|
result = await db.execute(
|
|
select(models.User).where(models.User.username == declare.username)
|
|
)
|
|
target_user = result.scalars().first()
|
|
|
|
if not target_user:
|
|
raise HTTPException(status_code=404, detail="User not found")
|
|
|
|
# Set guale to True
|
|
target_user.guale = True
|
|
await db.commit()
|
|
|
|
return {
|
|
"message": f"User {declare.username} has been declared as deceased",
|
|
"username": target_user.username,
|
|
"guale": target_user.guale
|
|
}
|
|
|
|
|
|
async def get_or_create_token_usage(user_id: int, db: AsyncSession):
|
|
# Get current week start (Monday)
|
|
now = datetime.utcnow()
|
|
monday = now - timedelta(days=now.weekday())
|
|
week_start = monday.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
result = await db.execute(
|
|
select(models.UserTokenUsage).where(models.UserTokenUsage.user_id == user_id)
|
|
)
|
|
usage = result.scalars().first()
|
|
|
|
if not usage:
|
|
usage = models.UserTokenUsage(
|
|
user_id=user_id,
|
|
tokens_used=0,
|
|
last_reset_at=week_start
|
|
)
|
|
db.add(usage)
|
|
await db.commit()
|
|
await db.refresh(usage)
|
|
#每周重置token使用情况
|
|
elif usage.last_reset_at < week_start:
|
|
usage.tokens_used = 0
|
|
usage.last_reset_at = week_start
|
|
await db.commit()
|
|
|
|
return usage
|
|
|
|
@app.post("/ai/proxy", response_model=schemas.AIResponse)
|
|
async def ai_proxy(
|
|
ai_request: schemas.AIRequest,
|
|
current_user: models.User = Depends(auth.get_current_user),
|
|
db: AsyncSession = Depends(database.get_db)
|
|
):
|
|
"""
|
|
Proxy relay for AI requests.
|
|
Fetches AI configuration from the database.
|
|
"""
|
|
def get_quota_exceeded_response():
|
|
return {
|
|
"id": f"chatcmpl-{int(datetime.utcnow().timestamp())}",
|
|
"object": "chat.completion",
|
|
"created": int(datetime.utcnow().timestamp()),
|
|
"model": "quota-manager",
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "quota exceeded, please upgrade plan"
|
|
},
|
|
"finish_reason": "stop"
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 0,
|
|
"completion_tokens": 0,
|
|
"total_tokens": 0
|
|
}
|
|
}
|
|
|
|
# 1. 检查 Tier 是否允许使用 AI
|
|
result = await db.execute(
|
|
select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == current_user.tier)
|
|
)
|
|
tier_plan = result.scalars().first()
|
|
|
|
if not tier_plan or not tier_plan.can_use_ai_proxy:
|
|
return get_quota_exceeded_response()
|
|
|
|
# 2. 检查本周 Token 使用是否超过限制
|
|
usage_record = await get_or_create_token_usage(current_user.id, db)
|
|
if usage_record.tokens_used >= tier_plan.weekly_token_limit:
|
|
return get_quota_exceeded_response()
|
|
|
|
# Fetch active AI config
|
|
result = await db.execute(
|
|
select(models.AIConfig).where(models.AIConfig.is_active == True)
|
|
)
|
|
config = result.scalars().first()
|
|
|
|
if not config:
|
|
raise HTTPException(status_code=500, detail="AI configuration not found")
|
|
|
|
headers = {
|
|
"Authorization": f"Bearer {config.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
# Prepare payload
|
|
payload = ai_request.model_dump()
|
|
payload["model"] = config.default_model
|
|
|
|
current_user.last_active_at = datetime.utcnow()
|
|
await db.commit()
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
try:
|
|
response = await client.post(
|
|
config.api_url,
|
|
json=payload,
|
|
headers=headers,
|
|
timeout=30.0
|
|
)
|
|
response.raise_for_status()
|
|
ai_data = response.json()
|
|
|
|
# 3. 记录使用的 Token
|
|
total_tokens = ai_data.get("usage", {}).get("total_tokens", 0)
|
|
if total_tokens > 0:
|
|
usage_record.tokens_used += total_tokens
|
|
await db.commit()
|
|
|
|
return ai_data
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(
|
|
status_code=e.response.status_code,
|
|
detail=f"AI provider returned an error: {e.response.text}"
|
|
)
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"An error occurred while requesting AI provider: {str(e)}"
|
|
)
|
|
|
|
# 用于测试热加载
|
|
@app.post("/post1")
|
|
async def test1():
|
|
a=2
|
|
b=3
|
|
c = a+b
|
|
return {"msg": f"this is a msg {c}"}
|
|
|
|
|