#!/usr/bin/env python3 import os import enum from typing import List from pydantic import BaseModel from result import Result, Ok, Err, is_ok from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select import datetime import pytest def utc_now(): return datetime.datetime.now(datetime.timezone.utc) class User(SQLModel, table=True): id: int = Field(primary_key=True) name: str joined_at: datetime.datetime = Field(default_factory=utc_now) participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user") def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]: accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] if accepted_requests: return Ok(max(accepted_requests, key=lambda r: r.requested_at)) else: return Err(None) def sort_time(self): last_accepted = self.last_accepted_participation_request() if is_ok(last_accepted): last_accepted_obj = last_accepted.unwrap() assert last_accepted_obj.state_since is not None return last_accepted_obj.state_since else: return self.joined_at def all_users_sorted(session: Session) -> List[User]: users = session.exec(select(User)).all() users.sort(key=lambda u: u.sort_time()) return users class Task(SQLModel, table=True): id: int = Field(primary_key=True) name: str required_number_of_participants: int due: datetime.datetime timeout: int # in seconds participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task") def accepted_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] def rejected_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED] def requested_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED] def timeout_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT] def additional_requests_to_be_sent(self): # return the number of additional requests to be sent return self.required_number_of_participants \ - len(self.accepted_participation_requests()) \ - len(self.requested_participation_requests()) def freshly_expired_requests(self, now) -> List["ParticipationRequest"]: expired_requests = [] for r in self.participation_requests: maybe_timeout = r.check_for_timeout(now) if maybe_timeout is not None: expired_requests.append(r) class ParticipationState(enum.Enum): REQUESTED = "requested" ACCEPTED = "accepted" REJECTED = "rejected" TIMEOUT = "timeout" # Transitions for a ParticipationState: # # Accepting or rejecting a request within timeout: # # - REQUESTED -> ACCEPTED (allowed) # - REQUESTED -> REJECTED (allowed) # # Normal timeout: # # - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout) # # Changing from ACCEPTED to REJECTED is allowed: # # - ACCEPTED -> REJECTED (allowed) # # Changing from REJECTED to ACCEPTED is allowed for 5 minutes only: # # - REJECTED -> ACCEPTED (allowed, within 5 minutes) # - REJECTED -> ACCEPTED (not allowed, after 5 minutes) # # Answered requests can not timeout: # # - REJECTED -> TIMEOUT (should not happen, bug) # - ACCEPTED -> TIMEOUT (should not happen, bug) # # Timed out requests can not be # # TIMEOUT -> ACCEPTED (not allowed) # TIMEOUT -> REJECTED (not allowed) # TIMEOUT -> REQUESTED (should not happen, bug) # # => TODO: What should happen after the due?!!! class StateTransition: pass class StateTransitionError: pass class AcceptInTime(StateTransition): pass class AcceptAfterRejectAllowed(StateTransition): pass class AcceptAfterRejectExpired(StateTransitionError): pass class AcceptAfterTimeout(StateTransitionError): pass class AlreadyAccepted(StateTransitionError): pass class RejectInTime(StateTransition): pass class RejectAfterAccept(StateTransition): pass class RejectAfterTimeout(StateTransitionError): pass class AlreadyRejected(StateTransitionError): pass class Timeout(StateTransition): pass class ParticipationRequest(SQLModel, table=True): id: int = Field(primary_key=True) user_id: int = Field(foreign_key="user.id") task_id: int = Field(foreign_key="task.id") requested_at: datetime.datetime state_since: datetime.datetime | None = None state: ParticipationState = Field(default=ParticipationState.REQUESTED) task: Task = Relationship(back_populates="participation_requests") user: User = Relationship(back_populates="participation_requests") def _change_state(self, new_state: ParticipationState, now: datetime.datetime): # This method does not check any business rules, it just changes the state self.state_since = now self.state = new_state def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]: if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: # We want to make the order in which check_for_timeout() and try_accept() is called irrelevant if self.is_timed_out(now): return Err(AcceptAfterTimeout()) else: self._change_state(ParticipationState.ACCEPTED, now) return Ok(AcceptInTime()) elif self.state == ParticipationState.REJECTED: if (now - self.state_since).total_seconds() < 5 * 60: self._change_state(ParticipationState.ACCEPTED, now) return Ok(AcceptAfterRejectAllowed()) else: return Err(AcceptAfterRejectExpired()) elif self.state == ParticipationState.ACCEPTED: return Err(AlreadyAccepted()) else: raise Exception("Unknown old state " + str(self.state) + ".") def try_reject(self, now: datetime.datetime) -> Result[RejectInTime, RejectAfterAccept | RejectAfterTimeout]: if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: # We want to make the order in which check_for_timeout() and try_reject() is called irrelevant if self.is_timed_out(now): return Err(RejectAfterTimeout()) else: self._change_state(ParticipationState.REJECTED, now) return Ok(RejectInTime()) elif self.state == ParticipationState.ACCEPTED: self._change_state(ParticipationState.REJECTED, now) return Ok(RejectAfterAccept()) elif self.state == ParticipationState.REJECTED: return Err(AlreadyRejected()) else: raise Exception("Unknown old state " + str(self.state) + ".") def check_for_timeout(self, now: datetime.datetime) -> Timeout | None: if self.state == ParticipationState.REQUESTED: if self.is_timed_out(now): self._change_state(ParticipationState.TIMEOUT, now) return Timeout() def is_timed_out(self, now: datetime.datetime) -> bool: return (now - self.requested_at).total_seconds() > self.task.timeout @pytest.fixture def session(): if os.path.exists("/tmp/test.db"): os.remove("/tmp/test.db") engine = create_engine("sqlite:////tmp/test.db") SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session def test_all_users_sorted(session): u1 = User(name="u1") u2 = User(name="u2") u3 = User(name="u3") session.add(u1) session.add(u2) session.add(u3) session.commit() users = all_users_sorted(session) assert users == [u1, u2, u3] t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10) session.add(t1) session.commit() r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now()) session.add(r1) session.commit() assert t1.accepted_requests() == [] users = all_users_sorted(session) r1._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() assert t1.accepted_requests() == [r1] users = all_users_sorted(session) assert users == [u2, u3, u1] r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now()) session.add(r2) r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now()) session.add(r3) session.commit() r3._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() assert t1.accepted_requests() == [r1, r3] users = all_users_sorted(session) assert users == [u2, u1, u3] r2._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() users = all_users_sorted(session) assert users == [u1, u3, u2] def test_accept(session): now = datetime.datetime(2024, 12, 10, 0, 0) dt = datetime.timedelta(days=1) u1 = User(name="u1") # - Task will happen at now + 10*dt # - Each user has 1*dt after the request to answer (before timeout) t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) session.add(u1) session.add(t1) session.add(r1) session.commit() accept_result = r1.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap(), AcceptInTime) session.commit() accept_result = r1.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap_err(), AlreadyAccepted) session.commit() # Accepted should not timeout timeout_result = r1.check_for_timeout(now + 1.6*dt) assert timeout_result is None r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) session.add(r2) # AcceptAfterTimeout must be given even before check_for_timeout() accept_result = r2.try_accept(now + 1.6*dt) assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) assert(r2.state == ParticipationState.REQUESTED) # Obtain Timeout timeout_result = r2.check_for_timeout(now + 1.6*dt) assert isinstance(timeout_result, Timeout) # Should still be a AcceptAfterTimeout accept_result = r2.try_accept(now + 1.6*dt) assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) assert(r2.state == ParticipationState.TIMEOUT) # If we reiceive a late accept, we should accept it accept_result = r2.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap(), AcceptInTime) assert(r2.state == ParticipationState.ACCEPTED) session.commit() def test_reject(session): now = datetime.datetime(2024, 12, 10, 0, 0) dt = datetime.timedelta(days=1) u1 = User(name="u1") # - Task will happen at now + 10*dt # - Each user has 1*dt after the request to answer (before timeout) t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) session.add(u1) session.add(t1) session.add(r1) session.commit() reject_result = r1.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap(), RejectInTime) session.commit() reject_result = r1.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap_err(), AlreadyRejected) session.commit() r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) session.add(r2) # RejectAfterTimeout must be given even before check_for_timeout() reject_result = r2.try_reject(now + 1.6*dt) assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) assert(r2.state == ParticipationState.REQUESTED) # Obtain Timeout timeout_result = r2.check_for_timeout(now + 1.6*dt) assert isinstance(timeout_result, Timeout) assert(r2.state == ParticipationState.TIMEOUT) # Should still be a RejectAfterTimeout reject_result = r2.try_reject(now + 1.6*dt) assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) # If we receive a late reject, we should accept it reject_result = r2.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap(), RejectInTime) assert(r2.state == ParticipationState.REJECTED) session.commit()