forked from LeineLab-Public/lab-signal-bot
395 lines
12 KiB
Python
395 lines
12 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import os
|
||
|
import enum
|
||
|
from typing import List
|
||
|
from pydantic import BaseModel
|
||
|
from result import Result, Ok, Err, is_ok
|
||
|
from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select
|
||
|
|
||
|
import datetime
|
||
|
import pytest
|
||
|
|
||
|
def utc_now():
|
||
|
return datetime.datetime.now(datetime.timezone.utc)
|
||
|
|
||
|
class User(SQLModel, table=True):
|
||
|
id: int = Field(primary_key=True)
|
||
|
name: str
|
||
|
joined_at: datetime.datetime = Field(default_factory=utc_now)
|
||
|
|
||
|
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user")
|
||
|
|
||
|
def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]:
|
||
|
accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
||
|
if accepted_requests:
|
||
|
return Ok(max(accepted_requests, key=lambda r: r.requested_at))
|
||
|
else:
|
||
|
return Err(None)
|
||
|
|
||
|
def sort_time(self):
|
||
|
last_accepted = self.last_accepted_participation_request()
|
||
|
if is_ok(last_accepted):
|
||
|
last_accepted_obj = last_accepted.unwrap()
|
||
|
assert last_accepted_obj.state_since is not None
|
||
|
return last_accepted_obj.state_since
|
||
|
else:
|
||
|
return self.joined_at
|
||
|
|
||
|
def all_users_sorted(session: Session) -> List[User]:
|
||
|
users = session.exec(select(User)).all()
|
||
|
users.sort(key=lambda u: u.sort_time())
|
||
|
return users
|
||
|
|
||
|
class Task(SQLModel, table=True):
|
||
|
id: int = Field(primary_key=True)
|
||
|
name: str
|
||
|
required_number_of_participants: int
|
||
|
due: datetime.datetime
|
||
|
timeout: int # in seconds
|
||
|
|
||
|
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task")
|
||
|
|
||
|
def accepted_requests(self):
|
||
|
return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
||
|
|
||
|
def rejected_requests(self):
|
||
|
return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED]
|
||
|
|
||
|
def requested_requests(self):
|
||
|
return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED]
|
||
|
|
||
|
def timeout_requests(self):
|
||
|
return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT]
|
||
|
|
||
|
def additional_requests_to_be_sent(self):
|
||
|
# return the number of additional requests to be sent
|
||
|
return self.required_number_of_participants \
|
||
|
- len(self.accepted_participation_requests()) \
|
||
|
- len(self.requested_participation_requests())
|
||
|
|
||
|
def freshly_expired_requests(self, now) -> List["ParticipationRequest"]:
|
||
|
expired_requests = []
|
||
|
|
||
|
for r in self.participation_requests:
|
||
|
maybe_timeout = r.check_for_timeout(now)
|
||
|
if maybe_timeout is not None:
|
||
|
expired_requests.append(r)
|
||
|
|
||
|
class ParticipationState(enum.Enum):
|
||
|
REQUESTED = "requested"
|
||
|
ACCEPTED = "accepted"
|
||
|
REJECTED = "rejected"
|
||
|
TIMEOUT = "timeout"
|
||
|
|
||
|
# Transitions for a ParticipationState:
|
||
|
#
|
||
|
# Accepting or rejecting a request within timeout:
|
||
|
#
|
||
|
# - REQUESTED -> ACCEPTED (allowed)
|
||
|
# - REQUESTED -> REJECTED (allowed)
|
||
|
#
|
||
|
# Normal timeout:
|
||
|
#
|
||
|
# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout)
|
||
|
#
|
||
|
# Changing from ACCEPTED to REJECTED is allowed:
|
||
|
#
|
||
|
# - ACCEPTED -> REJECTED (allowed)
|
||
|
#
|
||
|
# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only:
|
||
|
#
|
||
|
# - REJECTED -> ACCEPTED (allowed, within 5 minutes)
|
||
|
# - REJECTED -> ACCEPTED (not allowed, after 5 minutes)
|
||
|
#
|
||
|
# Answered requests can not timeout:
|
||
|
#
|
||
|
# - REJECTED -> TIMEOUT (should not happen, bug)
|
||
|
# - ACCEPTED -> TIMEOUT (should not happen, bug)
|
||
|
#
|
||
|
# Timed out requests can not be
|
||
|
#
|
||
|
# TIMEOUT -> ACCEPTED (not allowed)
|
||
|
# TIMEOUT -> REJECTED (not allowed)
|
||
|
# TIMEOUT -> REQUESTED (should not happen, bug)
|
||
|
#
|
||
|
# => TODO: What should happen after the due?!!!
|
||
|
|
||
|
|
||
|
class StateTransition:
|
||
|
pass
|
||
|
|
||
|
class StateTransitionError:
|
||
|
pass
|
||
|
|
||
|
class AcceptInTime(StateTransition):
|
||
|
pass
|
||
|
|
||
|
class AcceptAfterRejectAllowed(StateTransition):
|
||
|
pass
|
||
|
|
||
|
class AcceptAfterRejectExpired(StateTransitionError):
|
||
|
pass
|
||
|
|
||
|
class AcceptAfterTimeout(StateTransitionError):
|
||
|
pass
|
||
|
|
||
|
class AlreadyAccepted(StateTransitionError):
|
||
|
pass
|
||
|
|
||
|
class RejectInTime(StateTransition):
|
||
|
pass
|
||
|
|
||
|
class RejectAfterAccept(StateTransition):
|
||
|
pass
|
||
|
|
||
|
class RejectAfterTimeout(StateTransitionError):
|
||
|
pass
|
||
|
|
||
|
class AlreadyRejected(StateTransitionError):
|
||
|
pass
|
||
|
|
||
|
class Timeout(StateTransition):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class ParticipationRequest(SQLModel, table=True):
|
||
|
id: int = Field(primary_key=True)
|
||
|
user_id: int = Field(foreign_key="user.id")
|
||
|
task_id: int = Field(foreign_key="task.id")
|
||
|
requested_at: datetime.datetime
|
||
|
|
||
|
state_since: datetime.datetime | None = None
|
||
|
state: ParticipationState = Field(default=ParticipationState.REQUESTED)
|
||
|
|
||
|
task: Task = Relationship(back_populates="participation_requests")
|
||
|
user: User = Relationship(back_populates="participation_requests")
|
||
|
|
||
|
def _change_state(self, new_state: ParticipationState, now: datetime.datetime):
|
||
|
# This method does not check any business rules, it just changes the state
|
||
|
self.state_since = now
|
||
|
self.state = new_state
|
||
|
|
||
|
def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]:
|
||
|
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
||
|
# We want to make the order in which check_for_timeout() and try_accept() is called irrelevant
|
||
|
if self.is_timed_out(now):
|
||
|
return Err(AcceptAfterTimeout())
|
||
|
else:
|
||
|
self._change_state(ParticipationState.ACCEPTED, now)
|
||
|
return Ok(AcceptInTime())
|
||
|
elif self.state == ParticipationState.REJECTED:
|
||
|
if (now - self.state_since).total_seconds() < 5 * 60:
|
||
|
self._change_state(ParticipationState.ACCEPTED, now)
|
||
|
return Ok(AcceptAfterRejectAllowed())
|
||
|
else:
|
||
|
return Err(AcceptAfterRejectExpired())
|
||
|
elif self.state == ParticipationState.ACCEPTED:
|
||
|
return Err(AlreadyAccepted())
|
||
|
else:
|
||
|
raise Exception("Unknown old state " + str(self.state) + ".")
|
||
|
|
||
|
def try_reject(self, now: datetime.datetime) -> Result[RejectInTime, RejectAfterAccept | RejectAfterTimeout]:
|
||
|
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
||
|
# We want to make the order in which check_for_timeout() and try_reject() is called irrelevant
|
||
|
if self.is_timed_out(now):
|
||
|
return Err(RejectAfterTimeout())
|
||
|
else:
|
||
|
self._change_state(ParticipationState.REJECTED, now)
|
||
|
return Ok(RejectInTime())
|
||
|
elif self.state == ParticipationState.ACCEPTED:
|
||
|
self._change_state(ParticipationState.REJECTED, now)
|
||
|
return Ok(RejectAfterAccept())
|
||
|
elif self.state == ParticipationState.REJECTED:
|
||
|
return Err(AlreadyRejected())
|
||
|
else:
|
||
|
raise Exception("Unknown old state " + str(self.state) + ".")
|
||
|
|
||
|
def check_for_timeout(self, now: datetime.datetime) -> Timeout | None:
|
||
|
if self.state == ParticipationState.REQUESTED:
|
||
|
if self.is_timed_out(now):
|
||
|
self._change_state(ParticipationState.TIMEOUT, now)
|
||
|
return Timeout()
|
||
|
|
||
|
def is_timed_out(self, now: datetime.datetime) -> bool:
|
||
|
return (now - self.requested_at).total_seconds() > self.task.timeout
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def session():
|
||
|
if os.path.exists("/tmp/test.db"):
|
||
|
os.remove("/tmp/test.db")
|
||
|
|
||
|
engine = create_engine("sqlite:////tmp/test.db")
|
||
|
SQLModel.metadata.create_all(engine)
|
||
|
with Session(engine) as session:
|
||
|
yield session
|
||
|
|
||
|
def test_all_users_sorted(session):
|
||
|
u1 = User(name="u1")
|
||
|
u2 = User(name="u2")
|
||
|
u3 = User(name="u3")
|
||
|
|
||
|
session.add(u1)
|
||
|
session.add(u2)
|
||
|
session.add(u3)
|
||
|
|
||
|
session.commit()
|
||
|
|
||
|
users = all_users_sorted(session)
|
||
|
|
||
|
assert users == [u1, u2, u3]
|
||
|
|
||
|
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
|
||
|
|
||
|
session.add(t1)
|
||
|
session.commit()
|
||
|
|
||
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
|
||
|
|
||
|
session.add(r1)
|
||
|
session.commit()
|
||
|
|
||
|
assert t1.accepted_requests() == []
|
||
|
|
||
|
users = all_users_sorted(session)
|
||
|
|
||
|
r1._change_state(ParticipationState.ACCEPTED, utc_now())
|
||
|
session.commit()
|
||
|
|
||
|
assert t1.accepted_requests() == [r1]
|
||
|
users = all_users_sorted(session)
|
||
|
assert users == [u2, u3, u1]
|
||
|
|
||
|
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
|
||
|
session.add(r2)
|
||
|
|
||
|
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
|
||
|
session.add(r3)
|
||
|
|
||
|
session.commit()
|
||
|
|
||
|
r3._change_state(ParticipationState.ACCEPTED, utc_now())
|
||
|
|
||
|
session.commit()
|
||
|
|
||
|
assert t1.accepted_requests() == [r1, r3]
|
||
|
users = all_users_sorted(session)
|
||
|
assert users == [u2, u1, u3]
|
||
|
|
||
|
r2._change_state(ParticipationState.ACCEPTED, utc_now())
|
||
|
session.commit()
|
||
|
|
||
|
users = all_users_sorted(session)
|
||
|
assert users == [u1, u3, u2]
|
||
|
|
||
|
def test_accept(session):
|
||
|
now = datetime.datetime(2024, 12, 10, 0, 0)
|
||
|
dt = datetime.timedelta(days=1)
|
||
|
|
||
|
u1 = User(name="u1")
|
||
|
|
||
|
# - Task will happen at now + 10*dt
|
||
|
# - Each user has 1*dt after the request to answer (before timeout)
|
||
|
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
||
|
|
||
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
||
|
|
||
|
session.add(u1)
|
||
|
session.add(t1)
|
||
|
session.add(r1)
|
||
|
session.commit()
|
||
|
|
||
|
accept_result = r1.try_accept(now + 0.5*dt)
|
||
|
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
||
|
session.commit()
|
||
|
|
||
|
accept_result = r1.try_accept(now + 0.5*dt)
|
||
|
assert isinstance(accept_result.unwrap_err(), AlreadyAccepted)
|
||
|
session.commit()
|
||
|
|
||
|
# Accepted should not timeout
|
||
|
timeout_result = r1.check_for_timeout(now + 1.6*dt)
|
||
|
assert timeout_result is None
|
||
|
|
||
|
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
||
|
session.add(r2)
|
||
|
|
||
|
# AcceptAfterTimeout must be given even before check_for_timeout()
|
||
|
accept_result = r2.try_accept(now + 1.6*dt)
|
||
|
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
||
|
|
||
|
assert(r2.state == ParticipationState.REQUESTED)
|
||
|
|
||
|
# Obtain Timeout
|
||
|
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
||
|
assert isinstance(timeout_result, Timeout)
|
||
|
|
||
|
# Should still be a AcceptAfterTimeout
|
||
|
accept_result = r2.try_accept(now + 1.6*dt)
|
||
|
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
||
|
|
||
|
assert(r2.state == ParticipationState.TIMEOUT)
|
||
|
|
||
|
# If we reiceive a late accept, we should accept it
|
||
|
accept_result = r2.try_accept(now + 0.5*dt)
|
||
|
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
||
|
|
||
|
assert(r2.state == ParticipationState.ACCEPTED)
|
||
|
|
||
|
session.commit()
|
||
|
|
||
|
def test_reject(session):
|
||
|
now = datetime.datetime(2024, 12, 10, 0, 0)
|
||
|
dt = datetime.timedelta(days=1)
|
||
|
|
||
|
u1 = User(name="u1")
|
||
|
|
||
|
# - Task will happen at now + 10*dt
|
||
|
# - Each user has 1*dt after the request to answer (before timeout)
|
||
|
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
||
|
|
||
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
||
|
|
||
|
session.add(u1)
|
||
|
session.add(t1)
|
||
|
session.add(r1)
|
||
|
session.commit()
|
||
|
|
||
|
reject_result = r1.try_reject(now + 0.5*dt)
|
||
|
assert isinstance(reject_result.unwrap(), RejectInTime)
|
||
|
session.commit()
|
||
|
|
||
|
reject_result = r1.try_reject(now + 0.5*dt)
|
||
|
assert isinstance(reject_result.unwrap_err(), AlreadyRejected)
|
||
|
session.commit()
|
||
|
|
||
|
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
||
|
session.add(r2)
|
||
|
|
||
|
# RejectAfterTimeout must be given even before check_for_timeout()
|
||
|
reject_result = r2.try_reject(now + 1.6*dt)
|
||
|
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
||
|
|
||
|
assert(r2.state == ParticipationState.REQUESTED)
|
||
|
|
||
|
# Obtain Timeout
|
||
|
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
||
|
assert isinstance(timeout_result, Timeout)
|
||
|
|
||
|
assert(r2.state == ParticipationState.TIMEOUT)
|
||
|
|
||
|
# Should still be a RejectAfterTimeout
|
||
|
reject_result = r2.try_reject(now + 1.6*dt)
|
||
|
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
||
|
|
||
|
# If we receive a late reject, we should accept it
|
||
|
reject_result = r2.try_reject(now + 0.5*dt)
|
||
|
assert isinstance(reject_result.unwrap(), RejectInTime)
|
||
|
|
||
|
assert(r2.state == ParticipationState.REJECTED)
|
||
|
|
||
|
session.commit()
|
||
|
|
||
|
|