#!/usr/bin/env python3 import os import enum from typing import List from pydantic import BaseModel from result import Result, Ok, Err, is_err, is_ok from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select import datetime import pytest def utc_now(): # We remove the timezone information, because SQLite does not support it return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) class User(SQLModel, table=True): id: int = Field(primary_key=True) name: str joined_at: datetime.datetime = Field(default_factory=utc_now) active: bool = True participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user") def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]: accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] if accepted_requests: return Ok(max(accepted_requests, key=lambda r: r.requested_at)) else: return Err(None) def sort_time(self): last_accepted = self.last_accepted_participation_request() if is_ok(last_accepted): last_accepted_obj = last_accepted.unwrap() assert last_accepted_obj.state_since is not None return last_accepted_obj.state_since else: return self.joined_at def all_users_sorted(session: Session) -> List[User]: users = session.exec(select(User).where(User.active)).all() users.sort(key=lambda u: u.sort_time()) return users def next_user_to_send_request(session: Session, task: "Task") -> User | None: # Return the user who should send the next request (or None if we already # sent requests for this Task to all users) requested_users = [r.user for r in task.participation_requests] users = all_users_sorted(session) for u in users: if u in requested_users: continue return u def get_user_by_name(session: Session, name: str, only_active: bool = True) -> Result[User, None]: Q = select(User).where(User.name == name) if only_active: Q = Q.where(User.active) user = session.exec(Q).first() if user is None: return Err(None) else: return Ok(user) def get_participation_request_by_timestamp(session: Session, timestamp: datetime.datetime) -> Result["ParticipationRequest", None]: Q = select(ParticipationRequest).where(ParticipationRequest.requested_at == timestamp) request = session.exec(Q).first() if request is None: return Err(None) else: return Ok(request) class Task(SQLModel, table=True): id: int = Field(primary_key=True) name: str required_number_of_participants: int due: datetime.datetime timeout: int # in seconds chatgroup: str | None = None # None = to be created participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task") def accepted_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED] def rejected_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED] def requested_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED] def timeout_requests(self): return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT] def additional_requests_to_be_sent(self): # return the number of additional requests to be sent return self.required_number_of_participants \ - len(self.accepted_requests()) \ - len(self.requested_requests()) def freshly_expired_requests(self, now) -> List["ParticipationRequest"]: expired_requests = [] for r in self.participation_requests: maybe_timeout = r.check_for_timeout(now) if maybe_timeout is not None: expired_requests.append(r) return expired_requests def create_additional_requests(self, now: datetime.datetime, session: Session) -> Result[List["ParticipationRequest"], List["ParticipationRequest"]]: additional_requests = [] for _ in range(self.additional_requests_to_be_sent()): user = next_user_to_send_request(session, self) if user is None: # Incomplete requests return Err(additional_requests) request = ParticipationRequest( user=user, task=self, requested_at=now ) session.add(request) additional_requests.append(request) return Ok(additional_requests) def get_active_tasks(session: Session, now: datetime.datetime) -> List[Task]: return session.exec(select(Task).where(Task.due > now)).all() def get_tasks(session: Session) -> List[Task]: return session.exec(select(Task)).all() class ParticipationState(enum.Enum): REQUESTED = "requested" ACCEPTED = "accepted" REJECTED = "rejected" TIMEOUT = "timeout" # Transitions for a ParticipationState: # # Accepting or rejecting a request within timeout: # # - REQUESTED -> ACCEPTED (allowed) # - REQUESTED -> REJECTED (allowed) # # Normal timeout: # # - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout) # # Changing from ACCEPTED to REJECTED is allowed: # # - ACCEPTED -> REJECTED (allowed) # # Changing from REJECTED to ACCEPTED is allowed for 5 minutes only: # # - REJECTED -> ACCEPTED (allowed, within 5 minutes) # - REJECTED -> ACCEPTED (not allowed, after 5 minutes) # # Answered requests can not timeout: # # - REJECTED -> TIMEOUT (should not happen, bug) # - ACCEPTED -> TIMEOUT (should not happen, bug) # # Timed out requests can not be # # TIMEOUT -> ACCEPTED (not allowed) # TIMEOUT -> REJECTED (not allowed) # TIMEOUT -> REQUESTED (should not happen, bug) # # => TODO: What should happen after the due?!!! # => TODO: What happens when a user has two requests? class StateTransition: pass class StateTransitionError: pass class AcceptInTime(StateTransition): pass class AcceptAfterRejectAllowed(StateTransition): pass class AcceptAfterRejectExpired(StateTransitionError): pass class AcceptAfterTimeout(StateTransitionError): pass class AlreadyAccepted(StateTransitionError): pass class RejectInTime(StateTransition): pass class RejectAfterAccept(StateTransition): pass class RejectAfterTimeout(StateTransitionError): pass class AlreadyRejected(StateTransitionError): pass class Timeout(StateTransition): pass class ParticipationRequest(SQLModel, table=True): id: int = Field(primary_key=True) user_id: int = Field(foreign_key="user.id") task_id: int = Field(foreign_key="task.id") requested_at: datetime.datetime state_since: datetime.datetime | None = None state: ParticipationState = Field(default=ParticipationState.REQUESTED) task: Task = Relationship(back_populates="participation_requests") user: User = Relationship(back_populates="participation_requests") def _change_state(self, new_state: ParticipationState, now: datetime.datetime): # This method does not check any business rules, it just changes the state self.state_since = now self.state = new_state def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]: if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: # We want to make the order in which check_for_timeout() and try_accept() is called irrelevant if self.is_timed_out(now): return Err(AcceptAfterTimeout()) else: self._change_state(ParticipationState.ACCEPTED, now) return Ok(AcceptInTime()) elif self.state == ParticipationState.REJECTED: if (now - self.state_since).total_seconds() < 5 * 60: self._change_state(ParticipationState.ACCEPTED, now) return Ok(AcceptAfterRejectAllowed()) else: return Err(AcceptAfterRejectExpired()) elif self.state == ParticipationState.ACCEPTED: return Err(AlreadyAccepted()) else: raise Exception("Unknown old state " + str(self.state) + ".") def try_reject(self, now: datetime.datetime) -> Result[RejectInTime | RejectAfterAccept, AlreadyRejected | RejectAfterTimeout]: if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT: # We want to make the order in which check_for_timeout() and try_reject() is called irrelevant if self.is_timed_out(now): return Err(RejectAfterTimeout()) else: self._change_state(ParticipationState.REJECTED, now) return Ok(RejectInTime()) elif self.state == ParticipationState.ACCEPTED: self._change_state(ParticipationState.REJECTED, now) return Ok(RejectAfterAccept()) elif self.state == ParticipationState.REJECTED: return Err(AlreadyRejected()) else: raise Exception("Unknown old state " + str(self.state) + ".") def check_for_timeout(self, now: datetime.datetime) -> Timeout | None: if self.state == ParticipationState.REQUESTED: if self.is_timed_out(now): self._change_state(ParticipationState.TIMEOUT, now) return Timeout() def is_timed_out(self, now: datetime.datetime) -> bool: return (now - self.requested_at).total_seconds() > self.task.timeout @pytest.fixture def session(): if os.path.exists("/tmp/test.db"): os.remove("/tmp/test.db") engine = create_engine("sqlite:////tmp/test.db") SQLModel.metadata.create_all(engine) with Session(engine) as session: yield session def test_all_users_sorted(session): u1 = User(name="u1") u2 = User(name="u2") u3 = User(name="u3") session.add(u1) session.add(u2) session.add(u3) session.commit() users = all_users_sorted(session) assert users == [u1, u2, u3] t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10) session.add(t1) session.commit() r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now()) session.add(r1) session.commit() assert t1.accepted_requests() == [] users = all_users_sorted(session) r1._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() assert t1.accepted_requests() == [r1] users = all_users_sorted(session) assert users == [u2, u3, u1] r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now()) session.add(r2) r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now()) session.add(r3) session.commit() r3._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() assert t1.accepted_requests() == [r1, r3] users = all_users_sorted(session) assert users == [u2, u1, u3] r2._change_state(ParticipationState.ACCEPTED, utc_now()) session.commit() users = all_users_sorted(session) assert users == [u1, u3, u2] def test_next_user(session): u1 = User(name="u1") u2 = User(name="u2") u3 = User(name="u3") session.add(u1) session.add(u2) session.add(u3) session.commit() # Starting order users = all_users_sorted(session) assert users == [u1, u2, u3] t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10) session.add(t1) assert t1.additional_requests_to_be_sent() == 2 assert next_user_to_send_request(session, t1) == u1 r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now()) session.add(r1) assert next_user_to_send_request(session, t1) == u2 # Starting order users = all_users_sorted(session) assert users == [u1, u2, u3] # Another additional request should be sent out. assert t1.additional_requests_to_be_sent() == 1 # After rejecting, two additional requests should be sent out. assert isinstance(r1.try_reject(utc_now()).unwrap(), RejectInTime) assert t1.additional_requests_to_be_sent() == 2 # Create another Task t2 = Task(name="t2", required_number_of_participants=2, due=utc_now(), timeout=10) session.add(t2) # Next user should be user1 (because we rejected in t1) assert next_user_to_send_request(session, t2) == u1 # Still starting order (because we rejected) users = all_users_sorted(session) assert users == [u1, u2, u3] # When we change to accept, only one additional request should be sent out. assert isinstance(r1.try_accept(utc_now()).unwrap(), AcceptAfterRejectAllowed) assert t1.additional_requests_to_be_sent() == 1 # Next user should be u2 for t2 (because u1 accepted for t1) assert next_user_to_send_request(session, t2) == u2 # Next user should be user2 for t1 also assert next_user_to_send_request(session, t1) == u2 # Still starting order (because we rejected) users = all_users_sorted(session) assert users == [u2, u3, u1] # Next user should be user2 assert next_user_to_send_request(session, t1) == u2 r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now()) session.add(r3) # Next user should still be user2 assert next_user_to_send_request(session, t1) == u2 r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now()) session.add(r2) # No user should be left assert next_user_to_send_request(session, t1) is None session.commit() # We sent one request too much... assert t1.additional_requests_to_be_sent() == -1 def test_accept(session): now = datetime.datetime(2024, 12, 10, 0, 0) dt = datetime.timedelta(days=1) u1 = User(name="u1") # - Task will happen at now + 10*dt # - Each user has 1*dt after the request to answer (before timeout) t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) session.add(u1) session.add(t1) session.add(r1) session.commit() accept_result = r1.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap(), AcceptInTime) session.commit() accept_result = r1.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap_err(), AlreadyAccepted) session.commit() # Accepted should not timeout timeout_result = r1.check_for_timeout(now + 1.6*dt) assert timeout_result is None r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) session.add(r2) # AcceptAfterTimeout must be given even before check_for_timeout() accept_result = r2.try_accept(now + 1.6*dt) assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) assert(r2.state == ParticipationState.REQUESTED) # Obtain Timeout timeout_result = r2.check_for_timeout(now + 1.6*dt) assert isinstance(timeout_result, Timeout) # Should still be a AcceptAfterTimeout accept_result = r2.try_accept(now + 1.6*dt) assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout) assert(r2.state == ParticipationState.TIMEOUT) # If we reiceive a late accept, we should accept it accept_result = r2.try_accept(now + 0.5*dt) assert isinstance(accept_result.unwrap(), AcceptInTime) assert(r2.state == ParticipationState.ACCEPTED) session.commit() def test_reject(session): now = datetime.datetime(2024, 12, 10, 0, 0) dt = datetime.timedelta(days=1) u1 = User(name="u1") # - Task will happen at now + 10*dt # - Each user has 1*dt after the request to answer (before timeout) t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) r1 = ParticipationRequest(user=u1, task=t1, requested_at=now) session.add(u1) session.add(t1) session.add(r1) session.commit() reject_result = r1.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap(), RejectInTime) session.commit() reject_result = r1.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap_err(), AlreadyRejected) session.commit() r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt) session.add(r2) # RejectAfterTimeout must be given even before check_for_timeout() reject_result = r2.try_reject(now + 1.6*dt) assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) assert(r2.state == ParticipationState.REQUESTED) # Obtain Timeout timeout_result = r2.check_for_timeout(now + 1.6*dt) assert isinstance(timeout_result, Timeout) assert(r2.state == ParticipationState.TIMEOUT) # Should still be a RejectAfterTimeout reject_result = r2.try_reject(now + 1.6*dt) assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout) # If we receive a late reject, we should accept it reject_result = r2.try_reject(now + 0.5*dt) assert isinstance(reject_result.unwrap(), RejectInTime) assert(r2.state == ParticipationState.REJECTED) session.commit() def test_active(session): u1 = User(name="u1") u2 = User(name="u2") u3 = User(name="u3") session.add(u1) session.add(u2) session.add(u3) session.commit() users = all_users_sorted(session) assert users == [u1, u2, u3] u2.active = False session.commit() users = all_users_sorted(session) assert users == [u1, u3] u2.active = True session.commit() users = all_users_sorted(session) assert users == [u1, u2, u3] def test_create_additional_requests(session): now = datetime.datetime(2024, 12, 10, 0, 0) dt = datetime.timedelta(days=1) u1 = User(name="u1") u2 = User(name="u2") u3 = User(name="u3") session.add(u1) session.add(u2) session.add(u3) session.commit() t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds()) session.add(t1) requests = t1.create_additional_requests(now + 0.5*dt, session) assert(is_ok(requests)) requests = requests.unwrap() assert(len(requests) == 2) requested_users = [req.user for req in requests] assert(requested_users == [u1, u2]) assert(isinstance(requests[0].try_reject(now + 0.75*dt).unwrap(), RejectInTime)) assert(isinstance(requests[1].try_reject(now + 0.75*dt).unwrap(), RejectInTime)) requests = t1.create_additional_requests(now + 0.5*dt, session) # could not create all requests, so we have Err now assert(is_err(requests)) requests = requests.unwrap_err() assert(len(requests) == 1) requested_users = [req.user for req in requests] assert(requested_users == [u3]) # No user left to ask... requests = t1.create_additional_requests(now + 0.5*dt, session) assert(is_err(requests)) requests = requests.unwrap_err() assert(len(requests) == 0) u4 = User(name="u4") session.add(u4) # Now, we have a user again, so we can ask. Since u3 did not # reject (so far), we only need one additional request. So, # the result is Ok(...). requests = t1.create_additional_requests(now + 0.5*dt, session) assert(is_ok(requests)) requests = requests.unwrap() assert(len(requests) == 1) requested_users = [req.user for req in requests] assert(requested_users == [u4]) session.commit()