diff --git a/models.py b/models.py new file mode 100644 index 0000000..7bdb392 --- /dev/null +++ b/models.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 + +import os +import enum +from typing import List +from pydantic import BaseModel +from result import Result, Ok, Err, is_ok +from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select + +import datetime +import pytest + +def utc_now(): + return datetime.datetime.now(datetime.timezone.utc) + +class User(SQLModel, table=True): + id: int = Field(primary_key=True) + name: str + joined_at: datetime.datetime = Field(default_factory=utc_now) + + participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user") + + def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]: + accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] + if accepted_requests: + return Ok(max(accepted_requests, key=lambda r: r.requested_at)) + else: + return Err(None) + + def sort_time(self): + last_accepted = self.last_accepted_participation_request() + if is_ok(last_accepted): + last_accepted_obj = last_accepted.unwrap() + assert last_accepted_obj.state_since is not None + return last_accepted_obj.state_since + else: + return self.joined_at + +def all_users_sorted(session: Session) -> List[User]: + users = session.exec(select(User)).all() + users.sort(key=lambda u: u.sort_time()) + return users + +class Task(SQLModel, table=True): + id: int = Field(primary_key=True) + name: str + required_number_of_participants: int + due: datetime.datetime + timeout: int # in seconds + + participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task") + + def accepted_requests(self): + return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] + + def rejected_requests(self): + return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED] + + def requested_requests(self): + return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED] + + def timeout_requests(self): + return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT] + + def additional_requests_to_be_sent(self): + # return the number of additional requests to be sent + return self.required_number_of_participants \ + - len(self.accepted_participation_requests()) \ + - len(self.requested_participation_requests()) + + def freshly_expired_requests(self, now) -> List["ParticipationRequest"]: + expired_requests = [] + + for r in self.participation_requests: + maybe_timeout = r.check_for_timeout(now) + if maybe_timeout is not None: + expired_requests.append(r) + +class ParticipationState(enum.Enum): + REQUESTED = "requested" + ACCEPTED = "accepted" + REJECTED = "rejected" + TIMEOUT = "timeout" + +# Transitions for a ParticipationState: +# +# Accepting or rejecting a request within timeout: +# +# - REQUESTED -> ACCEPTED (allowed) +# - REQUESTED -> REJECTED (allowed) +# +# Normal timeout: +# +# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout) +# +# Changing from ACCEPTED to REJECTED is allowed: +# +# - ACCEPTED -> REJECTED (allowed) +# +# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only: +# +# - REJECTED -> ACCEPTED (allowed, within 5 minutes) +# - REJECTED -> ACCEPTED (not allowed, after 5 minutes) +# +# Answered requests can not timeout: +# +# - REJECTED -> TIMEOUT (should not happen, bug) +# - ACCEPTED -> TIMEOUT (should not happen, bug) +# +# Timed out requests can not be +# +# TIMEOUT -> ACCEPTED (not allowed) +# TIMEOUT -> REJECTED (not allowed) +# TIMEOUT -> REQUESTED (should not happen, bug) +# +# => TODO: What should happen after the due?!!! + + +class StateTransition: + pass + +class StateTransitionError: + pass + +class AcceptInTime(StateTransition): + pass + +class AcceptAfterRejectAllowed(StateTransition): + pass + +class AcceptAfterRejectExpired(StateTransitionError): + pass + +class AcceptAfterTimeout(StateTransitionError): + pass + +class AlreadyAccepted(StateTransitionError): + pass + +class RejectInTime(StateTransition): + pass + +class RejectAfterAccept(StateTransition): + pass + +class RejectAfterTimeout(StateTransitionError): + pass + +class AlreadyRejected(StateTransitionError): + pass + +class Timeout(StateTransition): + pass + + +class ParticipationRequest(SQLModel, table=True): + id: int = Field(primary_key=True) + user_id: int = Field(foreign_key="user.id") + task_id: int = Field(foreign_key="task.id") + requested_at: datetime.datetime + + state_since: datetime.datetime | None = None + state: ParticipationState = Field(default=ParticipationState.REQUESTED) + + task: Task = Relationship(back_populates="participation_requests") + user: User = Relationship(back_populates="participation_requests") + + def _change_state(self, new_state: ParticipationState, now: datetime.datetime): + # This method does not check any business rules, it just changes the state + self.state_since = now + self.state = new_state + + def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]: + if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: + # We want to make the order in which check_for_timeout() and try_accept() is called irrelevant + if self.is_timed_out(now): + return Err(AcceptAfterTimeout()) + else: + self._change_state(ParticipationState.ACCEPTED, now) + return Ok(AcceptInTime()) + elif self.state == ParticipationState.REJECTED: + if (now - self.state_since).total_seconds() < 5 * 60: + self._change_state(ParticipationState.ACCEPTED, now) + return Ok(AcceptAfterRejectAllowed()) + else: + return Err(AcceptAfterRejectExpired()) + elif self.state == ParticipationState.ACCEPTED: + return Err(AlreadyAccepted()) + else: + raise Exception("Unknown old state " + str(self.state) + ".") + + def try_reject(self, now: datetime.datetime) -> Result[RejectInTime, RejectAfterAccept | RejectAfterTimeout]: + if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: + # We want to make the order in which check_for_timeout() and try_reject() is called irrelevant + if self.is_timed_out(now): + return Err(RejectAfterTimeout()) + else: + self._change_state(ParticipationState.REJECTED, now) + return Ok(RejectInTime()) + elif self.state == ParticipationState.ACCEPTED: + self._change_state(ParticipationState.REJECTED, now) + return Ok(RejectAfterAccept()) + elif self.state == ParticipationState.REJECTED: + return Err(AlreadyRejected()) + else: + raise Exception("Unknown old state " + str(self.state) + ".") + + def check_for_timeout(self, now: datetime.datetime) -> Timeout | None: + if self.state == ParticipationState.REQUESTED: + if self.is_timed_out(now): + self._change_state(ParticipationState.TIMEOUT, now) + return Timeout() + + def is_timed_out(self, now: datetime.datetime) -> bool: + return (now - self.requested_at).total_seconds() > self.task.timeout + + +@pytest.fixture +def session(): + if os.path.exists("/tmp/test.db"): + os.remove("/tmp/test.db") + + engine = create_engine("sqlite:////tmp/test.db") + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + +def test_all_users_sorted(session): + u1 = User(name="u1") + u2 = User(name="u2") + u3 = User(name="u3") + + session.add(u1) + session.add(u2) + session.add(u3) + + session.commit() + + users = all_users_sorted(session) + + assert users == [u1, u2, u3] + + t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10) + + session.add(t1) + session.commit() + + r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now()) + + session.add(r1) + session.commit() + + assert t1.accepted_requests() == [] + + users = all_users_sorted(session) + + r1._change_state(ParticipationState.ACCEPTED, utc_now()) + session.commit() + + assert t1.accepted_requests() == [r1] + users = all_users_sorted(session) + assert users == [u2, u3, u1] + + r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now()) + session.add(r2) + + r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now()) + session.add(r3) + + session.commit() + + r3._change_state(ParticipationState.ACCEPTED, utc_now()) + + session.commit() + + assert t1.accepted_requests() == [r1, r3] + users = all_users_sorted(session) + assert users == [u2, u1, u3] + + r2._change_state(ParticipationState.ACCEPTED, utc_now()) + session.commit() + + users = all_users_sorted(session) + assert users == [u1, u3, u2] + +def test_accept(session): + now = datetime.datetime(2024, 12, 10, 0, 0) + dt = datetime.timedelta(days=1) + + u1 = User(name="u1") + + # - Task will happen at now + 10*dt + # - Each user has 1*dt after the request to answer (before timeout) + t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) + + r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) + + session.add(u1) + session.add(t1) + session.add(r1) + session.commit() + + accept_result = r1.try_accept(now + 0.5*dt) + assert isinstance(accept_result.unwrap(), AcceptInTime) + session.commit() + + accept_result = r1.try_accept(now + 0.5*dt) + assert isinstance(accept_result.unwrap_err(), AlreadyAccepted) + session.commit() + + # Accepted should not timeout + timeout_result = r1.check_for_timeout(now + 1.6*dt) + assert timeout_result is None + + r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) + session.add(r2) + + # AcceptAfterTimeout must be given even before check_for_timeout() + accept_result = r2.try_accept(now + 1.6*dt) + assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) + + assert(r2.state == ParticipationState.REQUESTED) + + # Obtain Timeout + timeout_result = r2.check_for_timeout(now + 1.6*dt) + assert isinstance(timeout_result, Timeout) + + # Should still be a AcceptAfterTimeout + accept_result = r2.try_accept(now + 1.6*dt) + assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) + + assert(r2.state == ParticipationState.TIMEOUT) + + # If we reiceive a late accept, we should accept it + accept_result = r2.try_accept(now + 0.5*dt) + assert isinstance(accept_result.unwrap(), AcceptInTime) + + assert(r2.state == ParticipationState.ACCEPTED) + + session.commit() + +def test_reject(session): + now = datetime.datetime(2024, 12, 10, 0, 0) + dt = datetime.timedelta(days=1) + + u1 = User(name="u1") + + # - Task will happen at now + 10*dt + # - Each user has 1*dt after the request to answer (before timeout) + t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) + + r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) + + session.add(u1) + session.add(t1) + session.add(r1) + session.commit() + + reject_result = r1.try_reject(now + 0.5*dt) + assert isinstance(reject_result.unwrap(), RejectInTime) + session.commit() + + reject_result = r1.try_reject(now + 0.5*dt) + assert isinstance(reject_result.unwrap_err(), AlreadyRejected) + session.commit() + + r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) + session.add(r2) + + # RejectAfterTimeout must be given even before check_for_timeout() + reject_result = r2.try_reject(now + 1.6*dt) + assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) + + assert(r2.state == ParticipationState.REQUESTED) + + # Obtain Timeout + timeout_result = r2.check_for_timeout(now + 1.6*dt) + assert isinstance(timeout_result, Timeout) + + assert(r2.state == ParticipationState.TIMEOUT) + + # Should still be a RejectAfterTimeout + reject_result = r2.try_reject(now + 1.6*dt) + assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) + + # If we receive a late reject, we should accept it + reject_result = r2.try_reject(now + 0.5*dt) + assert isinstance(reject_result.unwrap(), RejectInTime) + + assert(r2.state == ParticipationState.REJECTED) + + session.commit() + +