forked from LeineLab-Public/lab-signal-bot
Add some initial (unfinished) models
This commit is contained in:
parent
3a6de92eee
commit
52112e1df8
394
models.py
Normal file
394
models.py
Normal file
@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import enum
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from result import Result, Ok, Err, is_ok
|
||||
from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select
|
||||
|
||||
import datetime
|
||||
import pytest
|
||||
|
||||
def utc_now():
|
||||
return datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
name: str
|
||||
joined_at: datetime.datetime = Field(default_factory=utc_now)
|
||||
|
||||
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user")
|
||||
|
||||
def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]:
|
||||
accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
||||
if accepted_requests:
|
||||
return Ok(max(accepted_requests, key=lambda r: r.requested_at))
|
||||
else:
|
||||
return Err(None)
|
||||
|
||||
def sort_time(self):
|
||||
last_accepted = self.last_accepted_participation_request()
|
||||
if is_ok(last_accepted):
|
||||
last_accepted_obj = last_accepted.unwrap()
|
||||
assert last_accepted_obj.state_since is not None
|
||||
return last_accepted_obj.state_since
|
||||
else:
|
||||
return self.joined_at
|
||||
|
||||
def all_users_sorted(session: Session) -> List[User]:
|
||||
users = session.exec(select(User)).all()
|
||||
users.sort(key=lambda u: u.sort_time())
|
||||
return users
|
||||
|
||||
class Task(SQLModel, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
name: str
|
||||
required_number_of_participants: int
|
||||
due: datetime.datetime
|
||||
timeout: int # in seconds
|
||||
|
||||
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task")
|
||||
|
||||
def accepted_requests(self):
|
||||
return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
||||
|
||||
def rejected_requests(self):
|
||||
return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED]
|
||||
|
||||
def requested_requests(self):
|
||||
return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED]
|
||||
|
||||
def timeout_requests(self):
|
||||
return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT]
|
||||
|
||||
def additional_requests_to_be_sent(self):
|
||||
# return the number of additional requests to be sent
|
||||
return self.required_number_of_participants \
|
||||
- len(self.accepted_participation_requests()) \
|
||||
- len(self.requested_participation_requests())
|
||||
|
||||
def freshly_expired_requests(self, now) -> List["ParticipationRequest"]:
|
||||
expired_requests = []
|
||||
|
||||
for r in self.participation_requests:
|
||||
maybe_timeout = r.check_for_timeout(now)
|
||||
if maybe_timeout is not None:
|
||||
expired_requests.append(r)
|
||||
|
||||
class ParticipationState(enum.Enum):
|
||||
REQUESTED = "requested"
|
||||
ACCEPTED = "accepted"
|
||||
REJECTED = "rejected"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
# Transitions for a ParticipationState:
|
||||
#
|
||||
# Accepting or rejecting a request within timeout:
|
||||
#
|
||||
# - REQUESTED -> ACCEPTED (allowed)
|
||||
# - REQUESTED -> REJECTED (allowed)
|
||||
#
|
||||
# Normal timeout:
|
||||
#
|
||||
# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout)
|
||||
#
|
||||
# Changing from ACCEPTED to REJECTED is allowed:
|
||||
#
|
||||
# - ACCEPTED -> REJECTED (allowed)
|
||||
#
|
||||
# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only:
|
||||
#
|
||||
# - REJECTED -> ACCEPTED (allowed, within 5 minutes)
|
||||
# - REJECTED -> ACCEPTED (not allowed, after 5 minutes)
|
||||
#
|
||||
# Answered requests can not timeout:
|
||||
#
|
||||
# - REJECTED -> TIMEOUT (should not happen, bug)
|
||||
# - ACCEPTED -> TIMEOUT (should not happen, bug)
|
||||
#
|
||||
# Timed out requests can not be
|
||||
#
|
||||
# TIMEOUT -> ACCEPTED (not allowed)
|
||||
# TIMEOUT -> REJECTED (not allowed)
|
||||
# TIMEOUT -> REQUESTED (should not happen, bug)
|
||||
#
|
||||
# => TODO: What should happen after the due?!!!
|
||||
|
||||
|
||||
class StateTransition:
|
||||
pass
|
||||
|
||||
class StateTransitionError:
|
||||
pass
|
||||
|
||||
class AcceptInTime(StateTransition):
|
||||
pass
|
||||
|
||||
class AcceptAfterRejectAllowed(StateTransition):
|
||||
pass
|
||||
|
||||
class AcceptAfterRejectExpired(StateTransitionError):
|
||||
pass
|
||||
|
||||
class AcceptAfterTimeout(StateTransitionError):
|
||||
pass
|
||||
|
||||
class AlreadyAccepted(StateTransitionError):
|
||||
pass
|
||||
|
||||
class RejectInTime(StateTransition):
|
||||
pass
|
||||
|
||||
class RejectAfterAccept(StateTransition):
|
||||
pass
|
||||
|
||||
class RejectAfterTimeout(StateTransitionError):
|
||||
pass
|
||||
|
||||
class AlreadyRejected(StateTransitionError):
|
||||
pass
|
||||
|
||||
class Timeout(StateTransition):
|
||||
pass
|
||||
|
||||
|
||||
class ParticipationRequest(SQLModel, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
task_id: int = Field(foreign_key="task.id")
|
||||
requested_at: datetime.datetime
|
||||
|
||||
state_since: datetime.datetime | None = None
|
||||
state: ParticipationState = Field(default=ParticipationState.REQUESTED)
|
||||
|
||||
task: Task = Relationship(back_populates="participation_requests")
|
||||
user: User = Relationship(back_populates="participation_requests")
|
||||
|
||||
def _change_state(self, new_state: ParticipationState, now: datetime.datetime):
|
||||
# This method does not check any business rules, it just changes the state
|
||||
self.state_since = now
|
||||
self.state = new_state
|
||||
|
||||
def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]:
|
||||
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
||||
# We want to make the order in which check_for_timeout() and try_accept() is called irrelevant
|
||||
if self.is_timed_out(now):
|
||||
return Err(AcceptAfterTimeout())
|
||||
else:
|
||||
self._change_state(ParticipationState.ACCEPTED, now)
|
||||
return Ok(AcceptInTime())
|
||||
elif self.state == ParticipationState.REJECTED:
|
||||
if (now - self.state_since).total_seconds() < 5 * 60:
|
||||
self._change_state(ParticipationState.ACCEPTED, now)
|
||||
return Ok(AcceptAfterRejectAllowed())
|
||||
else:
|
||||
return Err(AcceptAfterRejectExpired())
|
||||
elif self.state == ParticipationState.ACCEPTED:
|
||||
return Err(AlreadyAccepted())
|
||||
else:
|
||||
raise Exception("Unknown old state " + str(self.state) + ".")
|
||||
|
||||
def try_reject(self, now: datetime.datetime) -> Result[RejectInTime, RejectAfterAccept | RejectAfterTimeout]:
|
||||
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
||||
# We want to make the order in which check_for_timeout() and try_reject() is called irrelevant
|
||||
if self.is_timed_out(now):
|
||||
return Err(RejectAfterTimeout())
|
||||
else:
|
||||
self._change_state(ParticipationState.REJECTED, now)
|
||||
return Ok(RejectInTime())
|
||||
elif self.state == ParticipationState.ACCEPTED:
|
||||
self._change_state(ParticipationState.REJECTED, now)
|
||||
return Ok(RejectAfterAccept())
|
||||
elif self.state == ParticipationState.REJECTED:
|
||||
return Err(AlreadyRejected())
|
||||
else:
|
||||
raise Exception("Unknown old state " + str(self.state) + ".")
|
||||
|
||||
def check_for_timeout(self, now: datetime.datetime) -> Timeout | None:
|
||||
if self.state == ParticipationState.REQUESTED:
|
||||
if self.is_timed_out(now):
|
||||
self._change_state(ParticipationState.TIMEOUT, now)
|
||||
return Timeout()
|
||||
|
||||
def is_timed_out(self, now: datetime.datetime) -> bool:
|
||||
return (now - self.requested_at).total_seconds() > self.task.timeout
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
if os.path.exists("/tmp/test.db"):
|
||||
os.remove("/tmp/test.db")
|
||||
|
||||
engine = create_engine("sqlite:////tmp/test.db")
|
||||
SQLModel.metadata.create_all(engine)
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
def test_all_users_sorted(session):
|
||||
u1 = User(name="u1")
|
||||
u2 = User(name="u2")
|
||||
u3 = User(name="u3")
|
||||
|
||||
session.add(u1)
|
||||
session.add(u2)
|
||||
session.add(u3)
|
||||
|
||||
session.commit()
|
||||
|
||||
users = all_users_sorted(session)
|
||||
|
||||
assert users == [u1, u2, u3]
|
||||
|
||||
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
|
||||
|
||||
session.add(t1)
|
||||
session.commit()
|
||||
|
||||
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
|
||||
|
||||
session.add(r1)
|
||||
session.commit()
|
||||
|
||||
assert t1.accepted_requests() == []
|
||||
|
||||
users = all_users_sorted(session)
|
||||
|
||||
r1._change_state(ParticipationState.ACCEPTED, utc_now())
|
||||
session.commit()
|
||||
|
||||
assert t1.accepted_requests() == [r1]
|
||||
users = all_users_sorted(session)
|
||||
assert users == [u2, u3, u1]
|
||||
|
||||
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
|
||||
session.add(r2)
|
||||
|
||||
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
|
||||
session.add(r3)
|
||||
|
||||
session.commit()
|
||||
|
||||
r3._change_state(ParticipationState.ACCEPTED, utc_now())
|
||||
|
||||
session.commit()
|
||||
|
||||
assert t1.accepted_requests() == [r1, r3]
|
||||
users = all_users_sorted(session)
|
||||
assert users == [u2, u1, u3]
|
||||
|
||||
r2._change_state(ParticipationState.ACCEPTED, utc_now())
|
||||
session.commit()
|
||||
|
||||
users = all_users_sorted(session)
|
||||
assert users == [u1, u3, u2]
|
||||
|
||||
def test_accept(session):
|
||||
now = datetime.datetime(2024, 12, 10, 0, 0)
|
||||
dt = datetime.timedelta(days=1)
|
||||
|
||||
u1 = User(name="u1")
|
||||
|
||||
# - Task will happen at now + 10*dt
|
||||
# - Each user has 1*dt after the request to answer (before timeout)
|
||||
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
||||
|
||||
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
||||
|
||||
session.add(u1)
|
||||
session.add(t1)
|
||||
session.add(r1)
|
||||
session.commit()
|
||||
|
||||
accept_result = r1.try_accept(now + 0.5*dt)
|
||||
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
||||
session.commit()
|
||||
|
||||
accept_result = r1.try_accept(now + 0.5*dt)
|
||||
assert isinstance(accept_result.unwrap_err(), AlreadyAccepted)
|
||||
session.commit()
|
||||
|
||||
# Accepted should not timeout
|
||||
timeout_result = r1.check_for_timeout(now + 1.6*dt)
|
||||
assert timeout_result is None
|
||||
|
||||
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
||||
session.add(r2)
|
||||
|
||||
# AcceptAfterTimeout must be given even before check_for_timeout()
|
||||
accept_result = r2.try_accept(now + 1.6*dt)
|
||||
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
||||
|
||||
assert(r2.state == ParticipationState.REQUESTED)
|
||||
|
||||
# Obtain Timeout
|
||||
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
||||
assert isinstance(timeout_result, Timeout)
|
||||
|
||||
# Should still be a AcceptAfterTimeout
|
||||
accept_result = r2.try_accept(now + 1.6*dt)
|
||||
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
||||
|
||||
assert(r2.state == ParticipationState.TIMEOUT)
|
||||
|
||||
# If we reiceive a late accept, we should accept it
|
||||
accept_result = r2.try_accept(now + 0.5*dt)
|
||||
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
||||
|
||||
assert(r2.state == ParticipationState.ACCEPTED)
|
||||
|
||||
session.commit()
|
||||
|
||||
def test_reject(session):
|
||||
now = datetime.datetime(2024, 12, 10, 0, 0)
|
||||
dt = datetime.timedelta(days=1)
|
||||
|
||||
u1 = User(name="u1")
|
||||
|
||||
# - Task will happen at now + 10*dt
|
||||
# - Each user has 1*dt after the request to answer (before timeout)
|
||||
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
||||
|
||||
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
||||
|
||||
session.add(u1)
|
||||
session.add(t1)
|
||||
session.add(r1)
|
||||
session.commit()
|
||||
|
||||
reject_result = r1.try_reject(now + 0.5*dt)
|
||||
assert isinstance(reject_result.unwrap(), RejectInTime)
|
||||
session.commit()
|
||||
|
||||
reject_result = r1.try_reject(now + 0.5*dt)
|
||||
assert isinstance(reject_result.unwrap_err(), AlreadyRejected)
|
||||
session.commit()
|
||||
|
||||
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
||||
session.add(r2)
|
||||
|
||||
# RejectAfterTimeout must be given even before check_for_timeout()
|
||||
reject_result = r2.try_reject(now + 1.6*dt)
|
||||
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
||||
|
||||
assert(r2.state == ParticipationState.REQUESTED)
|
||||
|
||||
# Obtain Timeout
|
||||
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
||||
assert isinstance(timeout_result, Timeout)
|
||||
|
||||
assert(r2.state == ParticipationState.TIMEOUT)
|
||||
|
||||
# Should still be a RejectAfterTimeout
|
||||
reject_result = r2.try_reject(now + 1.6*dt)
|
||||
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
||||
|
||||
# If we receive a late reject, we should accept it
|
||||
reject_result = r2.try_reject(now + 0.5*dt)
|
||||
assert isinstance(reject_result.unwrap(), RejectInTime)
|
||||
|
||||
assert(r2.state == ParticipationState.REJECTED)
|
||||
|
||||
session.commit()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user