added_ai_proxy
This commit is contained in:
@@ -47,4 +47,22 @@ async def init_db():
|
|||||||
)
|
)
|
||||||
session.add(admin_user)
|
session.add(admin_user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
print("✅ Default admin user created (username: admin, password: admin123)")
|
print("✅ Default admin user created (username: admin, password: admin123)")
|
||||||
|
|
||||||
|
# 检查是否已存在 Gemini 配置
|
||||||
|
result = await session.execute(
|
||||||
|
select(models.AIConfig).where(models.AIConfig.provider_name == "gemini")
|
||||||
|
)
|
||||||
|
existing_gemini = result.scalars().first()
|
||||||
|
|
||||||
|
if not existing_gemini:
|
||||||
|
gemini_config = models.AIConfig(
|
||||||
|
provider_name="gemini",
|
||||||
|
api_key=os.getenv("GEMINI_API_KEY", "your-gemini-api-key"),
|
||||||
|
api_url="https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
|
||||||
|
default_model="gemini-3-flash-preview",
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
session.add(gemini_config)
|
||||||
|
await session.commit()
|
||||||
|
print("✅ Default Gemini AI configuration created")
|
||||||
51
app/main.py
51
app/main.py
@@ -6,6 +6,7 @@ from . import models, schemas, auth, database
|
|||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import httpx
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -212,3 +213,53 @@ async def test1():
|
|||||||
c = a+b
|
c = a+b
|
||||||
return {"msg": f"this is a msg {c}"}
|
return {"msg": f"this is a msg {c}"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/ai/proxy", response_model=schemas.AIResponse)
|
||||||
|
async def ai_proxy(
|
||||||
|
ai_request: schemas.AIRequest,
|
||||||
|
current_user: models.User = Depends(auth.get_current_user),
|
||||||
|
db: AsyncSession = Depends(database.get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Proxy relay for AI requests.
|
||||||
|
Fetches AI configuration from the database.
|
||||||
|
"""
|
||||||
|
# Fetch active AI config
|
||||||
|
result = await db.execute(
|
||||||
|
select(models.AIConfig).where(models.AIConfig.is_active == True)
|
||||||
|
)
|
||||||
|
config = result.scalars().first()
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise HTTPException(status_code=500, detail="AI configuration not found")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {config.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare payload
|
||||||
|
payload = ai_request.model_dump()
|
||||||
|
payload["model"] = config.default_model
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
config.api_url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
detail=f"AI provider returned an error: {e.response.text}"
|
||||||
|
)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"An error occurred while requesting AI provider: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -59,4 +59,21 @@ class AssetAssign(BaseModel):
|
|||||||
heir_name: str
|
heir_name: str
|
||||||
|
|
||||||
class DeclareGuale(BaseModel):
|
class DeclareGuale(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
|
|
||||||
|
# AI Proxy Schemas
|
||||||
|
class AIMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class AIRequest(BaseModel):
|
||||||
|
messages: List[AIMessage]
|
||||||
|
model: Optional[str] = None
|
||||||
|
|
||||||
|
class AIResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[dict]
|
||||||
|
usage: dict
|
||||||
@@ -24,5 +24,6 @@ services:
|
|||||||
- "5678:5678" # 暴露调试端口
|
- "5678:5678" # 暴露调试端口
|
||||||
environment:
|
environment:
|
||||||
- DATABASE_URL=postgresql+asyncpg://user:password@db:5432/fastapi_db
|
- DATABASE_URL=postgresql+asyncpg://user:password@db:5432/fastapi_db
|
||||||
|
- GEMINI_API_KEY=key_here
|
||||||
depends_on:
|
depends_on:
|
||||||
- db
|
- db
|
||||||
61
test/test_ai_proxy.py
Normal file
61
test/test_ai_proxy.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import httpx
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Testing against the running service on localhost
|
||||||
|
BASE_URL = "http://localhost:8000"
|
||||||
|
|
||||||
|
async def test_ai_proxy_integration():
|
||||||
|
async with httpx.AsyncClient(base_url=BASE_URL, timeout=30.0) as client:
|
||||||
|
print("--- Starting AI Proxy Integration Test ---")
|
||||||
|
|
||||||
|
# 1. Register a new user
|
||||||
|
username = f"user_{int(time.time())}"
|
||||||
|
print(f"1. Registering user: {username}")
|
||||||
|
reg_res = await client.post("/register", json={
|
||||||
|
"username": username,
|
||||||
|
"password": "testpassword"
|
||||||
|
})
|
||||||
|
if reg_res.status_code != 200:
|
||||||
|
print(f"Registration failed: {reg_res.text}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Login to get token
|
||||||
|
print("2. Logging in...")
|
||||||
|
login_res = await client.post("/token", json={
|
||||||
|
"username": username,
|
||||||
|
"password": "testpassword"
|
||||||
|
})
|
||||||
|
if login_res.status_code != 200:
|
||||||
|
print(f"Login failed: {login_res.text}")
|
||||||
|
return
|
||||||
|
token = login_res.json()["access_token"]
|
||||||
|
|
||||||
|
# 3. Request AI Proxy
|
||||||
|
print("3. Sending AI Proxy request...")
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
ai_request = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Tell me a joke."}
|
||||||
|
],
|
||||||
|
"model": "some-model"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.post("/ai/proxy", json=ai_request, headers=headers)
|
||||||
|
print(f"Response Status Code: {response.status_code}")
|
||||||
|
print(f"Response Content: {response.text[:200]}...") # Print first 200 chars
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("✅ Success: AI Proxy returned 200 OK")
|
||||||
|
elif response.status_code in [400, 401]:
|
||||||
|
print("ℹ️ Proxy worked, but AI provider returned error (likely invalid/missing API key)")
|
||||||
|
else:
|
||||||
|
print(f"❌ Unexpected status code: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Request failed: {str(e)}")
|
||||||
|
|
||||||
|
print("--- Test Completed ---")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_ai_proxy_integration())
|
||||||
Reference in New Issue
Block a user