跳至內容

使用 Peewee 的 SQL (關聯式) 資料庫 (已棄用)

「已棄用」

本教學已棄用,並將於未來版本中移除。

警告

如果您剛開始使用,使用 SQLAlchemy 的教學 SQL (關聯式) 資料庫 應該就足夠了。

您可以略過此部分。

不建議在 FastAPI 中使用 Peewee,因為它與任何 Python 非同步功能都不相容。 有幾個更好的替代方案。

資訊

這些文件假設使用 Pydantic v1。

由於 Pewee 與任何非同步功能都不相容,並且有更好的替代方案,我不會為 Pydantic v2 更新這些文件,它們目前僅保留供歷史參考。

此處的範例不再於 CI 中進行測試(如以往)。

如果您要從頭開始一個專案,您可能最好使用 SQLAlchemy ORM (SQL (關聯式) 資料庫) 或任何其他非同步 ORM。

如果您已經有使用 Peewee ORM 的程式碼庫,您可以查看這裡如何將其與 FastAPI 搭配使用。

「需要 Python 3.7+」

您需要 Python 3.7 或更高版本才能安全地將 Peewee 與 FastAPI 搭配使用。

適用於非同步的 Peewee

Peewee 並非為非同步框架設計,也沒有考慮到它們。

Peewee 對其預設值以及如何使用有一些嚴格的假設。

如果您正在使用較舊的非同步框架開發應用程式,並且可以使用其所有預設值,它會是一個很棒的工具

但是,如果您需要更改某些預設值、支援多個預定義資料庫、使用非同步框架(例如 FastAPI)等等,您將需要新增一些相當複雜的額外程式碼來覆蓋這些預設值。

儘管如此,還是可以做到,在這裡您將看到需要新增哪些程式碼才能將 Peewee 與 FastAPI 搭配使用。

「技術細節」

您可以閱讀更多關於 Peewee 對 Python 非同步的立場 在文件中一個 issue一個 PR

相同的應用程式

我們將建立與 SQLAlchemy 教學 (SQL (關聯式) 資料庫) 中相同的應用程式。

大部分的程式碼實際上是相同的。

因此,我們將只關注差異。

檔案結構

假設你有一個名為 my_super_project 的目錄,其中包含一個名為 sql_app 的子目錄,結構如下:

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    └── schemas.py

這與 SQLAlchemy 教學中的結構幾乎相同。

現在讓我們看看每個檔案/模組的功能。

建立 Peewee 部分

讓我們參考 sql_app/database.py 檔案。

標準的 Peewee 程式碼

首先讓我們檢查所有標準的 Peewee 程式碼,建立一個 Peewee 資料庫

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

提示

請記住,如果您想使用不同的資料庫,例如 PostgreSQL,您不能只更改字串。您需要使用不同的 Peewee 資料庫類別。

注意事項

參數

check_same_thread=False

等價於 SQLAlchemy 教學中的參數

connect_args={"check_same_thread": False}

...它僅適用於 SQLite

「技術細節」

SQL (關聯式) 資料庫 中完全相同的技術細節適用。

使 Peewee 與非同步相容 PeeweeConnectionState

Peewee 和 FastAPI 的主要問題是 Peewee 嚴重依賴 Python 的 threading.local,並且它沒有直接的方法來覆寫它或讓您直接處理連線/會話(就像在 SQLAlchemy 教學中所做的那樣)。

threading.local 與現代 Python 的新非同步功能不相容。

「技術細節」

threading.local 用於擁有一個「神奇」的變數,該變數對於每個執行緒具有不同的值。

這在設計為每個請求只有一個執行緒的舊框架中很有用,不多也不少。

使用這個,每個請求都會有自己的資料庫連線/會話,這才是最終目標。

但是 FastAPI 使用新的非同步功能,可以在同一個執行緒上處理多個請求。同時,對於單個請求,它可以在不同的執行緒(在執行緒池中)中運行多個任務,取決於您使用的是 async def 還是普通的 def。這就是 FastAPI 效能提升的原因。

