跳至內容

測試資料庫

資訊

這些文件即將更新。 🎉

目前的版本假設使用 Pydantic v1 和 SQLAlchemy 2.0 以下的版本。

新的文件將包含 Pydantic v2,並將使用 SQLModel(它也基於 SQLAlchemy),一旦它也更新為使用 Pydantic v2。

您可以使用與 使用覆蓋測試依賴項 中相同的依賴項覆蓋來更改資料庫以進行測試。

您可能想要設定不同的資料庫以進行測試,在測試後回滾資料,預先填入一些測試資料等等。

主要概念與您在上一章中看到的完全相同。

為 SQL 應用程式新增測試

讓我們更新 SQL (關聯式) 資料庫 的範例以使用測試資料庫。

所有應用程式程式碼都相同,您可以回到該章節查看它是如何的。

這裡唯一的變更是在新的測試檔案中。

您正常的依賴項 get_db() 會返回一個資料庫工作階段。

在測試中,您可以使用依賴項覆蓋來返回您的*自訂*資料庫工作階段,而不是通常使用的資料庫工作階段。

在此範例中,我們將建立一個僅用於測試的臨時資料庫。

檔案結構

我們在 sql_app/tests/test_sql_app.py 建立一個新檔案。

所以新的檔案結構看起來像

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    ├── models.py
    ├── schemas.py
    └── tests
        ├── __init__.py
        └── test_sql_app.py

建立新的資料庫工作階段

首先,我們使用新的資料庫建立一個新的資料庫工作階段。

我們將使用在測試期間持續存在的記憶體資料庫,而不是本地檔案 sql_app.db

但其餘的工作階段程式碼大致相同,我們只是複製它。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

提示

您可以透過將程式碼放入函式並從 database.pytests/test_sql_app.py 使用它來減少程式碼重複。

為了簡潔起見並專注於特定的測試程式碼,我們只是複製它。

建立資料庫

因為現在我們將在新檔案中使用新的資料庫,所以我們需要確保我們使用以下程式碼建立資料庫

Base.metadata.create_all(bind=engine)

這通常在 main.py 中呼叫,但 main.py 中的行使用資料庫檔案 sql_app.db,而我們需要確保我們為測試建立 test.db

所以我們在這裡添加該行,以及新檔案。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

依賴項覆蓋

現在我們建立依賴項覆蓋並將其添加到我們應用程式的覆蓋中。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

提示

override_get_db() 的程式碼與 get_db() 幾乎完全相同,但在 override_get_db() 中,我們使用測試資料庫的 TestingSessionLocal

測試應用程式

然後我們就可以像平常一樣測試應用程式了。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

而且我們在測試期間對資料庫所做的所有修改都將在 test.db 資料庫中,而不是主要的 sql_app.db 中。