Fix list_persons N+1 (the ~4s person-page load) #246

Merged
justin merged 1 commits from fix-person-list-n-plus-one into main 2026-06-11 08:00:47 -04:00
3 changed files with 92 additions and 7 deletions
+45 -5
View File
@@ -45,6 +45,29 @@ async def _attach_primary_name(session: AsyncSession, person: Person) -> None:
person.primary_name = _format_name(name) if name is not None else None person.primary_name = _format_name(name) if name is not None else None
async def _attach_primary_names(session: AsyncSession, persons: list[Person]) -> None:
"""Batch version of ``_attach_primary_name`` — ONE query for the whole list
instead of one per person (the difference between 1 and N queries when
rendering a 2k-person tree). The global order (is_primary desc, sort_order)
matches the single-person query, so the first row seen per person is the same
name ``_attach_primary_name`` would pick."""
if not persons:
return
rows = (
await session.execute(
select(Name)
.where(Name.person_id.in_([p.id for p in persons]), Name.deleted_at.is_(None))
.order_by(Name.is_primary.desc(), Name.sort_order)
)
).scalars().all()
best: dict[uuid.UUID, Name] = {}
for n in rows:
best.setdefault(n.person_id, n)
for p in persons:
n = best.get(p.id)
p.primary_name = _format_name(n) if n is not None else None
async def create_person( async def create_person(
session: AsyncSession, session: AsyncSession,
*, *,
@@ -336,15 +359,18 @@ async def list_deleted_persons(
.order_by(Person.deleted_at.desc()) .order_by(Person.deleted_at.desc())
) )
persons = list((await session.execute(stmt)).scalars().all()) persons = list((await session.execute(stmt)).scalars().all())
for person in persons: await _attach_primary_names(session, persons)
await _attach_primary_name(session, person)
return persons return persons
async def list_persons( async def list_persons(
session: AsyncSession, *, viewer_id: uuid.UUID, tree: Tree session: AsyncSession, *, viewer_id: uuid.UUID, tree: Tree
) -> list[Person]: ) -> list[Person]:
if not await privacy.can_view_tree(session, user_id=viewer_id, tree=tree): # Resolve the viewer's role ONCE. Members see the whole tree (full), so we
# skip the per-person privacy engine entirely and batch the name fetch — the
# difference between ~3 queries and ~3·N queries on a 2k-person tree.
role = await privacy.get_membership_role(session, viewer_id, tree.id)
if role is None and not await privacy.can_view_tree(session, user_id=viewer_id, tree=tree):
raise Forbidden("not permitted to view this tree") raise Forbidden("not permitted to view this tree")
stmt = ( stmt = (
@@ -354,7 +380,15 @@ async def list_persons(
) )
persons = list((await session.execute(stmt)).scalars().all()) persons = list((await session.execute(stmt)).scalars().all())
if role is not None:
await _attach_primary_names(session, persons)
return persons
# Non-member on a viewable (public/unlisted/site_members) tree: redact per
# person. Names are batched for the non-redacted ones; redacted ones already
# have their display name overwritten by _redact.
visible: list[Person] = [] visible: list[Person] = []
full: list[Person] = []
for person in persons: for person in persons:
vis = await privacy.person_visibility( vis = await privacy.person_visibility(
session, user_id=viewer_id, tree=tree, person=person session, user_id=viewer_id, tree=tree, person=person
@@ -364,8 +398,9 @@ async def list_persons(
if vis == Visibility.redacted: if vis == Visibility.redacted:
_redact(person) _redact(person)
else: else:
await _attach_primary_name(session, person) full.append(person)
visible.append(person) visible.append(person)
await _attach_primary_names(session, full)
return visible return visible
@@ -406,7 +441,11 @@ async def search_persons(
.order_by(sub.c.score.desc()) .order_by(sub.c.score.desc())
) )
persons = list((await session.execute(stmt)).scalars().all()) persons = list((await session.execute(stmt)).scalars().all())
if await privacy.get_membership_role(session, viewer_id, tree.id) is not None:
await _attach_primary_names(session, persons)
return persons
out: list[Person] = [] out: list[Person] = []
full: list[Person] = []
for person in persons: for person in persons:
vis = await privacy.person_visibility( vis = await privacy.person_visibility(
session, user_id=viewer_id, tree=tree, person=person session, user_id=viewer_id, tree=tree, person=person
@@ -416,6 +455,7 @@ async def search_persons(
if vis == Visibility.redacted: if vis == Visibility.redacted:
_redact(person) _redact(person)
else: else:
await _attach_primary_name(session, person) full.append(person)
out.append(person) out.append(person)
await _attach_primary_names(session, full)
return out return out
+8 -2
View File
@@ -33,7 +33,11 @@ from app.models.source import Citation, Source
from app.models.tree import Tree from app.models.tree import Tree
from app.services import privacy from app.services import privacy
from app.services.exceptions import NotFound from app.services.exceptions import NotFound
from app.services.person_service import _attach_primary_name, _redact from app.services.person_service import (
_attach_primary_name,
_attach_primary_names,
_redact,
)
from app.services.privacy import Visibility from app.services.privacy import Visibility
@@ -78,6 +82,7 @@ async def list_public_persons(
session: AsyncSession, *, viewer_id: uuid.UUID | None, tree: Tree session: AsyncSession, *, viewer_id: uuid.UUID | None, tree: Tree
) -> list[Person]: ) -> list[Person]:
out: list[Person] = [] out: list[Person] = []
full: list[Person] = []
for p in await _persons(session, tree): for p in await _persons(session, tree):
vis = await privacy.person_visibility(session, user_id=viewer_id, tree=tree, person=p) vis = await privacy.person_visibility(session, user_id=viewer_id, tree=tree, person=p)
if vis == Visibility.hidden: if vis == Visibility.hidden:
@@ -85,8 +90,9 @@ async def list_public_persons(
if vis == Visibility.redacted: if vis == Visibility.redacted:
_redact(p) _redact(p)
else: else:
await _attach_primary_name(session, p) full.append(p)
out.append(p) out.append(p)
await _attach_primary_names(session, full) # one query, not one per person
return out return out
@@ -0,0 +1,39 @@
"""Regression guard: list_persons must batch — a constant number of queries,
not one (or three) per person. A 2k-person tree took ~4s before this was fixed."""
import sqlalchemy as sa
from tests.conftest import auth, register
async def test_list_persons_does_not_n_plus_one(client, engine):
owner = auth(await register(client, "perf-owner@ex.com"))
tid = (await client.post("/api/v1/trees", json={"name": "Perf"}, headers=owner)).json()["id"]
n = 25
for i in range(n):
await client.post(
f"/api/v1/trees/{tid}/persons",
json={"given": f"P{i}", "surname": "X"},
headers=owner,
)
selects = 0
def _count(conn, cursor, statement, params, context, executemany):
nonlocal selects
if statement.lstrip().upper().startswith("SELECT"):
selects += 1
sa.event.listen(engine.sync_engine, "before_cursor_execute", _count)
try:
resp = await client.get(f"/api/v1/trees/{tid}/persons", headers=owner)
finally:
sa.event.remove(engine.sync_engine, "before_cursor_execute", _count)
assert resp.status_code == 200
body = resp.json()
assert len(body) == n
assert all(p["primary_name"] for p in body) # names still resolve correctly
# Batched: a small constant (auth, role, persons, one names query, …) — NOT
# proportional to n. The old per-person path was ~3·n SELECTs.
assert 0 < selects < n, f"expected a constant query count, got {selects} for {n} people"