但 Python 3.7 及更高版本提供了比 threading.local 更先進的替代方案,它也可以用於 threading.local 可以使用的地方,並且與新的非同步功能相容。

我們將使用它。它叫做 contextvars

我們將覆寫使用 threading.local 的 Peewee 內部組件,並將它們替換為 contextvars,並進行相應的更新。

這看起來可能有點複雜(實際上也是如此),您並不需要完全理解它的工作原理就能使用它。

我們將建立一個 PeeweeConnectionState

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

這個類別繼承自 Peewee 使用的一個特殊的內部類別。

它包含了使 Peewee 使用 contextvars 而不是 threading.local 的所有邏輯。

contextvars 的工作方式與 threading.local 有些不同。但 Peewee 的其餘內部程式碼假設這個類別與 threading.local 一起使用。

因此,我們需要做一些額外的技巧,讓它像使用 threading.local 一樣工作。__init____setattr____getattr__ 實現了所有必要的技巧,以便 Peewee 在不知道它現在與 FastAPI 相容的情況下使用它。

提示

這只會讓 Peewee 在與 FastAPI 一起使用時正常運作。不會隨機開啟或關閉正在使用的連線,產生錯誤等等。

但它並沒有賦予 Peewee 非同步的超能力。您仍然應該使用普通的 def 函數,而不是 async def

使用自訂的 PeeweeConnectionState 類別

現在,使用新的 PeeweeConnectionState 覆寫 Peewee 資料庫 db 物件中的內部屬性 ._state

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

提示

確保在建立 db *之後* 覆寫 db._state

提示

對於任何其他 Peewee 資料庫,包括 PostgresqlDatabaseMySQLDatabase 等,您也需要執行相同的操作。

建立資料庫模型

現在讓我們來看一下檔案 sql_app/models.py

為我們的資料建立 Peewee 模型

現在為 UserItem 建立 Peewee 模型(類別)。

這與您遵循 Peewee 教學並將模型更新為與 SQLAlchemy 教學中的資料相同的操作相同。

提示

Peewee 也使用術語「**模型**」來指稱這些與資料庫互動的類別和實例。

但 Pydantic 也使用術語「**模型**」來指稱不同的東西,即資料驗證、轉換和文件類別及實例。

database(即上面的檔案 database.py)匯入 db 並在此處使用它。

import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db

提示

Peewee 會建立幾個魔術屬性。

它會自動新增一個 id 屬性作為整數主鍵。

它會根據類別名稱選擇表格的名稱。

對於 Item,它會建立一個具有 User 整數 ID 的 owner_id 屬性。但我們在任何地方都沒有宣告它。

建立 Pydantic 模型

現在讓我們檢查檔案 sql_app/schemas.py

提示

為了避免混淆 Peewee *模型* 和 Pydantic *模型*,我們將使用包含 Peewee 模型的檔案 models.py 和包含 Pydantic 模型的檔案 schemas.py

這些 Pydantic 模型或多或少定義了一個「schema」(一個有效的資料形狀)。

因此,這將有助於我們在同時使用兩者時避免混淆。

建立 Pydantic *模型* / schemas

建立與 SQLAlchemy 教學中所有相同的 Pydantic 模型

from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

提示

這裡我們正在建立帶有 id 的模型。

我們沒有在 Peewee 模型中明確指定 id 屬性,但 Peewee 會自動新增一個。

我們也正在將魔術屬性 owner_id 新增到 Item

為 Pydantic *模型* / schemas 建立 PeeweeGetterDict

當您在 Peewee 物件中存取關係時,例如在 some_user.items 中,Peewee 並未提供 Itemlist(列表)。

它提供了一個特殊的自定義類別 ModelSelect 的物件。

可以使用 list(some_user.items) 建立其項目的 list

