added_ai_proxy

This commit is contained in:
lusixing
2026-01-26 15:57:52 -08:00
parent 22117cf9e8
commit e3fa788318
5 changed files with 150 additions and 2 deletions

View File

@@ -47,4 +47,22 @@ async def init_db():
) )
session.add(admin_user) session.add(admin_user)
await session.commit() await session.commit()
print("✅ Default admin user created (username: admin, password: admin123)") print("✅ Default admin user created (username: admin, password: admin123)")
# 检查是否已存在 Gemini 配置
result = await session.execute(
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
)
existing_gemini = result.scalars().first()
if not existing_gemini:
gemini_config = models.AIConfig(
provider_name="gemini",
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
default_model="gemini-3-flash-preview",
is_active=True
)
session.add(gemini_config)
await session.commit()
print("✅ Default Gemini AI configuration created")

View File

@@ -6,6 +6,7 @@ from . import models, schemas, auth, database
from passlib.context import CryptContext from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import httpx
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -212,3 +213,53 @@ async def test1():
c = a+b c = a+b
return {"msg": f"this is a msg {c}"} return {"msg": f"this is a msg {c}"}
@app.post("/ai/proxy", response_model=schemas.AIResponse)
async def ai_proxy(
ai_request: schemas.AIRequest,
current_user: models.User = Depends(auth.get_current_user),
db: AsyncSession = Depends(database.get_db)
):
"""
Proxy relay for AI requests.
Fetches AI configuration from the database.
"""
# Fetch active AI config
result = await db.execute(
select(models.AIConfig).where(models.AIConfig.is_active == True)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=500, detail="AI configuration not found")
headers = {
"Authorization": f"Bearer {config.api_key}",
"Content-Type": "application/json"
}
# Prepare payload
payload = ai_request.model_dump()
payload["model"] = config.default_model
async with httpx.AsyncClient() as client:
try:
response = await client.post(
config.api_url,
json=payload,
headers=headers,
timeout=30.0
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=e.response.status_code,
detail=f"AI provider returned an error: {e.response.text}"
)
except httpx.RequestError as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while requesting AI provider: {str(e)}"
)

View File

@@ -59,4 +59,21 @@ class AssetAssign(BaseModel):
heir_name: str heir_name: str
class DeclareGuale(BaseModel): class DeclareGuale(BaseModel):
username: str username: str
# AI Proxy Schemas
class AIMessage(BaseModel):
role: str
content: str
class AIRequest(BaseModel):
messages: List[AIMessage]
model: Optional[str] = None
class AIResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[dict]
usage: dict

View File

@@ -24,5 +24,6 @@ services:
- "5678:5678" # 暴露调试端口 - "5678:5678" # 暴露调试端口
environment: environment:
- DATABASE_URL=postgresql+asyncpg://user:password@db:5432/fastapi_db - DATABASE_URL=postgresql+asyncpg://user:password@db:5432/fastapi_db
- GEMINI_API_KEY=key_here
depends_on: depends_on:
- db - db

61
test/test_ai_proxy.py Normal file
View File

@@ -0,0 +1,61 @@
import httpx
import asyncio
import time
# Testing against the running service on localhost
BASE_URL = "http://localhost:8000"
async def test_ai_proxy_integration():
async with httpx.AsyncClient(base_url=BASE_URL, timeout=30.0) as client:
print("--- Starting AI Proxy Integration Test ---")
# 1. Register a new user
username = f"user_{int(time.time())}"
print(f"1. Registering user: {username}")
reg_res = await client.post("/register", json={
"username": username,
"password": "testpassword"
})
if reg_res.status_code != 200:
print(f"Registration failed: {reg_res.text}")
return
# 2. Login to get token
print("2. Logging in...")
login_res = await client.post("/token", json={
"username": username,
"password": "testpassword"
})
if login_res.status_code != 200:
print(f"Login failed: {login_res.text}")
return
token = login_res.json()["access_token"]
# 3. Request AI Proxy
print("3. Sending AI Proxy request...")
headers = {"Authorization": f"Bearer {token}"}
ai_request = {
"messages": [
{"role": "user", "content": "Tell me a joke."}
],
"model": "some-model"
}
try:
response = await client.post("/ai/proxy", json=ai_request, headers=headers)
print(f"Response Status Code: {response.status_code}")
print(f"Response Content: {response.text[:200]}...") # Print first 200 chars
if response.status_code == 200:
print("✅ Success: AI Proxy returned 200 OK")
elif response.status_code in [400, 401]:
print(" Proxy worked, but AI provider returned error (likely invalid/missing API key)")
else:
print(f"❌ Unexpected status code: {response.status_code}")
except Exception as e:
print(f"❌ Request failed: {str(e)}")
print("--- Test Completed ---")
if __name__ == "__main__":
asyncio.run(test_ai_proxy_integration())