backend update 260127

This commit is contained in:
lusixing
2026-01-27 20:01:29 -08:00
parent e3fa788318
commit 4b5b6fb976
5 changed files with 131 additions and 26 deletions

View File

@@ -1,5 +1,6 @@
# database.py # database.py
import os import os
from datetime import datetime
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
@@ -43,7 +44,9 @@ async def init_db():
private_key=private_key, private_key=private_key,
public_key=public_key, public_key=public_key,
is_admin=True, is_admin=True,
guale=False guale=False,
tier="Unlimited",
last_active_at=datetime.utcnow()
) )
session.add(admin_user) session.add(admin_user)
await session.commit() await session.commit()

View File

@@ -7,6 +7,8 @@ from passlib.context import CryptContext
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import httpx import httpx
from datetime import datetime
from typing import List
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -17,6 +19,23 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
from fastapi.middleware.cors import CORSMiddleware
origins = [
"http://localhost",
"http://localhost:8081",
"http://localhost:8080",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/register", response_model=schemas.UserOut) @app.post("/register", response_model=schemas.UserOut)
async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database.get_db)): async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database.get_db)):
@@ -25,9 +44,12 @@ async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database
new_user = models.User( new_user = models.User(
username=user.username, username=user.username,
email=user.email,
hashed_password=hashed_pwd, hashed_password=hashed_pwd,
private_key=private_key, private_key=private_key,
public_key=public_key public_key=public_key,
tier="Free",
last_active_at=datetime.utcnow()
) )
db.add(new_user) db.add(new_user)
try: try:
@@ -40,12 +62,17 @@ async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已存在" detail="邮箱已存在"
)
except Exception as e:
await db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"注册失败: {str(e)}"
) )
@app.post("/login", response_model=schemas.LoginResponse)
@app.post("/token")
async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(database.get_db)): async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(database.get_db)):
result = await db.execute(select(models.User).where(models.User.username == form_data.username)) result = await db.execute(select(models.User).where(models.User.username == form_data.username))
user = result.scalars().first() user = result.scalars().first()
@@ -53,10 +80,31 @@ async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(databas
raise HTTPException(status_code=400, detail="Incorrect username or password") raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token = auth.create_access_token(data={"sub": user.username}) access_token = auth.create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
# Update last_active_at
user.last_active_at = datetime.utcnow()
await db.commit()
return {
"access_token": access_token,
"token_type": "bearer",
"user": user
}
@app.post("/assets/", response_model=schemas.AssetOut)
@app.get("/assets/get", response_model=List[schemas.AssetOut])
async def get_my_assets(
current_user: models.User = Depends(auth.get_current_user),
db: AsyncSession = Depends(database.get_db)
):
result = await db.execute(
select(models.Asset).where(models.Asset.author_id == current_user.id)
)
return result.scalars().all()
@app.post("/assets/create", response_model=schemas.AssetOut)
async def create_asset( async def create_asset(
asset: schemas.AssetCreate, asset: schemas.AssetCreate,
current_user: models.User = Depends(auth.get_current_user), current_user: models.User = Depends(auth.get_current_user),

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, ForeignKey, Text, Table, Boolean from sqlalchemy import Column, Integer, String, ForeignKey, Text, Table, Boolean, DateTime
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .database import Base from .database import Base
@@ -8,8 +8,14 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True) username = Column(String, index=True)
email = Column(String, unique=True, index=True)
hashed_password = Column(String) hashed_password = Column(String)
tier = Column(String)
tier_expires_at = Column(DateTime)
last_active_at = Column(DateTime)
# System keys # System keys
public_key = Column(String) public_key = Column(String)
private_key = Column(String) # Encrypted or raw? Storing raw for now as per req private_key = Column(String) # Encrypted or raw? Storing raw for now as per req
@@ -35,3 +41,13 @@ class Asset(Base):
author = relationship("User", foreign_keys=[author_id], back_populates="assets") author = relationship("User", foreign_keys=[author_id], back_populates="assets")
heir = relationship("User", foreign_keys=[heir_id], back_populates="inherited_assets") heir = relationship("User", foreign_keys=[heir_id], back_populates="inherited_assets")
class AIConfig(Base):
__tablename__ = "ai_configs"
id = Column(Integer, primary_key=True, index=True)
provider_name = Column(String, unique=True, index=True)
api_key = Column(String)
api_url = Column(String)
default_model = Column(String)
is_active = Column(Boolean, default=True)

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from typing import List, Optional from typing import List, Optional
from datetime import datetime
# Heir Schemas # Heir Schemas
class HeirBase(BaseModel): class HeirBase(BaseModel):
@@ -17,6 +18,7 @@ class HeirOut(HeirBase):
class UserCreate(BaseModel): class UserCreate(BaseModel):
username: str username: str
password: str password: str
email: str
class UserLogin(BaseModel): class UserLogin(BaseModel):
username: str username: str
@@ -28,9 +30,18 @@ class UserOut(BaseModel):
public_key: Optional[str] = None public_key: Optional[str] = None
is_admin: bool = False is_admin: bool = False
guale: bool = False guale: bool = False
tier: Optional[str] = None
tier_expires_at: Optional[datetime] = None
last_active_at: Optional[datetime] = None
#heirs: List[HeirOut] = [] #heirs: List[HeirOut] = []
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
class LoginResponse(BaseModel):
access_token: str
token_type: str
user: UserOut
# Asset Schemas (renamed from Article) # Asset Schemas (renamed from Article)
class AssetBase(BaseModel): class AssetBase(BaseModel):
title: str title: str

View File

@@ -6,11 +6,12 @@ import ast
BASE_URL = "http://localhost:8000" BASE_URL = "http://localhost:8000"
def register_user(username, password): def register_user(username, email, password):
url = f"{BASE_URL}/register" url = f"{BASE_URL}/register"
data = { data = {
"username": username, "username": username,
"password": password "password": password,
"email": email,
} }
response = requests.post(url, json=data) response = requests.post(url, json=data)
if response.status_code == 200: if response.status_code == 200:
@@ -21,7 +22,7 @@ def register_user(username, password):
return None return None
def login_user(username, password): def login_user(username, password):
url = f"{BASE_URL}/token" url = f"{BASE_URL}/login"
data = { data = {
"username": username, "username": username,
"password": password "password": password
@@ -35,7 +36,7 @@ def login_user(username, password):
return None return None
def create_asset(token, title, private_key_shard, content_inner_encrypted): def create_asset(token, title, private_key_shard, content_inner_encrypted):
url = f"{BASE_URL}/assets/" url = f"{BASE_URL}/assets/create"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
data = { data = {
"title": title, "title": title,
@@ -92,16 +93,28 @@ def claim_asset(token, asset_id, private_key_shard):
print(f"Failed to claim asset: {response.text}") print(f"Failed to claim asset: {response.text}")
return None return None
def get_my_assets(token):
url = f"{BASE_URL}/assets/get"
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(url, headers=headers)
if response.status_code == 200:
print(f"Assets retrieved successfully.")
return response.json()
else:
print(f"Failed to retrieve assets: {response.text}")
return None
def main(): def main():
# 1. 创建三个用户 # 1. 创建三个用户
users = [ users = [
("user1", "pass123"), ("user1", "pass123", "user1@example.com"),
("user2", "pass123"), ("user2", "pass123", "user2@example.com"),
("user3", "pass123") ("user3", "pass123", "user3@example.com")
] ]
for username, password in users: for username, password, email in users:
register_user(username, password) register_user(username, email, password)
# 1.1 用户一信息生成 # 1.1 用户一信息生成
key_engine = SentinelKeyEngine() key_engine = SentinelKeyEngine()
@@ -131,22 +144,36 @@ def main():
return return
# 3. 创建一个 asset # 3. 创建一个 asset
asset = create_asset( asset1 = create_asset(
token1, token1,
"My Secret Asset", "My Secret Asset1",
share_a, share_a,
ciphertext_1 ciphertext_1
) )
if not asset: asset2 = create_asset(
token1,
"My Secret Asset2",
share_a,
ciphertext_1
)
if not asset1 or not asset2:
print(" [失败] 创建资产失败")
return return
asset_id = asset["id"]
print(f" [输出] Asset ID: {asset_id}") # 3.1 测试 /assets/get
print("\n [测试] 获取用户资产列表")
my_assets = get_my_assets(token1)
if my_assets:
print(f" [输出] 成功获取 {len(my_assets)} 个资产")
else:
print(" [失败] 无法获取资产列表")
# 4. 指定用户 2 为继承人 # 4. 指定用户 2 为继承人
print("用户 1 指定用户 2 为继承人") print("用户 1 指定用户 2 为继承人")
assign_heir(token1, asset_id, "user2") assign_heir(token1, asset1["id"], "user2")
print("\n## 3. 继承流 (Inheritance Layer)") print("\n## 3. 继承流 (Inheritance Layer)")
# 5. Admin 宣布用户 1 挂了 # 5. Admin 宣布用户 1 挂了
@@ -166,7 +193,7 @@ def main():
# 7. 用户 2 申领资产,并带上自己的分片 (share_c) # 7. 用户 2 申领资产,并带上自己的分片 (share_c)
print("用户 2 申领资产,并带上自己的分片 (share_c)") print("用户 2 申领资产,并带上自己的分片 (share_c)")
claim_res = claim_asset(token2, asset_id, json.dumps(share_c)) claim_res = claim_asset(token2, asset1["id"], json.dumps(share_c))
if not claim_res: if not claim_res:
return return