但物件本身並不是一個 list。它也不是一個實際的 Python 產生器。因此,Pydantic 預設情況下不知道如何將其轉換為 Pydantic *模型* / schemas 的 list

但最新版本的 Pydantic 允許提供一個繼承自 pydantic.utils.GetterDict 的自定義類別,以提供在使用 orm_mode = True 檢索 ORM 模型屬性值時使用的功能。

我們將建立一個自定義的 PeeweeGetterDict 類別,並在所有使用 orm_mode 的相同 Pydantic *模型* / schemas 中使用它

from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

這裡我們正在檢查正在存取的屬性(例如 some_user.items 中的 .items)是否是 peewee.ModelSelect 的實例。

如果是這種情況,則只需返回一個包含它的 list

然後,我們在使用 orm_mode = True 的 Pydantic *模型* / schemas 中使用它,並使用設定變數 getter_dict = PeeweeGetterDict

提示

我們只需要建立一個 PeeweeGetterDict 類別,就可以在所有 Pydantic *模型*/schema 中使用它。

CRUD 工具函式

現在讓我們來看一下 sql_app/crud.py 檔案。

建立所有 CRUD 工具函式

建立所有與 SQLAlchemy 教學相同的 CRUD 工具函式,所有程式碼都非常相似。

from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item

與 SQLAlchemy 教學的程式碼有一些差異。

我們沒有傳遞 db 屬性。而是直接使用模型。這是因為 db 物件是一個全域物件,包含了所有連線邏輯。這就是為什麼我們必須在上面進行所有 contextvars 更新的原因。

此外,當返回多個物件時,例如在 get_users 中,我們直接呼叫 list,就像這樣:

list(models.User.select())

這與我們必須建立自訂 PeeweeGetterDict 的原因相同。但是藉由返回一個已經是 list 的東西,而不是 peewee.ModelSelect,在稍後我們會看到的 *路徑操作* 中使用 List[models.User]response_model 將能正常運作。

主要的 FastAPI 應用程式

現在在 sql_app/main.py 檔案中,讓我們整合並使用之前建立的所有其他部分。

建立資料庫表格

以非常簡化的方式建立資料庫表格。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

建立依賴項

建立一個依賴項,它會在請求開始時立即連線到資料庫,並在請求結束時斷開連線。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

這裡我們有一個空的 yield,因為我們實際上沒有直接使用資料庫物件。

它會連線到資料庫,並將連線資料儲存在一個獨立於每個請求的內部變數中(使用上面提到的 contextvars 技巧)。

由於資料庫連線可能造成 I/O 阻塞,因此此依賴項是使用一般的 def 函式建立的。

然後,在每個需要存取資料庫的 *路徑操作函式* 中,我們將其作為依賴項新增。

但我們沒有使用這個依賴項所提供的數值(它實際上沒有提供任何數值,因為它有一個空的 yield)。因此,我們沒有將它新增到 *路徑操作函式* 中,而是新增到 dependencies 參數中的 *路徑操作裝飾器*。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

Context 變數子依賴項

為了讓所有 contextvars 部分都能正常運作,我們需要確保在 ContextVar 中為每個使用資料庫的請求都有一個獨立的數值,並且該數值將在整個請求中用作資料庫狀態(連線、事務等)。

為此,我們需要建立另一個 async 依賴項 reset_db_state(),它在 get_db() 中作為子依賴項使用。它將設定 context 變數的數值(只是一個預設的 dict),該數值將在整個請求中用作資料庫狀態。然後,依賴項 get_db() 將在其中儲存資料庫狀態(連線、事務等)。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

對於**下一個請求**,由於我們將在 async 依賴項 reset_db_state() 中再次重置該 context 變數,然後在 get_db() 依賴項中建立新的連線,因此新的請求將擁有自己的資料庫狀態(連線、事務等)。

提示

由於 FastAPI 是一個非同步框架,一個請求可以在處理過程中,另一個請求也可以被接收並開始處理,而且它們都可以在同一個執行緒中處理。

