lab-signal-bot/models.py
2024-12-19 10:31:45 +01:00

395 lines
12 KiB
Python

#!/usr/bin/env python3
import os
import enum
from typing import List
from pydantic import BaseModel
from result import Result, Ok, Err, is_ok
from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select
import datetime
import pytest
def utc_now():
return datetime.datetime.now(datetime.timezone.utc)
class User(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str
joined_at: datetime.datetime = Field(default_factory=utc_now)
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user")
def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]:
accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
if accepted_requests:
return Ok(max(accepted_requests, key=lambda r: r.requested_at))
else:
return Err(None)
def sort_time(self):
last_accepted = self.last_accepted_participation_request()
if is_ok(last_accepted):
last_accepted_obj = last_accepted.unwrap()
assert last_accepted_obj.state_since is not None
return last_accepted_obj.state_since
else:
return self.joined_at
def all_users_sorted(session: Session) -> List[User]:
users = session.exec(select(User)).all()
users.sort(key=lambda u: u.sort_time())
return users
class Task(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str
required_number_of_participants: int
due: datetime.datetime
timeout: int # in seconds
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task")
def accepted_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
def rejected_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED]
def requested_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED]
def timeout_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT]
def additional_requests_to_be_sent(self):
# return the number of additional requests to be sent
return self.required_number_of_participants \
- len(self.accepted_participation_requests()) \
- len(self.requested_participation_requests())
def freshly_expired_requests(self, now) -> List["ParticipationRequest"]:
expired_requests = []
for r in self.participation_requests:
maybe_timeout = r.check_for_timeout(now)
if maybe_timeout is not None:
expired_requests.append(r)
class ParticipationState(enum.Enum):
REQUESTED = "requested"
ACCEPTED = "accepted"
REJECTED = "rejected"
TIMEOUT = "timeout"
# Transitions for a ParticipationState:
#
# Accepting or rejecting a request within timeout:
#
# - REQUESTED -> ACCEPTED (allowed)
# - REQUESTED -> REJECTED (allowed)
#
# Normal timeout:
#
# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout)
#
# Changing from ACCEPTED to REJECTED is allowed:
#
# - ACCEPTED -> REJECTED (allowed)
#
# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only:
#
# - REJECTED -> ACCEPTED (allowed, within 5 minutes)
# - REJECTED -> ACCEPTED (not allowed, after 5 minutes)
#
# Answered requests can not timeout:
#
# - REJECTED -> TIMEOUT (should not happen, bug)
# - ACCEPTED -> TIMEOUT (should not happen, bug)
#
# Timed out requests can not be
#
# TIMEOUT -> ACCEPTED (not allowed)
# TIMEOUT -> REJECTED (not allowed)
# TIMEOUT -> REQUESTED (should not happen, bug)
#
# => TODO: What should happen after the due?!!!
class StateTransition:
pass
class StateTransitionError:
pass
class AcceptInTime(StateTransition):
pass
class AcceptAfterRejectAllowed(StateTransition):
pass
class AcceptAfterRejectExpired(StateTransitionError):
pass
class AcceptAfterTimeout(StateTransitionError):
pass
class AlreadyAccepted(StateTransitionError):
pass
class RejectInTime(StateTransition):
pass
class RejectAfterAccept(StateTransition):
pass
class RejectAfterTimeout(StateTransitionError):
pass
class AlreadyRejected(StateTransitionError):
pass
class Timeout(StateTransition):
pass
class ParticipationRequest(SQLModel, table=True):
id: int = Field(primary_key=True)
user_id: int = Field(foreign_key="user.id")
task_id: int = Field(foreign_key="task.id")
requested_at: datetime.datetime
state_since: datetime.datetime | None = None
state: ParticipationState = Field(default=ParticipationState.REQUESTED)
task: Task = Relationship(back_populates="participation_requests")
user: User = Relationship(back_populates="participation_requests")
def _change_state(self, new_state: ParticipationState, now: datetime.datetime):
# This method does not check any business rules, it just changes the state
self.state_since = now
self.state = new_state
def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]:
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
# We want to make the order in which check_for_timeout() and try_accept() is called irrelevant
if self.is_timed_out(now):
return Err(AcceptAfterTimeout())
else:
self._change_state(ParticipationState.ACCEPTED, now)
return Ok(AcceptInTime())
elif self.state == ParticipationState.REJECTED:
if (now - self.state_since).total_seconds() < 5 * 60:
self._change_state(ParticipationState.ACCEPTED, now)
return Ok(AcceptAfterRejectAllowed())
else:
return Err(AcceptAfterRejectExpired())
elif self.state == ParticipationState.ACCEPTED:
return Err(AlreadyAccepted())
else:
raise Exception("Unknown old state " + str(self.state) + ".")
def try_reject(self, now: datetime.datetime) -> Result[RejectInTime, RejectAfterAccept | RejectAfterTimeout]:
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
# We want to make the order in which check_for_timeout() and try_reject() is called irrelevant
if self.is_timed_out(now):
return Err(RejectAfterTimeout())
else:
self._change_state(ParticipationState.REJECTED, now)
return Ok(RejectInTime())
elif self.state == ParticipationState.ACCEPTED:
self._change_state(ParticipationState.REJECTED, now)
return Ok(RejectAfterAccept())
elif self.state == ParticipationState.REJECTED:
return Err(AlreadyRejected())
else:
raise Exception("Unknown old state " + str(self.state) + ".")
def check_for_timeout(self, now: datetime.datetime) -> Timeout | None:
if self.state == ParticipationState.REQUESTED:
if self.is_timed_out(now):
self._change_state(ParticipationState.TIMEOUT, now)
return Timeout()
def is_timed_out(self, now: datetime.datetime) -> bool:
return (now - self.requested_at).total_seconds() > self.task.timeout
@pytest.fixture
def session():
if os.path.exists("/tmp/test.db"):
os.remove("/tmp/test.db")
engine = create_engine("sqlite:////tmp/test.db")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
def test_all_users_sorted(session):
u1 = User(name="u1")
u2 = User(name="u2")
u3 = User(name="u3")
session.add(u1)
session.add(u2)
session.add(u3)
session.commit()
users = all_users_sorted(session)
assert users == [u1, u2, u3]
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
session.add(t1)
session.commit()
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
session.add(r1)
session.commit()
assert t1.accepted_requests() == []
users = all_users_sorted(session)
r1._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
assert t1.accepted_requests() == [r1]
users = all_users_sorted(session)
assert users == [u2, u3, u1]
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
session.add(r2)
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
session.add(r3)
session.commit()
r3._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
assert t1.accepted_requests() == [r1, r3]
users = all_users_sorted(session)
assert users == [u2, u1, u3]
r2._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
users = all_users_sorted(session)
assert users == [u1, u3, u2]
def test_accept(session):
now = datetime.datetime(2024, 12, 10, 0, 0)
dt = datetime.timedelta(days=1)
u1 = User(name="u1")
# - Task will happen at now + 10*dt
# - Each user has 1*dt after the request to answer (before timeout)
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
session.add(u1)
session.add(t1)
session.add(r1)
session.commit()
accept_result = r1.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap(), AcceptInTime)
session.commit()
accept_result = r1.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap_err(), AlreadyAccepted)
session.commit()
# Accepted should not timeout
timeout_result = r1.check_for_timeout(now + 1.6*dt)
assert timeout_result is None
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
session.add(r2)
# AcceptAfterTimeout must be given even before check_for_timeout()
accept_result = r2.try_accept(now + 1.6*dt)
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
assert(r2.state == ParticipationState.REQUESTED)
# Obtain Timeout
timeout_result = r2.check_for_timeout(now + 1.6*dt)
assert isinstance(timeout_result, Timeout)
# Should still be a AcceptAfterTimeout
accept_result = r2.try_accept(now + 1.6*dt)
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
assert(r2.state == ParticipationState.TIMEOUT)
# If we reiceive a late accept, we should accept it
accept_result = r2.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap(), AcceptInTime)
assert(r2.state == ParticipationState.ACCEPTED)
session.commit()
def test_reject(session):
now = datetime.datetime(2024, 12, 10, 0, 0)
dt = datetime.timedelta(days=1)
u1 = User(name="u1")
# - Task will happen at now + 10*dt
# - Each user has 1*dt after the request to answer (before timeout)
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
session.add(u1)
session.add(t1)
session.add(r1)
session.commit()
reject_result = r1.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap(), RejectInTime)
session.commit()
reject_result = r1.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap_err(), AlreadyRejected)
session.commit()
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
session.add(r2)
# RejectAfterTimeout must be given even before check_for_timeout()
reject_result = r2.try_reject(now + 1.6*dt)
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
assert(r2.state == ParticipationState.REQUESTED)
# Obtain Timeout
timeout_result = r2.check_for_timeout(now + 1.6*dt)
assert isinstance(timeout_result, Timeout)
assert(r2.state == ParticipationState.TIMEOUT)
# Should still be a RejectAfterTimeout
reject_result = r2.try_reject(now + 1.6*dt)
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
# If we receive a late reject, we should accept it
reject_result = r2.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap(), RejectInTime)
assert(r2.state == ParticipationState.REJECTED)
session.commit()