diff --git a/app/database.py b/app/database.py index 5cd2ea1..1123d47 100644 --- a/app/database.py +++ b/app/database.py @@ -24,12 +24,62 @@ async def init_db(): # 自动创建表 await conn.run_sync(Base.metadata.create_all) - # 创建默认管理员用户 async with AsyncSessionLocal() as session: from . import auth from sqlalchemy.future import select - # 检查是否已存在 admin 用户 + # 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints) + tiers = [ + { + "name": "Free", + "max_heirs": 1, + "weekly_token_limit": 1000, + "max_assets": 5, + "max_storage_mb": 10, + "can_use_ai_proxy": False, + "description": "Standard free tier" + }, + { + "name": "Pro", + "max_heirs": 5, + "weekly_token_limit": 10000, + "max_assets": 50, + "max_storage_mb": 100, + "can_use_ai_proxy": True, + "description": "Professional tier for active users" + }, + { + "name": "Ultra", + "max_heirs": 100, + "weekly_token_limit": 100000, + "max_assets": 1000, + "max_storage_mb": 1024, + "can_use_ai_proxy": True, + "description": "Ultimate tier for power users" + }, + { + "name": "Unlimited", + "max_heirs": 9999, + "weekly_token_limit": 999999, + "max_assets": 9999, + "max_storage_mb": 999999, + "can_use_ai_proxy": True, + "description": "Internal unlimited tier" + } + ] + + for tier_data in tiers: + result = await session.execute( + select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"]) + ) + if not result.scalars().first(): + new_tier = models.SubscriptionPlans(**tier_data) + session.add(new_tier) + print(f"✅ Default subscription tier '{tier_data['name']}' created") + + await session.commit() + + # 2. 检查并创建默认管理员用户 result = await session.execute( select(models.User).where(models.User.username == "admin") ) @@ -52,7 +102,7 @@ async def init_db(): await session.commit() print("✅ Default admin user created (username: admin, password: admin123)") - # 检查是否已存在 Gemini 配置 + # 3. 检查是否已存在 Gemini 配置 result = await session.execute( select(models.AIConfig).where(models.AIConfig.provider_name == "gemini") ) diff --git a/app/main.py b/app/main.py index fd74305..828fe1d 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,7 @@ from passlib.context import CryptContext from sqlalchemy.exc import IntegrityError from contextlib import asynccontextmanager import httpx -from datetime import datetime +from datetime import datetime, timedelta from typing import List @asynccontextmanager @@ -115,9 +115,12 @@ async def create_asset( new_asset = models.Asset( title=asset.title, + type=asset.type, content_outer_encrypted=encrypted_content, private_key_shard=asset.private_key_shard, - author_id=current_user.id + author_id=current_user.id, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() ) db.add(new_asset) await db.commit() @@ -199,7 +202,7 @@ async def assign_asset( heir_result = await db.execute( select(models.User).where( - models.User.username == assignment.heir_name + models.User.email == assignment.heir_email ) ) heir_user = heir_result.scalars().first() @@ -216,7 +219,21 @@ async def assign_asset( asset.heir = heir_user await db.commit() - return {"message": f"Asset assigned to {assignment.heir_name}"} + return {"message": f"Asset assigned to {assignment.heir_email}"} + + +@app.get("/assets/designated", response_model=List[schemas.AssetOut]) +async def get_designated_assets( + current_user: models.User = Depends(auth.get_current_user), + db: AsyncSession = Depends(database.get_db) +): + """ + Query assets where the current user is the designated heir. + """ + result = await db.execute( + select(models.Asset).where(models.Asset.heir_id == current_user.id) + ) + return result.scalars().all() @@ -253,14 +270,34 @@ async def declare_user_guale( "guale": target_user.guale } -# 用于测试热加载 -@app.post("/post1") -async def test1(): - a=2 - b=3 - c = a+b - return {"msg": f"this is a msg {c}"} +async def get_or_create_token_usage(user_id: int, db: AsyncSession): + # Get current week start (Monday) + now = datetime.utcnow() + monday = now - timedelta(days=now.weekday()) + week_start = monday.replace(hour=0, minute=0, second=0, microsecond=0) + + result = await db.execute( + select(models.UserTokenUsage).where(models.UserTokenUsage.user_id == user_id) + ) + usage = result.scalars().first() + + if not usage: + usage = models.UserTokenUsage( + user_id=user_id, + tokens_used=0, + last_reset_at=week_start + ) + db.add(usage) + await db.commit() + await db.refresh(usage) + #每周重置token使用情况 + elif usage.last_reset_at < week_start: + usage.tokens_used = 0 + usage.last_reset_at = week_start + await db.commit() + + return usage @app.post("/ai/proxy", response_model=schemas.AIResponse) async def ai_proxy( @@ -272,6 +309,43 @@ async def ai_proxy( Proxy relay for AI requests. Fetches AI configuration from the database. """ + def get_quota_exceeded_response(): + return { + "id": f"chatcmpl-{int(datetime.utcnow().timestamp())}", + "object": "chat.completion", + "created": int(datetime.utcnow().timestamp()), + "model": "quota-manager", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "quota exceeded, please upgrade plan" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + # 1. 检查 Tier 是否允许使用 AI + result = await db.execute( + select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == current_user.tier) + ) + tier_plan = result.scalars().first() + + if not tier_plan or not tier_plan.can_use_ai_proxy: + return get_quota_exceeded_response() + + # 2. 检查本周 Token 使用是否超过限制 + usage_record = await get_or_create_token_usage(current_user.id, db) + if usage_record.tokens_used >= tier_plan.weekly_token_limit: + return get_quota_exceeded_response() + # Fetch active AI config result = await db.execute( select(models.AIConfig).where(models.AIConfig.is_active == True) @@ -290,6 +364,9 @@ async def ai_proxy( payload = ai_request.model_dump() payload["model"] = config.default_model + current_user.last_active_at = datetime.utcnow() + await db.commit() + async with httpx.AsyncClient() as client: try: response = await client.post( @@ -299,7 +376,15 @@ async def ai_proxy( timeout=30.0 ) response.raise_for_status() - return response.json() + ai_data = response.json() + + # 3. 记录使用的 Token + total_tokens = ai_data.get("usage", {}).get("total_tokens", 0) + if total_tokens > 0: + usage_record.tokens_used += total_tokens + await db.commit() + + return ai_data except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, @@ -311,3 +396,12 @@ async def ai_proxy( detail=f"An error occurred while requesting AI provider: {str(e)}" ) +# 用于测试热加载 +@app.post("/post1") +async def test1(): + a=2 + b=3 + c = a+b + return {"msg": f"this is a msg {c}"} + + diff --git a/app/models.py b/app/models.py index c0f0bc2..50e26ec 100644 --- a/app/models.py +++ b/app/models.py @@ -2,6 +2,23 @@ from sqlalchemy import Column, Integer, String, ForeignKey, Text, Table, Boolean from sqlalchemy.orm import relationship from .database import Base +from datetime import datetime + + + +class SubscriptionPlans(Base): + __tablename__ = "subscription_plans" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True, index=True) # "Free", "Pro", "Ultra" + max_heirs = Column(Integer, default=1) + weekly_token_limit = Column(Integer, default=1000) + max_assets = Column(Integer, default=5) + max_storage_mb = Column(Integer, default=10) + can_use_ai_proxy = Column(Boolean, default=False) + description = Column(Text, nullable=True) + + users = relationship("User", back_populates="subscription_plans") class User(Base): @@ -12,10 +29,13 @@ class User(Base): email = Column(String, unique=True, index=True) hashed_password = Column(String) - tier = Column(String) + + tier = Column(String, ForeignKey("subscription_plans.name"), default="Free") tier_expires_at = Column(DateTime) last_active_at = Column(DateTime) + subscription_plans = relationship("SubscriptionPlans", back_populates="users") + # System keys public_key = Column(String) private_key = Column(String) # Encrypted or raw? Storing raw for now as per req @@ -35,6 +55,7 @@ class Asset(Base): content_outer_encrypted = Column(Text) author_id = Column(Integer, ForeignKey("users.id")) heir_id = Column(Integer, ForeignKey("users.id")) + type = Column(String, index=True, nullable=True) # Key shard for this asset private_key_shard = Column(String) @@ -42,6 +63,9 @@ class Asset(Base): author = relationship("User", foreign_keys=[author_id], back_populates="assets") heir = relationship("User", foreign_keys=[heir_id], back_populates="inherited_assets") + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + class AIConfig(Base): __tablename__ = "ai_configs" @@ -50,4 +74,14 @@ class AIConfig(Base): api_key = Column(String) api_url = Column(String) default_model = Column(String) - is_active = Column(Boolean, default=True) \ No newline at end of file + is_active = Column(Boolean, default=True) + +class UserTokenUsage(Base): + __tablename__ = "user_token_usage" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), unique=True) + tokens_used = Column(Integer, default=0) + last_reset_at = Column(DateTime) + + user = relationship("User", backref="token_usage", uselist=False) \ No newline at end of file diff --git a/app/schemas.py b/app/schemas.py index d4d1257..6061758 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -45,6 +45,7 @@ class LoginResponse(BaseModel): # Asset Schemas (renamed from Article) class AssetBase(BaseModel): title: str + type: Optional[str] = "note" class AssetCreate(AssetBase): private_key_shard: str @@ -55,6 +56,8 @@ class AssetOut(AssetBase): author_id: int private_key_shard: str content_outer_encrypted: str + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) class AssetClaim(BaseModel): @@ -67,7 +70,7 @@ class AssetClaimOut(AssetClaim): class AssetAssign(BaseModel): asset_id: int - heir_name: str + heir_email: str class DeclareGuale(BaseModel): username: str @@ -87,4 +90,19 @@ class AIResponse(BaseModel): created: int model: str choices: List[dict] - usage: dict \ No newline at end of file + usage: dict + + +# Subscription Plans Schemas +class SubscriptionPlansBase(BaseModel): + name: str + max_heirs: int + weekly_token_limit: int + max_assets: int + max_storage_mb: int + can_use_ai_proxy: bool + description: Optional[str] = None + +class SubscriptionPlansOut(SubscriptionPlansBase): + id: int + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/test/test_scenario.py b/test/test_scenario.py index dde0f7d..6bb0f7f 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -35,32 +35,35 @@ def login_user(username, password): print(f"Failed to login {username}: {response.text}") return None -def create_asset(token, title, private_key_shard, content_inner_encrypted): +def create_asset(token, title, private_key_shard, content_inner_encrypted, asset_type="note"): url = f"{BASE_URL}/assets/create" headers = {"Authorization": f"Bearer {token}"} data = { "title": title, + "type": asset_type, "private_key_shard": str(private_key_shard), "content_inner_encrypted": str(content_inner_encrypted) } response = requests.post(url, json=data, headers=headers) if response.status_code == 200: - print(f"Asset '{title}' created successfully.") - return response.json() + asset_data = response.json() + print(f"Asset '{title}' (type: {asset_type}) created successfully.") + print(f" [校验] Timestamps: created_at={asset_data.get('created_at')}, updated_at={asset_data.get('updated_at')}") + return asset_data else: print(f"Failed to create asset: {response.text}") return None -def assign_heir(token, asset_id, heir_name): +def assign_heir(token, asset_id, heir_email): url = f"{BASE_URL}/assets/assign" headers = {"Authorization": f"Bearer {token}"} data = { "asset_id": asset_id, - "heir_name": heir_name + "heir_email": heir_email } response = requests.post(url, json=data, headers=headers) if response.status_code == 200: - print(f"Asset {asset_id} assigned to heir {heir_name} successfully.") + print(f"Asset {asset_id} assigned to heir {heir_email} successfully.") return response.json() else: print(f"Failed to assign heir: {response.text}") @@ -105,6 +108,17 @@ def get_my_assets(token): print(f"Failed to retrieve assets: {response.text}") return None +def get_designated_assets(token): + url = f"{BASE_URL}/assets/designated" + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(url, headers=headers) + if response.status_code == 200: + print(f"Designated assets retrieved successfully.") + return response.json() + else: + print(f"Failed to retrieve designated assets: {response.text}") + return None + def main(): # 1. 创建三个用户 users = [ @@ -143,38 +157,60 @@ def main(): if not token1: return - # 3. 创建一个 asset + # 3. 创建三个 asset asset1 = create_asset( token1, "My Secret Asset1", share_a, - ciphertext_1 + ciphertext_1, + "note" ) asset2 = create_asset( token1, "My Secret Asset2", share_a, - ciphertext_1 + ciphertext_1, + "note" ) - if not asset1 or not asset2: + asset3 = create_asset( + token1, + "My Secret Asset3", + share_a, + ciphertext_1, + "note" + ) + + if not asset1 or not asset2 or not asset3: print(" [失败] 创建资产失败") return # 3.1 测试 /assets/get print("\n [测试] 获取用户资产列表") - my_assets = get_my_assets(token1) - if my_assets: - print(f" [输出] 成功获取 {len(my_assets)} 个资产") + user1_assets = get_my_assets(token1) + if user1_assets: + print(f" [输出] 用户1共有 {len(user1_assets)} 个资产") else: print(" [失败] 无法获取资产列表") - # 4. 指定用户 2 为继承人 - print("用户 1 指定用户 2 为继承人") - assign_heir(token1, asset1["id"], "user2") - + print("用户 1 为用户 2 分配遗产") + assign_heir(token1, asset1["id"], "user2@example.com") + assign_heir(token1, asset2["id"], "user2@example.com") + + # 4.1 用户2查询自己能继承多少遗产 + print("\n [测试] 用户 2 查询自己被指定的资产") + token2_temp = login_user("user2", "pass123") + designated_assets = get_designated_assets(token2_temp) + if designated_assets: + print(f" [输出] 用户 2 共有 {len(designated_assets)} 个被指定的资产") + for asset in designated_assets: + print(f" - Asset ID: {asset['id']}, Title: {asset['title']}") + else: + print(" [失败] 无法获取被指定资产列表") + + print("\n## 3. 继承流 (Inheritance Layer)") # 5. Admin 宣布用户 1 挂了 print("Admin 宣布用户 1 挂了")