但上下文變數知道這些非同步特性,因此,在 async 依賴項 reset_db_state() 中設定的 Peewee 資料庫狀態將在整個請求過程中保留其自身的資料。

同時,其他並行請求將擁有自己的資料庫狀態,該狀態在整個請求過程中都是獨立的。

Peewee 代理

如果您使用的是 Peewee 代理,實際的資料庫位於 db.obj

因此,您需要使用以下方式重置它

async def reset_db_state():
    database.db.obj._state._state.set(db_state_default.copy())
    database.db.obj._state.reset()

建立您的 FastAPI 路徑操作

現在,最後,這是標準的 FastAPI 路徑操作 程式碼。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

關於 defasync def

與 SQLAlchemy 相同,我們沒有執行類似以下的程式碼

user = await models.User.select().first()

...而是使用

user = models.User.select().first()

因此,同樣地,我們應該在沒有 async def 的情況下宣告 *路徑操作函式* 和依賴項,只需使用普通的 def,如下所示

# Something goes here
def read_users(skip: int = 0, limit: int = 100):
    # Something goes here

使用非同步測試 Peewee

此範例包含一個額外的 *路徑操作*,它使用 time.sleep(sleep_time) 模擬長時間處理的請求。

它會在開始時開啟資料庫連線,並在回覆之前等待幾秒鐘。每個新的請求都會減少一秒的等待時間。

這可以讓您輕鬆測試您的 Peewee 和 FastAPI 應用程式在所有關於執行緒的情況下是否正常運作。

如果您想檢查 Peewee 在未經修改的情況下如何破壞您的應用程式,請前往 sql_app/database.py 檔案並註解掉以下這一行

# db._state = PeeweeConnectionState()

並在 sql_app/main.py 檔案中,註解掉 async 依賴項 reset_db_state() 的主體,並將其替換為 pass

async def reset_db_state():
#     database.db._state._state.set(db_state_default.copy())
#     database.db._state.reset()
    pass

然後使用 Uvicorn 執行您的應用程式

$ uvicorn sql_app.main:app --reload

<span style="color: green;">INFO</span>:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)

在瀏覽器中開啟 http://127.0.0.1:8000/docs 並建立幾個使用者。

然後同時在 10 個分頁中開啟 http://127.0.0.1:8000/docs#/default/read_slow_users_slowusers__get

在所有分頁中前往「取得 /slowusers/」*路徑操作*。使用「試用」按鈕並在每個分頁中依序執行請求。

這些分頁將會等待一段時間,然後其中一些會顯示「內部伺服器錯誤」。

發生了什麼事

第一個分頁將使您的應用程式建立與資料庫的連線,並在回覆和關閉資料庫連線之前等待幾秒鐘。

然後,對於下一個分頁中的請求,您的應用程式將會減少一秒的等待時間,依此類推。

這表示它最終會比之前的某些請求更早完成一些最後分頁的請求。

然後,等待時間較短的最後一個請求將嘗試開啟資料庫連線,但由於其他分頁的先前請求可能會與第一個請求在同一個執行緒中處理,它將擁有相同的已開啟資料庫連線,Peewee 將會拋出錯誤,您將會在終端機中看到它,並且回應將會出現「內部伺服器錯誤」。

這可能會發生在多個分頁上。

如果有多個客戶端同時與您的應用程式通訊,可能會發生以下情況。

隨著您的應用程式開始同時處理越來越多的客戶端,單個請求中的等待時間需要越來越短才能觸發錯誤。

使用 FastAPI 修正 Peewee

現在回到 sql_app/database.py 檔案,並取消註釋這一行

db._state = PeeweeConnectionState()

sql_app/main.py 檔案中,取消註釋 async 依賴項 reset_db_state() 的主體

async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()

終止正在執行的應用程式並重新啟動它。

