diff --git a/app/database.py b/app/database.py index aaaff19..5cd2ea1 100644 --- a/app/database.py +++ b/app/database.py @@ -1,5 +1,6 @@ # database.py import os +from datetime import datetime from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base @@ -43,7 +44,9 @@ async def init_db(): private_key=private_key, public_key=public_key, is_admin=True, - guale=False + guale=False, + tier="Unlimited", + last_active_at=datetime.utcnow() ) session.add(admin_user) await session.commit() diff --git a/app/main.py b/app/main.py index 1fd127c..fd74305 100644 --- a/app/main.py +++ b/app/main.py @@ -7,6 +7,8 @@ from passlib.context import CryptContext from sqlalchemy.exc import IntegrityError from contextlib import asynccontextmanager import httpx +from datetime import datetime +from typing import List @asynccontextmanager async def lifespan(app: FastAPI): @@ -17,6 +19,23 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +from fastapi.middleware.cors import CORSMiddleware + +origins = [ + "http://localhost", + "http://localhost:8081", + "http://localhost:8080", + "http://localhost:3000", +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + @app.post("/register", response_model=schemas.UserOut) async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database.get_db)): @@ -25,9 +44,12 @@ async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database new_user = models.User( username=user.username, + email=user.email, hashed_password=hashed_pwd, private_key=private_key, - public_key=public_key + public_key=public_key, + tier="Free", + last_active_at=datetime.utcnow() ) db.add(new_user) try: @@ -40,12 +62,17 @@ async def register(user: schemas.UserCreate, db: AsyncSession = Depends(database await db.rollback() raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="用户名已存在" + detail="邮箱已存在" + ) + except Exception as e: + await db.rollback() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"注册失败: {str(e)}" ) - -@app.post("/token") +@app.post("/login", response_model=schemas.LoginResponse) async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(database.get_db)): result = await db.execute(select(models.User).where(models.User.username == form_data.username)) user = result.scalars().first() @@ -53,10 +80,31 @@ async def login(form_data: schemas.UserLogin, db: AsyncSession = Depends(databas raise HTTPException(status_code=400, detail="Incorrect username or password") access_token = auth.create_access_token(data={"sub": user.username}) - return {"access_token": access_token, "token_type": "bearer"} + + # Update last_active_at + user.last_active_at = datetime.utcnow() + await db.commit() + + return { + "access_token": access_token, + "token_type": "bearer", + "user": user + } -@app.post("/assets/", response_model=schemas.AssetOut) + +@app.get("/assets/get", response_model=List[schemas.AssetOut]) +async def get_my_assets( + current_user: models.User = Depends(auth.get_current_user), + db: AsyncSession = Depends(database.get_db) +): + result = await db.execute( + select(models.Asset).where(models.Asset.author_id == current_user.id) + ) + return result.scalars().all() + + +@app.post("/assets/create", response_model=schemas.AssetOut) async def create_asset( asset: schemas.AssetCreate, current_user: models.User = Depends(auth.get_current_user), diff --git a/app/models.py b/app/models.py index ac9fa06..c0f0bc2 100644 --- a/app/models.py +++ b/app/models.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, ForeignKey, Text, Table, Boolean +from sqlalchemy import Column, Integer, String, ForeignKey, Text, Table, Boolean, DateTime from sqlalchemy.orm import relationship from .database import Base @@ -8,8 +8,14 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True, index=True) + username = Column(String, index=True) + email = Column(String, unique=True, index=True) hashed_password = Column(String) + + tier = Column(String) + tier_expires_at = Column(DateTime) + last_active_at = Column(DateTime) + # System keys public_key = Column(String) private_key = Column(String) # Encrypted or raw? Storing raw for now as per req @@ -34,4 +40,14 @@ class Asset(Base): private_key_shard = Column(String) author = relationship("User", foreign_keys=[author_id], back_populates="assets") - heir = relationship("User", foreign_keys=[heir_id], back_populates="inherited_assets") \ No newline at end of file + heir = relationship("User", foreign_keys=[heir_id], back_populates="inherited_assets") + +class AIConfig(Base): + __tablename__ = "ai_configs" + + id = Column(Integer, primary_key=True, index=True) + provider_name = Column(String, unique=True, index=True) + api_key = Column(String) + api_url = Column(String) + default_model = Column(String) + is_active = Column(Boolean, default=True) \ No newline at end of file diff --git a/app/schemas.py b/app/schemas.py index 2d3b854..d4d1257 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, ConfigDict from typing import List, Optional +from datetime import datetime # Heir Schemas class HeirBase(BaseModel): @@ -17,6 +18,7 @@ class HeirOut(HeirBase): class UserCreate(BaseModel): username: str password: str + email: str class UserLogin(BaseModel): username: str @@ -28,9 +30,18 @@ class UserOut(BaseModel): public_key: Optional[str] = None is_admin: bool = False guale: bool = False + tier: Optional[str] = None + tier_expires_at: Optional[datetime] = None + last_active_at: Optional[datetime] = None #heirs: List[HeirOut] = [] model_config = ConfigDict(from_attributes=True) +class LoginResponse(BaseModel): + access_token: str + token_type: str + user: UserOut + + # Asset Schemas (renamed from Article) class AssetBase(BaseModel): title: str diff --git a/test/test_scenario.py b/test/test_scenario.py index 0f7268c..dde0f7d 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -6,11 +6,12 @@ import ast BASE_URL = "http://localhost:8000" -def register_user(username, password): +def register_user(username, email, password): url = f"{BASE_URL}/register" data = { "username": username, - "password": password + "password": password, + "email": email, } response = requests.post(url, json=data) if response.status_code == 200: @@ -21,7 +22,7 @@ def register_user(username, password): return None def login_user(username, password): - url = f"{BASE_URL}/token" + url = f"{BASE_URL}/login" data = { "username": username, "password": password @@ -35,7 +36,7 @@ def login_user(username, password): return None def create_asset(token, title, private_key_shard, content_inner_encrypted): - url = f"{BASE_URL}/assets/" + url = f"{BASE_URL}/assets/create" headers = {"Authorization": f"Bearer {token}"} data = { "title": title, @@ -92,16 +93,28 @@ def claim_asset(token, asset_id, private_key_shard): print(f"Failed to claim asset: {response.text}") return None + +def get_my_assets(token): + url = f"{BASE_URL}/assets/get" + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(url, headers=headers) + if response.status_code == 200: + print(f"Assets retrieved successfully.") + return response.json() + else: + print(f"Failed to retrieve assets: {response.text}") + return None + def main(): # 1. 创建三个用户 users = [ - ("user1", "pass123"), - ("user2", "pass123"), - ("user3", "pass123") + ("user1", "pass123", "user1@example.com"), + ("user2", "pass123", "user2@example.com"), + ("user3", "pass123", "user3@example.com") ] - for username, password in users: - register_user(username, password) + for username, password, email in users: + register_user(username, email, password) # 1.1 用户一信息生成 key_engine = SentinelKeyEngine() @@ -131,22 +144,36 @@ def main(): return # 3. 创建一个 asset - asset = create_asset( + asset1 = create_asset( token1, - "My Secret Asset", + "My Secret Asset1", share_a, ciphertext_1 ) - if not asset: + asset2 = create_asset( + token1, + "My Secret Asset2", + share_a, + ciphertext_1 + ) + + if not asset1 or not asset2: + print(" [失败] 创建资产失败") return - asset_id = asset["id"] - print(f" [输出] Asset ID: {asset_id}") + # 3.1 测试 /assets/get + print("\n [测试] 获取用户资产列表") + my_assets = get_my_assets(token1) + if my_assets: + print(f" [输出] 成功获取 {len(my_assets)} 个资产") + else: + print(" [失败] 无法获取资产列表") + # 4. 指定用户 2 为继承人 print("用户 1 指定用户 2 为继承人") - assign_heir(token1, asset_id, "user2") + assign_heir(token1, asset1["id"], "user2") print("\n## 3. 继承流 (Inheritance Layer)") # 5. Admin 宣布用户 1 挂了 @@ -166,7 +193,7 @@ def main(): # 7. 用户 2 申领资产,并带上自己的分片 (share_c) print("用户 2 申领资产,并带上自己的分片 (share_c)") - claim_res = claim_asset(token2, asset_id, json.dumps(share_c)) + claim_res = claim_asset(token2, asset1["id"], json.dumps(share_c)) if not claim_res: return