"""Tree service. Creating a tree also creates the owner's TreeMembership (the authorization basis) and an audit entry. Reads go through the privacy engine. """ import uuid from datetime import UTC, datetime from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.enums import MembershipRole, TreeVisibility from app.models.tree import Tree, TreeMembership from app.models.user import User from app.repositories.base import BaseRepository from app.services import privacy from app.services.audit import record_audit from app.services.exceptions import Forbidden, NotFound async def create_tree( session: AsyncSession, *, owner: User, name: str, description: str | None = None, visibility: TreeVisibility = TreeVisibility.private, ) -> Tree: tree = Tree(owner_id=owner.id, name=name, description=description, visibility=visibility) session.add(tree) await session.flush() # assign tree.id session.add(TreeMembership(tree_id=tree.id, user_id=owner.id, role=MembershipRole.owner)) record_audit( session, action="create", entity_type="Tree", entity_id=tree.id, tree_id=tree.id, actor_user_id=owner.id, after={"name": name, "visibility": visibility.value}, ) await session.commit() await session.refresh(tree) return tree async def list_trees_for_user(session: AsyncSession, *, user: User) -> list[Tree]: stmt = ( select(Tree) .join(TreeMembership, TreeMembership.tree_id == Tree.id) .where(TreeMembership.user_id == user.id, Tree.deleted_at.is_(None)) .order_by(Tree.created_at) ) return list((await session.execute(stmt)).scalars().all()) async def get_tree(session: AsyncSession, *, viewer_id: uuid.UUID, tree_id: uuid.UUID) -> Tree: tree = await BaseRepository(session, Tree).get(tree_id) if tree is None: raise NotFound("tree not found") if not await privacy.can_view_tree(session, user_id=viewer_id, tree=tree): raise Forbidden("not permitted to view this tree") return tree async def update_tree( session: AsyncSession, *, actor: User, tree_id: uuid.UUID, changes: dict ) -> Tree: tree = await BaseRepository(session, Tree).get(tree_id) if tree is None: raise NotFound("tree not found") if not await privacy.can_edit_tree(session, user_id=actor.id, tree=tree): raise Forbidden("not an editor of this tree") for key in {"name", "description", "visibility", "home_person_id"} & changes.keys(): setattr(tree, key, changes[key]) record_audit( session, action="update", entity_type="Tree", entity_id=tree.id, tree_id=tree.id, actor_user_id=actor.id, after=changes, ) await session.commit() await session.refresh(tree) return tree async def _owned_tree(session: AsyncSession, *, actor: User, tree_id: uuid.UUID) -> Tree: """Load a tree (including soft-deleted) and require the actor be its owner.""" tree = await BaseRepository(session, Tree).get(tree_id, include_deleted=True) if tree is None: raise NotFound("tree not found") role = await privacy.get_membership_role(session, actor.id, tree.id) if role is not MembershipRole.owner: raise Forbidden("only the owner can delete or restore a tree") return tree async def delete_tree(session: AsyncSession, *, actor: User, tree_id: uuid.UUID) -> None: tree = await _owned_tree(session, actor=actor, tree_id=tree_id) if tree.deleted_at is None: tree.deleted_at = datetime.now(UTC) record_audit( session, action="delete", entity_type="Tree", entity_id=tree.id, tree_id=tree.id, actor_user_id=actor.id, ) await session.commit() async def restore_tree(session: AsyncSession, *, actor: User, tree_id: uuid.UUID) -> Tree: tree = await _owned_tree(session, actor=actor, tree_id=tree_id) if tree.deleted_at is not None: tree.deleted_at = None record_audit( session, action="restore", entity_type="Tree", entity_id=tree.id, tree_id=tree.id, actor_user_id=actor.id, ) await session.commit() return tree async def list_deleted_trees_for_user(session: AsyncSession, *, user: User) -> list[Tree]: stmt = ( select(Tree) .join(TreeMembership, TreeMembership.tree_id == Tree.id) .where(TreeMembership.user_id == user.id, Tree.deleted_at.is_not(None)) .order_by(Tree.deleted_at.desc()) ) return list((await session.execute(stmt)).scalars().all())