diff --git a/app/database.py b/app/database.py index 5cd2ea1..1123d47 100644 --- a/app/database.py +++ b/app/database.py @@ -24,12 +24,62 @@ async def init_db(): # 自动创建表 await conn.run_sync(Base.metadata.create_all) - # 创建默认管理员用户 async with AsyncSessionLocal() as session: from . import auth from sqlalchemy.future import select - # 检查是否已存在 admin 用户 + # 1. 检查并创建默认订阅级别 (MUST BE FIRST because of FK constraints) + tiers = [ + { + "name": "Free", + "max_heirs": 1, + "weekly_token_limit": 1000, + "max_assets": 5, + "max_storage_mb": 10, + "can_use_ai_proxy": False, + "description": "Standard free tier" + }, + { + "name": "Pro", + "max_heirs": 5, + "weekly_token_limit": 10000, + "max_assets": 50, + "max_storage_mb": 100, + "can_use_ai_proxy": True, + "description": "Professional tier for active users" + }, + { + "name": "Ultra", + "max_heirs": 100, + "weekly_token_limit": 100000, + "max_assets": 1000, + "max_storage_mb": 1024, + "can_use_ai_proxy": True, + "description": "Ultimate tier for power users" + }, + { + "name": "Unlimited", + "max_heirs": 9999, + "weekly_token_limit": 999999, + "max_assets": 9999, + "max_storage_mb": 999999, + "can_use_ai_proxy": True, + "description": "Internal unlimited tier" + } + ] + + for tier_data in tiers: + result = await session.execute( + select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == tier_data["name"]) + ) + if not result.scalars().first(): + new_tier = models.SubscriptionPlans(**tier_data) + session.add(new_tier) + print(f"✅ Default subscription tier '{tier_data['name']}' created") + + await session.commit() + + # 2. 检查并创建默认管理员用户 result = await session.execute( select(models.User).where(models.User.username == "admin") ) @@ -52,7 +102,7 @@ async def init_db(): await session.commit() print("✅ Default admin user created (username: admin, password: admin123)") - # 检查是否已存在 Gemini 配置 + # 3. 检查是否已存在 Gemini 配置 result = await session.execute( select(models.AIConfig).where(models.AIConfig.provider_name == "gemini") ) diff --git a/app/main.py b/app/main.py index fd74305..5cbefd4 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,7 @@ from passlib.context import CryptContext from sqlalchemy.exc import IntegrityError from contextlib import asynccontextmanager import httpx -from datetime import datetime +from datetime import datetime, timedelta from typing import List @asynccontextmanager @@ -199,7 +199,7 @@ async def assign_asset( heir_result = await db.execute( select(models.User).where( - models.User.username == assignment.heir_name + models.User.email == assignment.heir_email ) ) heir_user = heir_result.scalars().first() @@ -216,7 +216,21 @@ async def assign_asset( asset.heir = heir_user await db.commit() - return {"message": f"Asset assigned to {assignment.heir_name}"} + return {"message": f"Asset assigned to {assignment.heir_email}"} + + +@app.get("/assets/designated", response_model=List[schemas.AssetOut]) +async def get_designated_assets( + current_user: models.User = Depends(auth.get_current_user), + db: AsyncSession = Depends(database.get_db) +): + """ + Query assets where the current user is the designated heir. + """ + result = await db.execute( + select(models.Asset).where(models.Asset.heir_id == current_user.id) + ) + return result.scalars().all() @@ -253,14 +267,34 @@ async def declare_user_guale( "guale": target_user.guale } -# 用于测试热加载 -@app.post("/post1") -async def test1(): - a=2 - b=3 - c = a+b - return {"msg": f"this is a msg {c}"} +async def get_or_create_token_usage(user_id: int, db: AsyncSession): + # Get current week start (Monday) + now = datetime.utcnow() + monday = now - timedelta(days=now.weekday()) + week_start = monday.replace(hour=0, minute=0, second=0, microsecond=0) + + result = await db.execute( + select(models.UserTokenUsage).where(models.UserTokenUsage.user_id == user_id) + ) + usage = result.scalars().first() + + if not usage: + usage = models.UserTokenUsage( + user_id=user_id, + tokens_used=0, + last_reset_at=week_start + ) + db.add(usage) + await db.commit() + await db.refresh(usage) + #每周重置token使用情况 + elif usage.last_reset_at < week_start: + usage.tokens_used = 0 + usage.last_reset_at = week_start + await db.commit() + + return usage @app.post("/ai/proxy", response_model=schemas.AIResponse) async def ai_proxy( @@ -272,6 +306,43 @@ async def ai_proxy( Proxy relay for AI requests. Fetches AI configuration from the database. """ + def get_quota_exceeded_response(): + return { + "id": f"chatcmpl-{int(datetime.utcnow().timestamp())}", + "object": "chat.completion", + "created": int(datetime.utcnow().timestamp()), + "model": "quota-manager", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "quota exceeded, please upgrade plan" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + + # 1. 检查 Tier 是否允许使用 AI + result = await db.execute( + select(models.SubscriptionPlans).where(models.SubscriptionPlans.name == current_user.tier) + ) + tier_plan = result.scalars().first() + + if not tier_plan or not tier_plan.can_use_ai_proxy: + return get_quota_exceeded_response() + + # 2. 检查本周 Token 使用是否超过限制 + usage_record = await get_or_create_token_usage(current_user.id, db) + if usage_record.tokens_used >= tier_plan.weekly_token_limit: + return get_quota_exceeded_response() + # Fetch active AI config result = await db.execute( select(models.AIConfig).where(models.AIConfig.is_active == True) @@ -290,6 +361,9 @@ async def ai_proxy( payload = ai_request.model_dump() payload["model"] = config.default_model + current_user.last_active_at = datetime.utcnow() + await db.commit() + async with httpx.AsyncClient() as client: try: response = await client.post( @@ -299,7 +373,15 @@ async def ai_proxy( timeout=30.0 ) response.raise_for_status() - return response.json() + ai_data = response.json() + + # 3. 记录使用的 Token + total_tokens = ai_data.get("usage", {}).get("total_tokens", 0) + if total_tokens > 0: + usage_record.tokens_used += total_tokens + await db.commit() + + return ai_data except httpx.HTTPStatusError as e: raise HTTPException( status_code=e.response.status_code, @@ -311,3 +393,12 @@ async def ai_proxy( detail=f"An error occurred while requesting AI provider: {str(e)}" ) +# 用于测试热加载 +@app.post("/post1") +async def test1(): + a=2 + b=3 + c = a+b + return {"msg": f"this is a msg {c}"} + + diff --git a/app/models.py b/app/models.py index c0f0bc2..e727ef0 100644 --- a/app/models.py +++ b/app/models.py @@ -4,6 +4,21 @@ from sqlalchemy.orm import relationship from .database import Base +class SubscriptionPlans(Base): + __tablename__ = "subscription_plans" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, unique=True, index=True) # "Free", "Pro", "Ultra" + max_heirs = Column(Integer, default=1) + weekly_token_limit = Column(Integer, default=1000) + max_assets = Column(Integer, default=5) + max_storage_mb = Column(Integer, default=10) + can_use_ai_proxy = Column(Boolean, default=False) + description = Column(Text, nullable=True) + + users = relationship("User", back_populates="subscription_plans") + + class User(Base): __tablename__ = "users" @@ -12,10 +27,13 @@ class User(Base): email = Column(String, unique=True, index=True) hashed_password = Column(String) - tier = Column(String) + + tier = Column(String, ForeignKey("subscription_plans.name"), default="Free") tier_expires_at = Column(DateTime) last_active_at = Column(DateTime) + subscription_plans = relationship("SubscriptionPlans", back_populates="users") + # System keys public_key = Column(String) private_key = Column(String) # Encrypted or raw? Storing raw for now as per req @@ -50,4 +68,14 @@ class AIConfig(Base): api_key = Column(String) api_url = Column(String) default_model = Column(String) - is_active = Column(Boolean, default=True) \ No newline at end of file + is_active = Column(Boolean, default=True) + +class UserTokenUsage(Base): + __tablename__ = "user_token_usage" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), unique=True) + tokens_used = Column(Integer, default=0) + last_reset_at = Column(DateTime) + + user = relationship("User", backref="token_usage", uselist=False) \ No newline at end of file diff --git a/app/schemas.py b/app/schemas.py index d4d1257..50231c6 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -67,7 +67,7 @@ class AssetClaimOut(AssetClaim): class AssetAssign(BaseModel): asset_id: int - heir_name: str + heir_email: str class DeclareGuale(BaseModel): username: str @@ -87,4 +87,19 @@ class AIResponse(BaseModel): created: int model: str choices: List[dict] - usage: dict \ No newline at end of file + usage: dict + + +# Subscription Plans Schemas +class SubscriptionPlansBase(BaseModel): + name: str + max_heirs: int + weekly_token_limit: int + max_assets: int + max_storage_mb: int + can_use_ai_proxy: bool + description: Optional[str] = None + +class SubscriptionPlansOut(SubscriptionPlansBase): + id: int + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/test/test_scenario.py b/test/test_scenario.py index dde0f7d..2c7c089 100644 --- a/test/test_scenario.py +++ b/test/test_scenario.py @@ -51,16 +51,16 @@ def create_asset(token, title, private_key_shard, content_inner_encrypted): print(f"Failed to create asset: {response.text}") return None -def assign_heir(token, asset_id, heir_name): +def assign_heir(token, asset_id, heir_email): url = f"{BASE_URL}/assets/assign" headers = {"Authorization": f"Bearer {token}"} data = { "asset_id": asset_id, - "heir_name": heir_name + "heir_email": heir_email } response = requests.post(url, json=data, headers=headers) if response.status_code == 200: - print(f"Asset {asset_id} assigned to heir {heir_name} successfully.") + print(f"Asset {asset_id} assigned to heir {heir_email} successfully.") return response.json() else: print(f"Failed to assign heir: {response.text}") @@ -105,6 +105,17 @@ def get_my_assets(token): print(f"Failed to retrieve assets: {response.text}") return None +def get_designated_assets(token): + url = f"{BASE_URL}/assets/designated" + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(url, headers=headers) + if response.status_code == 200: + print(f"Designated assets retrieved successfully.") + return response.json() + else: + print(f"Failed to retrieve designated assets: {response.text}") + return None + def main(): # 1. 创建三个用户 users = [ @@ -143,7 +154,7 @@ def main(): if not token1: return - # 3. 创建一个 asset + # 3. 创建三个 asset asset1 = create_asset( token1, "My Secret Asset1", @@ -158,23 +169,42 @@ def main(): ciphertext_1 ) - if not asset1 or not asset2: + asset3 = create_asset( + token1, + "My Secret Asset3", + share_a, + ciphertext_1 + ) + + if not asset1 or not asset2 or not asset3: print(" [失败] 创建资产失败") return # 3.1 测试 /assets/get print("\n [测试] 获取用户资产列表") - my_assets = get_my_assets(token1) - if my_assets: - print(f" [输出] 成功获取 {len(my_assets)} 个资产") + user1_assets = get_my_assets(token1) + if user1_assets: + print(f" [输出] 用户1共有 {len(user1_assets)} 个资产") else: print(" [失败] 无法获取资产列表") - # 4. 指定用户 2 为继承人 - print("用户 1 指定用户 2 为继承人") - assign_heir(token1, asset1["id"], "user2") - + print("用户 1 为用户 2 分配遗产") + assign_heir(token1, asset1["id"], "user2@example.com") + assign_heir(token1, asset2["id"], "user2@example.com") + + # 4.1 用户2查询自己能继承多少遗产 + print("\n [测试] 用户 2 查询自己被指定的资产") + token2_temp = login_user("user2", "pass123") + designated_assets = get_designated_assets(token2_temp) + if designated_assets: + print(f" [输出] 用户 2 共有 {len(designated_assets)} 个被指定的资产") + for asset in designated_assets: + print(f" - Asset ID: {asset['id']}, Title: {asset['title']}") + else: + print(" [失败] 无法获取被指定资产列表") + + print("\n## 3. 继承流 (Inheritance Layer)") # 5. Admin 宣布用户 1 挂了 print("Admin 宣布用户 1 挂了")