用 10 個分頁重複相同的過程。這次它們都會等待,您將獲得所有結果而不會出錯。

...您已修復它!

檢閱所有檔案

請記住,您應該有一個名為 my_super_project(或任何您想要的名稱)的目錄,其中包含一個名為 sql_app 的子目錄。

sql_app 應該包含以下檔案

  • sql_app/__init__.py:是一個空檔案。

  • sql_app/database.py:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()
  • sql_app/models.py:
import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db
  • sql_app/schemas.py:
from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict
  • sql_app/crud.py:
from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item
  • sql_app/main.py:
import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

技術細節

警告

這些是非常技術性的細節,您可能不需要。

問題所在

Peewee 預設使用 threading.local 來儲存其資料庫「狀態」資料(連線、交易等)。

threading.local 會建立一個專屬於目前執行緒的值,但非同步框架會在同一個執行緒中執行所有程式碼(例如,每個請求),而且可能不是按順序執行。

此外,非同步框架可以在執行緒池中執行一些同步程式碼(使用 asyncio.run_in_executor),但屬於同一個請求。

這意味著,使用 Peewee 目前的實作方式,多個任務可能會使用相同的 threading.local 變數,並最終共用相同的連線和資料(它們不應該這樣做),同時,如果它們在執行緒池中執行同步 I/O 阻塞程式碼(例如 FastAPI 中的普通 def 函式,在*路徑操作*和依賴項中),即使該程式碼屬於同一個請求,而且應該能夠存取相同的資料庫狀態,該程式碼也無法存取資料庫狀態變數。

上下文變數

Python 3.7 具有 contextvars,它可以建立與 threading.local 非常相似的區域變數,但也支援這些非同步功能。

有幾件事需要注意。

ContextVar 必須在模組的頂部建立,例如

some_var = ContextVar("some_var", default="default value")

要設定在目前「上下文」中使用的值(例如,針對目前請求),請使用

some_var.set("new value")

要在上下文中的任何位置取得值(例如,在處理目前請求的任何部分),請使用

some_var.get()

async 依賴項 reset_db_state() 中設定上下文變數

如果非同步程式碼的某些部分使用 some_var.set("updated in function") 設定值(例如,像 async 依賴項),則其中的其餘程式碼以及後續的程式碼(包括使用 await 呼叫的 async 函式內的程式碼)將會看到新的值。

因此,在我們的案例中,如果我們在 async 依賴項中設定 Peewee 狀態變數(使用預設的 dict),我們應用程式中的所有其他內部程式碼都將看到此值,並能夠在整個請求中重複使用它。

即使請求是並發的,上下文變數也會在下一個請求中重新設置。

在依賴項 get_db() 中設置資料庫狀態

由於 get_db() 是一個普通的 def 函式,FastAPI 會讓它在一個執行緒池中運行,並帶有「上下文」的*副本*,其中包含上下文變數(帶有重置資料庫狀態的 dict)的相同值。然後它可以將資料庫狀態添加到該 dict 中,例如連線等。

但是,如果上下文變數(預設的 dict)的值是在該普通 def 函式中設置的,它將會創建一個新值,該值只會保留在執行緒池的該執行緒中,而其餘程式碼(例如*路徑操作函式*)將無法訪問它。在 get_db() 中,我們只能設置 dict 中的值,而不能設置整個 dict 本身。

因此,我們需要非同步依賴項 reset_db_state() 來在上下文變數中設置 dict。這樣,所有程式碼都可以訪問同一個用於單個請求的資料庫狀態的 dict

在依賴項 get_db() 中連線和斷開連線

那麼下一個問題是,為什麼不直接在非同步依賴項本身中連線和斷開資料庫,而是在 get_db() 中呢?

非同步依賴項必須是 async 才能將上下文變數保留用於請求的其餘部分,但創建和關閉資料庫連線可能會造成阻塞,因此如果它在那裡,可能會降低效能。

所以我們也需要普通的 def 依賴項 get_db()