lab-signal-bot/models.py

632 lines
19 KiB
Python
Raw Normal View History

2024-12-19 10:31:45 +01:00
#!/usr/bin/env python3
import os
import enum
from typing import List
from pydantic import BaseModel
2024-12-21 02:39:27 +01:00
from result import Result, Ok, Err, is_err, is_ok
2024-12-19 10:31:45 +01:00
from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select
import datetime
import pytest
def utc_now():
# We remove the timezone information, because SQLite does not support it
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
2024-12-19 10:31:45 +01:00
class User(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str
joined_at: datetime.datetime = Field(default_factory=utc_now)
2024-12-21 01:45:10 +01:00
active: bool = True
2024-12-19 10:31:45 +01:00
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user")
def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]:
accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
if accepted_requests:
return Ok(max(accepted_requests, key=lambda r: r.requested_at))
else:
return Err(None)
def sort_time(self):
last_accepted = self.last_accepted_participation_request()
if is_ok(last_accepted):
last_accepted_obj = last_accepted.unwrap()
assert last_accepted_obj.state_since is not None
return last_accepted_obj.state_since
else:
return self.joined_at
def all_users_sorted(session: Session) -> List[User]:
2024-12-21 01:45:10 +01:00
users = session.exec(select(User).where(User.active)).all()
2024-12-19 10:31:45 +01:00
users.sort(key=lambda u: u.sort_time())
return users
def next_user_to_send_request(session: Session, task: "Task") -> User | None:
# Return the user who should send the next request (or None if we already
# sent requests for this Task to all users)
requested_users = [r.user for r in task.participation_requests]
users = all_users_sorted(session)
for u in users:
if u in requested_users:
continue
return u
2024-12-21 01:45:10 +01:00
def get_user_by_name(session: Session, name: str, only_active: bool = True) -> Result[User, None]:
Q = select(User).where(User.name == name)
if only_active:
Q = Q.where(User.active)
user = session.exec(Q).first()
if user is None:
return Err(None)
else:
return Ok(user)
2024-12-27 01:59:48 +01:00
def get_participation_request_by_timestamp(session: Session, timestamp: datetime.datetime) -> Result["ParticipationRequest", None]:
Q = select(ParticipationRequest).where(ParticipationRequest.requested_at == timestamp)
request = session.exec(Q).first()
if request is None:
return Err(None)
else:
return Ok(request)
2024-12-19 10:31:45 +01:00
class Task(SQLModel, table=True):
id: int = Field(primary_key=True)
name: str
required_number_of_participants: int
due: datetime.datetime
timeout: int # in seconds
chatgroup: str | None = None # None = to be created
2024-12-19 10:31:45 +01:00
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task")
def accepted_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
def rejected_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED]
def requested_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED]
def timeout_requests(self):
return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT]
def additional_requests_to_be_sent(self):
# return the number of additional requests to be sent
return self.required_number_of_participants \
- len(self.accepted_requests()) \
- len(self.requested_requests())
2024-12-19 10:31:45 +01:00
def freshly_expired_requests(self, now) -> List["ParticipationRequest"]:
expired_requests = []
for r in self.participation_requests:
maybe_timeout = r.check_for_timeout(now)
if maybe_timeout is not None:
expired_requests.append(r)
2024-12-27 02:05:26 +01:00
return expired_requests
2024-12-21 02:39:27 +01:00
def create_additional_requests(self, now: datetime.datetime, session: Session) -> Result[List["ParticipationRequest"], List["ParticipationRequest"]]:
additional_requests = []
for _ in range(self.additional_requests_to_be_sent()):
user = next_user_to_send_request(session, self)
if user is None:
# Incomplete requests
return Err(additional_requests)
2024-12-21 02:39:27 +01:00
request = ParticipationRequest(
user=user,
task=self,
requested_at=now
)
session.add(request)
additional_requests.append(request)
return Ok(additional_requests)
2024-12-21 02:39:27 +01:00
2024-12-21 05:42:36 +01:00
def get_active_tasks(session: Session, now: datetime.datetime) -> List[Task]:
return session.exec(select(Task).where(Task.due > now)).all()
2024-12-21 02:39:27 +01:00
def get_tasks(session: Session) -> List[Task]:
return session.exec(select(Task)).all()
2024-12-19 10:31:45 +01:00
class ParticipationState(enum.Enum):
REQUESTED = "requested"
ACCEPTED = "accepted"
REJECTED = "rejected"
TIMEOUT = "timeout"
# Transitions for a ParticipationState:
#
# Accepting or rejecting a request within timeout:
#
# - REQUESTED -> ACCEPTED (allowed)
# - REQUESTED -> REJECTED (allowed)
#
# Normal timeout:
#
# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout)
#
# Changing from ACCEPTED to REJECTED is allowed:
#
# - ACCEPTED -> REJECTED (allowed)
#
# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only:
#
# - REJECTED -> ACCEPTED (allowed, within 5 minutes)
# - REJECTED -> ACCEPTED (not allowed, after 5 minutes)
#
# Answered requests can not timeout:
#
# - REJECTED -> TIMEOUT (should not happen, bug)
# - ACCEPTED -> TIMEOUT (should not happen, bug)
#
# Timed out requests can not be
#
# TIMEOUT -> ACCEPTED (not allowed)
# TIMEOUT -> REJECTED (not allowed)
# TIMEOUT -> REQUESTED (should not happen, bug)
#
# => TODO: What should happen after the due?!!!
2024-12-21 01:45:10 +01:00
# => TODO: What happens when a user has two requests?
2024-12-19 10:31:45 +01:00
class StateTransition:
pass
class StateTransitionError:
pass
class AcceptInTime(StateTransition):
pass
class AcceptAfterRejectAllowed(StateTransition):
pass
class AcceptAfterRejectExpired(StateTransitionError):
pass
class AcceptAfterTimeout(StateTransitionError):
pass
class AlreadyAccepted(StateTransitionError):
pass
class RejectInTime(StateTransition):
pass
class RejectAfterAccept(StateTransition):
pass
class RejectAfterTimeout(StateTransitionError):
pass
class AlreadyRejected(StateTransitionError):
pass
class Timeout(StateTransition):
pass
class ParticipationRequest(SQLModel, table=True):
id: int = Field(primary_key=True)
user_id: int = Field(foreign_key="user.id")
task_id: int = Field(foreign_key="task.id")
requested_at: datetime.datetime
state_since: datetime.datetime | None = None
state: ParticipationState = Field(default=ParticipationState.REQUESTED)
task: Task = Relationship(back_populates="participation_requests")
user: User = Relationship(back_populates="participation_requests")
def _change_state(self, new_state: ParticipationState, now: datetime.datetime):
# This method does not check any business rules, it just changes the state
self.state_since = now
self.state = new_state
def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]:
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
# We want to make the order in which check_for_timeout() and try_accept() is called irrelevant
if self.is_timed_out(now):
return Err(AcceptAfterTimeout())
else:
self._change_state(ParticipationState.ACCEPTED, now)
return Ok(AcceptInTime())
elif self.state == ParticipationState.REJECTED:
if (now - self.state_since).total_seconds() < 5 * 60:
self._change_state(ParticipationState.ACCEPTED, now)
return Ok(AcceptAfterRejectAllowed())
else:
return Err(AcceptAfterRejectExpired())
elif self.state == ParticipationState.ACCEPTED:
return Err(AlreadyAccepted())
else:
raise Exception("Unknown old state " + str(self.state) + ".")
def try_reject(self, now: datetime.datetime) -> Result[RejectInTime | RejectAfterAccept, AlreadyRejected | RejectAfterTimeout]:
2024-12-19 10:31:45 +01:00
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
# We want to make the order in which check_for_timeout() and try_reject() is called irrelevant
if self.is_timed_out(now):
return Err(RejectAfterTimeout())
else:
self._change_state(ParticipationState.REJECTED, now)
return Ok(RejectInTime())
elif self.state == ParticipationState.ACCEPTED:
self._change_state(ParticipationState.REJECTED, now)
return Ok(RejectAfterAccept())
elif self.state == ParticipationState.REJECTED:
return Err(AlreadyRejected())
else:
raise Exception("Unknown old state " + str(self.state) + ".")
def check_for_timeout(self, now: datetime.datetime) -> Timeout | None:
if self.state == ParticipationState.REQUESTED:
if self.is_timed_out(now):
self._change_state(ParticipationState.TIMEOUT, now)
return Timeout()
def is_timed_out(self, now: datetime.datetime) -> bool:
return (now - self.requested_at).total_seconds() > self.task.timeout
@pytest.fixture
def session():
if os.path.exists("/tmp/test.db"):
os.remove("/tmp/test.db")
engine = create_engine("sqlite:////tmp/test.db")
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
yield session
def test_all_users_sorted(session):
u1 = User(name="u1")
u2 = User(name="u2")
u3 = User(name="u3")
session.add(u1)
session.add(u2)
session.add(u3)
session.commit()
users = all_users_sorted(session)
assert users == [u1, u2, u3]
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
session.add(t1)
session.commit()
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
session.add(r1)
session.commit()
assert t1.accepted_requests() == []
users = all_users_sorted(session)
r1._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
assert t1.accepted_requests() == [r1]
users = all_users_sorted(session)
assert users == [u2, u3, u1]
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
session.add(r2)
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
session.add(r3)
session.commit()
r3._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
assert t1.accepted_requests() == [r1, r3]
users = all_users_sorted(session)
assert users == [u2, u1, u3]
r2._change_state(ParticipationState.ACCEPTED, utc_now())
session.commit()
users = all_users_sorted(session)
assert users == [u1, u3, u2]
def test_next_user(session):
u1 = User(name="u1")
u2 = User(name="u2")
u3 = User(name="u3")
session.add(u1)
session.add(u2)
session.add(u3)
session.commit()
# Starting order
users = all_users_sorted(session)
assert users == [u1, u2, u3]
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
session.add(t1)
assert t1.additional_requests_to_be_sent() == 2
assert next_user_to_send_request(session, t1) == u1
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
session.add(r1)
assert next_user_to_send_request(session, t1) == u2
# Starting order
users = all_users_sorted(session)
assert users == [u1, u2, u3]
# Another additional request should be sent out.
assert t1.additional_requests_to_be_sent() == 1
# After rejecting, two additional requests should be sent out.
assert isinstance(r1.try_reject(utc_now()).unwrap(), RejectInTime)
assert t1.additional_requests_to_be_sent() == 2
# Create another Task
t2 = Task(name="t2", required_number_of_participants=2, due=utc_now(), timeout=10)
session.add(t2)
# Next user should be user1 (because we rejected in t1)
assert next_user_to_send_request(session, t2) == u1
# Still starting order (because we rejected)
users = all_users_sorted(session)
assert users == [u1, u2, u3]
# When we change to accept, only one additional request should be sent out.
assert isinstance(r1.try_accept(utc_now()).unwrap(), AcceptAfterRejectAllowed)
assert t1.additional_requests_to_be_sent() == 1
# Next user should be u2 for t2 (because u1 accepted for t1)
assert next_user_to_send_request(session, t2) == u2
# Next user should be user2 for t1 also
assert next_user_to_send_request(session, t1) == u2
# Still starting order (because we rejected)
users = all_users_sorted(session)
assert users == [u2, u3, u1]
# Next user should be user2
assert next_user_to_send_request(session, t1) == u2
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
session.add(r3)
# Next user should still be user2
assert next_user_to_send_request(session, t1) == u2
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
session.add(r2)
# No user should be left
assert next_user_to_send_request(session, t1) is None
session.commit()
# We sent one request too much...
assert t1.additional_requests_to_be_sent() == -1
2024-12-19 10:31:45 +01:00
def test_accept(session):
now = datetime.datetime(2024, 12, 10, 0, 0)
dt = datetime.timedelta(days=1)
u1 = User(name="u1")
# - Task will happen at now + 10*dt
# - Each user has 1*dt after the request to answer (before timeout)
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
session.add(u1)
session.add(t1)
session.add(r1)
session.commit()
accept_result = r1.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap(), AcceptInTime)
session.commit()
accept_result = r1.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap_err(), AlreadyAccepted)
session.commit()
# Accepted should not timeout
timeout_result = r1.check_for_timeout(now + 1.6*dt)
assert timeout_result is None
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
session.add(r2)
# AcceptAfterTimeout must be given even before check_for_timeout()
accept_result = r2.try_accept(now + 1.6*dt)
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
assert(r2.state == ParticipationState.REQUESTED)
# Obtain Timeout
timeout_result = r2.check_for_timeout(now + 1.6*dt)
assert isinstance(timeout_result, Timeout)
# Should still be a AcceptAfterTimeout
accept_result = r2.try_accept(now + 1.6*dt)
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
assert(r2.state == ParticipationState.TIMEOUT)
# If we reiceive a late accept, we should accept it
accept_result = r2.try_accept(now + 0.5*dt)
assert isinstance(accept_result.unwrap(), AcceptInTime)
assert(r2.state == ParticipationState.ACCEPTED)
session.commit()
def test_reject(session):
now = datetime.datetime(2024, 12, 10, 0, 0)
dt = datetime.timedelta(days=1)
u1 = User(name="u1")
# - Task will happen at now + 10*dt
# - Each user has 1*dt after the request to answer (before timeout)
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
session.add(u1)
session.add(t1)
session.add(r1)
session.commit()
reject_result = r1.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap(), RejectInTime)
session.commit()
reject_result = r1.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap_err(), AlreadyRejected)
session.commit()
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
session.add(r2)
# RejectAfterTimeout must be given even before check_for_timeout()
reject_result = r2.try_reject(now + 1.6*dt)
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
assert(r2.state == ParticipationState.REQUESTED)
# Obtain Timeout
timeout_result = r2.check_for_timeout(now + 1.6*dt)
assert isinstance(timeout_result, Timeout)
assert(r2.state == ParticipationState.TIMEOUT)
# Should still be a RejectAfterTimeout
reject_result = r2.try_reject(now + 1.6*dt)
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
# If we receive a late reject, we should accept it
reject_result = r2.try_reject(now + 0.5*dt)
assert isinstance(reject_result.unwrap(), RejectInTime)
assert(r2.state == ParticipationState.REJECTED)
session.commit()
2024-12-21 01:45:10 +01:00
def test_active(session):
u1 = User(name="u1")
u2 = User(name="u2")
u3 = User(name="u3")
session.add(u1)
session.add(u2)
session.add(u3)
session.commit()
users = all_users_sorted(session)
assert users == [u1, u2, u3]
u2.active = False
session.commit()
users = all_users_sorted(session)
assert users == [u1, u3]
u2.active = True
session.commit()
users = all_users_sorted(session)
assert users == [u1, u2, u3]
2024-12-21 02:39:27 +01:00
def test_create_additional_requests(session):
now = datetime.datetime(2024, 12, 10, 0, 0)
dt = datetime.timedelta(days=1)
u1 = User(name="u1")
u2 = User(name="u2")
u3 = User(name="u3")
session.add(u1)
session.add(u2)
session.add(u3)
session.commit()
t1 = Task(name="t1", required_number_of_participants=2,
due=now + 10*dt, timeout=dt.total_seconds())
session.add(t1)
requests = t1.create_additional_requests(now + 0.5*dt, session)
assert(is_ok(requests))
requests = requests.unwrap()
assert(len(requests) == 2)
requested_users = [req.user for req in requests]
assert(requested_users == [u1, u2])
assert(isinstance(requests[0].try_reject(now + 0.75*dt).unwrap(), RejectInTime))
assert(isinstance(requests[1].try_reject(now + 0.75*dt).unwrap(), RejectInTime))
requests = t1.create_additional_requests(now + 0.5*dt, session)
# could not create all requests, so we have Err now
assert(is_err(requests))
requests = requests.unwrap_err()
assert(len(requests) == 1)
requested_users = [req.user for req in requests]
assert(requested_users == [u3])
# No user left to ask...
requests = t1.create_additional_requests(now + 0.5*dt, session)
assert(is_err(requests))
requests = requests.unwrap_err()
assert(len(requests) == 0)
u4 = User(name="u4")
session.add(u4)
# Now, we have a user again, so we can ask. Since u3 did not
# reject (so far), we only need one additional request. So,
# the result is Ok(...).
requests = t1.create_additional_requests(now + 0.5*dt, session)
assert(is_ok(requests))
requests = requests.unwrap()
assert(len(requests) == 1)
requested_users = [req.user for req in requests]
assert(requested_users == [u4])
session.commit()