634 lines
20 KiB
Python
634 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import enum
|
|
from typing import List
|
|
from pydantic import BaseModel
|
|
from result import Result, Ok, Err, is_err, is_ok
|
|
from sqlmodel import Field, Session, SQLModel, create_engine, Relationship, select
|
|
|
|
import datetime
|
|
import pytest
|
|
|
|
def utc_now():
|
|
# We remove the timezone information, because SQLite does not support it
|
|
return datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
class User(SQLModel, table=True):
|
|
id: int = Field(primary_key=True)
|
|
name: str
|
|
joined_at: datetime.datetime = Field(default_factory=utc_now)
|
|
active: bool = True
|
|
|
|
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="user")
|
|
|
|
def last_accepted_participation_request(self) -> Result["ParticipationRequest", None]:
|
|
accepted_requests = [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
|
if accepted_requests:
|
|
return Ok(max(accepted_requests, key=lambda r: r.requested_at))
|
|
else:
|
|
return Err(None)
|
|
|
|
def sort_time(self):
|
|
last_accepted = self.last_accepted_participation_request()
|
|
if is_ok(last_accepted):
|
|
last_accepted_obj = last_accepted.unwrap()
|
|
assert last_accepted_obj.state_since is not None
|
|
return last_accepted_obj.state_since
|
|
else:
|
|
return self.joined_at
|
|
|
|
def all_users_sorted(session: Session) -> List[User]:
|
|
users = session.exec(select(User).where(User.active)).all()
|
|
users.sort(key=lambda u: u.sort_time())
|
|
return users
|
|
|
|
def next_user_to_send_request(session: Session, task: "Task") -> User | None:
|
|
# Return the user who should send the next request (or None if we already
|
|
# sent requests for this Task to all users)
|
|
|
|
requested_users = [r.user for r in task.participation_requests]
|
|
|
|
users = all_users_sorted(session)
|
|
for u in users:
|
|
if u in requested_users:
|
|
continue
|
|
|
|
return u
|
|
|
|
def get_user_by_name(session: Session, name: str, only_active: bool = True) -> Result[User, None]:
|
|
Q = select(User).where(User.name == name)
|
|
|
|
if only_active:
|
|
Q = Q.where(User.active)
|
|
|
|
user = session.exec(Q).first()
|
|
|
|
if user is None:
|
|
return Err(None)
|
|
else:
|
|
return Ok(user)
|
|
|
|
def get_participation_request_by_timestamp(session: Session, timestamp: datetime.datetime) -> Result["ParticipationRequest", None]:
|
|
Q = select(ParticipationRequest).where(ParticipationRequest.requested_at == timestamp)
|
|
request = session.exec(Q).first()
|
|
|
|
if request is None:
|
|
return Err(None)
|
|
else:
|
|
return Ok(request)
|
|
|
|
class Task(SQLModel, table=True):
|
|
id: int = Field(primary_key=True)
|
|
name: str
|
|
required_number_of_participants: int
|
|
due: datetime.datetime
|
|
timeout: int # in seconds
|
|
chatgroup: str | None = None # None = to be created
|
|
pad_url: str | None = None # optional pad url
|
|
unfulfillable_message_sent: bool = False
|
|
|
|
participation_requests: list["ParticipationRequest"] = Relationship(back_populates="task")
|
|
|
|
def accepted_requests(self):
|
|
return [r for r in self.participation_requests if r.state == ParticipationState.ACCEPTED]
|
|
|
|
def rejected_requests(self):
|
|
return [r for r in self.participation_requests if r.state == ParticipationState.REJECTED]
|
|
|
|
def requested_requests(self):
|
|
return [r for r in self.participation_requests if r.state == ParticipationState.REQUESTED]
|
|
|
|
def timeout_requests(self):
|
|
return [r for r in self.participation_requests if r.state == ParticipationState.TIMEOUT]
|
|
|
|
def additional_requests_to_be_sent(self):
|
|
# return the number of additional requests to be sent
|
|
return self.required_number_of_participants \
|
|
- len(self.accepted_requests()) \
|
|
- len(self.requested_requests())
|
|
|
|
def freshly_expired_requests(self, now) -> List["ParticipationRequest"]:
|
|
expired_requests = []
|
|
|
|
for r in self.participation_requests:
|
|
maybe_timeout = r.check_for_timeout(now)
|
|
if maybe_timeout is not None:
|
|
expired_requests.append(r)
|
|
|
|
return expired_requests
|
|
|
|
def create_additional_requests(self, now: datetime.datetime, session: Session) -> Result[List["ParticipationRequest"], List["ParticipationRequest"]]:
|
|
additional_requests = []
|
|
|
|
for _ in range(self.additional_requests_to_be_sent()):
|
|
user = next_user_to_send_request(session, self)
|
|
if user is None:
|
|
# Incomplete requests
|
|
return Err(additional_requests)
|
|
|
|
request = ParticipationRequest(
|
|
user=user,
|
|
task=self,
|
|
requested_at=now
|
|
)
|
|
session.add(request)
|
|
additional_requests.append(request)
|
|
|
|
return Ok(additional_requests)
|
|
|
|
def get_active_tasks(session: Session, now: datetime.datetime) -> List[Task]:
|
|
return session.exec(select(Task).where(Task.due > now)).all()
|
|
|
|
def get_tasks(session: Session) -> List[Task]:
|
|
return session.exec(select(Task)).all()
|
|
|
|
class ParticipationState(enum.Enum):
|
|
REQUESTED = "requested"
|
|
ACCEPTED = "accepted"
|
|
REJECTED = "rejected"
|
|
TIMEOUT = "timeout"
|
|
|
|
# Transitions for a ParticipationState:
|
|
#
|
|
# Accepting or rejecting a request within timeout:
|
|
#
|
|
# - REQUESTED -> ACCEPTED (allowed)
|
|
# - REQUESTED -> REJECTED (allowed)
|
|
#
|
|
# Normal timeout:
|
|
#
|
|
# - REQUESTED -> TIMEOUT (allowed, happens automatically after timeout)
|
|
#
|
|
# Changing from ACCEPTED to REJECTED is allowed:
|
|
#
|
|
# - ACCEPTED -> REJECTED (allowed)
|
|
#
|
|
# Changing from REJECTED to ACCEPTED is allowed for 5 minutes only:
|
|
#
|
|
# - REJECTED -> ACCEPTED (allowed, within 5 minutes)
|
|
# - REJECTED -> ACCEPTED (not allowed, after 5 minutes)
|
|
#
|
|
# Answered requests can not timeout:
|
|
#
|
|
# - REJECTED -> TIMEOUT (should not happen, bug)
|
|
# - ACCEPTED -> TIMEOUT (should not happen, bug)
|
|
#
|
|
# Timed out requests can not be
|
|
#
|
|
# TIMEOUT -> ACCEPTED (not allowed)
|
|
# TIMEOUT -> REJECTED (not allowed)
|
|
# TIMEOUT -> REQUESTED (should not happen, bug)
|
|
#
|
|
# => TODO: What should happen after the due?!!!
|
|
# => TODO: What happens when a user has two requests?
|
|
|
|
|
|
class StateTransition:
|
|
pass
|
|
|
|
class StateTransitionError:
|
|
pass
|
|
|
|
class AcceptInTime(StateTransition):
|
|
pass
|
|
|
|
class AcceptAfterRejectAllowed(StateTransition):
|
|
pass
|
|
|
|
class AcceptAfterRejectExpired(StateTransitionError):
|
|
pass
|
|
|
|
class AcceptAfterTimeout(StateTransitionError):
|
|
pass
|
|
|
|
class AlreadyAccepted(StateTransitionError):
|
|
pass
|
|
|
|
class RejectInTime(StateTransition):
|
|
pass
|
|
|
|
class RejectAfterAccept(StateTransition):
|
|
pass
|
|
|
|
class RejectAfterTimeout(StateTransitionError):
|
|
pass
|
|
|
|
class AlreadyRejected(StateTransitionError):
|
|
pass
|
|
|
|
class Timeout(StateTransition):
|
|
pass
|
|
|
|
class ParticipationRequest(SQLModel, table=True):
|
|
id: int = Field(primary_key=True)
|
|
user_id: int = Field(foreign_key="user.id")
|
|
task_id: int = Field(foreign_key="task.id")
|
|
requested_at: datetime.datetime
|
|
|
|
state_since: datetime.datetime | None = None
|
|
state: ParticipationState = Field(default=ParticipationState.REQUESTED)
|
|
|
|
task: Task = Relationship(back_populates="participation_requests")
|
|
user: User = Relationship(back_populates="participation_requests")
|
|
|
|
def _change_state(self, new_state: ParticipationState, now: datetime.datetime):
|
|
# This method does not check any business rules, it just changes the state
|
|
self.state_since = now
|
|
self.state = new_state
|
|
|
|
def try_accept(self, now: datetime.datetime) -> Result[AcceptInTime | AcceptAfterRejectAllowed, AcceptAfterRejectExpired | AcceptAfterTimeout | AlreadyAccepted]:
|
|
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
|
# We want to make the order in which check_for_timeout() and try_accept() is called irrelevant
|
|
if self.is_timed_out(now):
|
|
return Err(AcceptAfterTimeout())
|
|
else:
|
|
self._change_state(ParticipationState.ACCEPTED, now)
|
|
return Ok(AcceptInTime())
|
|
elif self.state == ParticipationState.REJECTED:
|
|
if (now - self.state_since).total_seconds() < 5 * 60:
|
|
self._change_state(ParticipationState.ACCEPTED, now)
|
|
return Ok(AcceptAfterRejectAllowed())
|
|
else:
|
|
return Err(AcceptAfterRejectExpired())
|
|
elif self.state == ParticipationState.ACCEPTED:
|
|
return Err(AlreadyAccepted())
|
|
else:
|
|
raise Exception("Unknown old state " + str(self.state) + ".")
|
|
|
|
def try_reject(self, now: datetime.datetime) -> Result[RejectInTime | RejectAfterAccept, AlreadyRejected | RejectAfterTimeout]:
|
|
if self.state == ParticipationState.REQUESTED or self.state == ParticipationState.TIMEOUT:
|
|
# We want to make the order in which check_for_timeout() and try_reject() is called irrelevant
|
|
if self.is_timed_out(now):
|
|
return Err(RejectAfterTimeout())
|
|
else:
|
|
self._change_state(ParticipationState.REJECTED, now)
|
|
return Ok(RejectInTime())
|
|
elif self.state == ParticipationState.ACCEPTED:
|
|
self._change_state(ParticipationState.REJECTED, now)
|
|
return Ok(RejectAfterAccept())
|
|
elif self.state == ParticipationState.REJECTED:
|
|
return Err(AlreadyRejected())
|
|
else:
|
|
raise Exception("Unknown old state " + str(self.state) + ".")
|
|
|
|
def check_for_timeout(self, now: datetime.datetime) -> Timeout | None:
|
|
if self.state == ParticipationState.REQUESTED:
|
|
if self.is_timed_out(now):
|
|
self._change_state(ParticipationState.TIMEOUT, now)
|
|
return Timeout()
|
|
|
|
def is_timed_out(self, now: datetime.datetime) -> bool:
|
|
return (now - self.requested_at).total_seconds() > self.task.timeout
|
|
|
|
|
|
@pytest.fixture
|
|
def session():
|
|
if os.path.exists("/tmp/test.db"):
|
|
os.remove("/tmp/test.db")
|
|
|
|
engine = create_engine("sqlite:////tmp/test.db")
|
|
SQLModel.metadata.create_all(engine)
|
|
with Session(engine) as session:
|
|
yield session
|
|
|
|
def test_all_users_sorted(session):
|
|
u1 = User(name="u1")
|
|
u2 = User(name="u2")
|
|
u3 = User(name="u3")
|
|
|
|
session.add(u1)
|
|
session.add(u2)
|
|
session.add(u3)
|
|
|
|
session.commit()
|
|
|
|
users = all_users_sorted(session)
|
|
|
|
assert users == [u1, u2, u3]
|
|
|
|
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
|
|
|
|
session.add(t1)
|
|
session.commit()
|
|
|
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
|
|
|
|
session.add(r1)
|
|
session.commit()
|
|
|
|
assert t1.accepted_requests() == []
|
|
|
|
users = all_users_sorted(session)
|
|
|
|
r1._change_state(ParticipationState.ACCEPTED, utc_now())
|
|
session.commit()
|
|
|
|
assert t1.accepted_requests() == [r1]
|
|
users = all_users_sorted(session)
|
|
assert users == [u2, u3, u1]
|
|
|
|
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
|
|
session.add(r2)
|
|
|
|
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
|
|
session.add(r3)
|
|
|
|
session.commit()
|
|
|
|
r3._change_state(ParticipationState.ACCEPTED, utc_now())
|
|
|
|
session.commit()
|
|
|
|
assert t1.accepted_requests() == [r1, r3]
|
|
users = all_users_sorted(session)
|
|
assert users == [u2, u1, u3]
|
|
|
|
r2._change_state(ParticipationState.ACCEPTED, utc_now())
|
|
session.commit()
|
|
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u3, u2]
|
|
|
|
def test_next_user(session):
|
|
u1 = User(name="u1")
|
|
u2 = User(name="u2")
|
|
u3 = User(name="u3")
|
|
|
|
session.add(u1)
|
|
session.add(u2)
|
|
session.add(u3)
|
|
|
|
session.commit()
|
|
|
|
# Starting order
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u2, u3]
|
|
|
|
t1 = Task(name="t1", required_number_of_participants=2, due=utc_now(), timeout=10)
|
|
|
|
session.add(t1)
|
|
|
|
assert t1.additional_requests_to_be_sent() == 2
|
|
|
|
assert next_user_to_send_request(session, t1) == u1
|
|
|
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=utc_now())
|
|
session.add(r1)
|
|
|
|
assert next_user_to_send_request(session, t1) == u2
|
|
|
|
# Starting order
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u2, u3]
|
|
|
|
# Another additional request should be sent out.
|
|
assert t1.additional_requests_to_be_sent() == 1
|
|
|
|
# After rejecting, two additional requests should be sent out.
|
|
assert isinstance(r1.try_reject(utc_now()).unwrap(), RejectInTime)
|
|
assert t1.additional_requests_to_be_sent() == 2
|
|
|
|
# Create another Task
|
|
t2 = Task(name="t2", required_number_of_participants=2, due=utc_now(), timeout=10)
|
|
session.add(t2)
|
|
|
|
# Next user should be user1 (because we rejected in t1)
|
|
assert next_user_to_send_request(session, t2) == u1
|
|
|
|
# Still starting order (because we rejected)
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u2, u3]
|
|
|
|
# When we change to accept, only one additional request should be sent out.
|
|
assert isinstance(r1.try_accept(utc_now()).unwrap(), AcceptAfterRejectAllowed)
|
|
assert t1.additional_requests_to_be_sent() == 1
|
|
|
|
# Next user should be u2 for t2 (because u1 accepted for t1)
|
|
assert next_user_to_send_request(session, t2) == u2
|
|
|
|
# Next user should be user2 for t1 also
|
|
assert next_user_to_send_request(session, t1) == u2
|
|
|
|
# Still starting order (because we rejected)
|
|
users = all_users_sorted(session)
|
|
assert users == [u2, u3, u1]
|
|
|
|
# Next user should be user2
|
|
assert next_user_to_send_request(session, t1) == u2
|
|
|
|
r3 = ParticipationRequest(user=u3, task=t1, requested_at=utc_now())
|
|
session.add(r3)
|
|
|
|
# Next user should still be user2
|
|
assert next_user_to_send_request(session, t1) == u2
|
|
|
|
r2 = ParticipationRequest(user=u2, task=t1, requested_at=utc_now())
|
|
session.add(r2)
|
|
|
|
# No user should be left
|
|
assert next_user_to_send_request(session, t1) is None
|
|
|
|
session.commit()
|
|
|
|
# We sent one request too much...
|
|
assert t1.additional_requests_to_be_sent() == -1
|
|
|
|
def test_accept(session):
|
|
now = datetime.datetime(2024, 12, 10, 0, 0)
|
|
dt = datetime.timedelta(days=1)
|
|
|
|
u1 = User(name="u1")
|
|
|
|
# - Task will happen at now + 10*dt
|
|
# - Each user has 1*dt after the request to answer (before timeout)
|
|
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
|
|
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
|
|
|
session.add(u1)
|
|
session.add(t1)
|
|
session.add(r1)
|
|
session.commit()
|
|
|
|
accept_result = r1.try_accept(now + 0.5*dt)
|
|
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
|
session.commit()
|
|
|
|
accept_result = r1.try_accept(now + 0.5*dt)
|
|
assert isinstance(accept_result.unwrap_err(), AlreadyAccepted)
|
|
session.commit()
|
|
|
|
# Accepted should not timeout
|
|
timeout_result = r1.check_for_timeout(now + 1.6*dt)
|
|
assert timeout_result is None
|
|
|
|
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
|
session.add(r2)
|
|
|
|
# AcceptAfterTimeout must be given even before check_for_timeout()
|
|
accept_result = r2.try_accept(now + 1.6*dt)
|
|
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
|
|
|
assert(r2.state == ParticipationState.REQUESTED)
|
|
|
|
# Obtain Timeout
|
|
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
|
assert isinstance(timeout_result, Timeout)
|
|
|
|
# Should still be a AcceptAfterTimeout
|
|
accept_result = r2.try_accept(now + 1.6*dt)
|
|
assert isinstance(accept_result.unwrap_err(), AcceptAfterTimeout)
|
|
|
|
assert(r2.state == ParticipationState.TIMEOUT)
|
|
|
|
# If we reiceive a late accept, we should accept it
|
|
accept_result = r2.try_accept(now + 0.5*dt)
|
|
assert isinstance(accept_result.unwrap(), AcceptInTime)
|
|
|
|
assert(r2.state == ParticipationState.ACCEPTED)
|
|
|
|
session.commit()
|
|
|
|
def test_reject(session):
|
|
now = datetime.datetime(2024, 12, 10, 0, 0)
|
|
dt = datetime.timedelta(days=1)
|
|
|
|
u1 = User(name="u1")
|
|
|
|
# - Task will happen at now + 10*dt
|
|
# - Each user has 1*dt after the request to answer (before timeout)
|
|
t1 = Task(name="t1", required_number_of_participants=2, due=now + 10*dt, timeout=dt.total_seconds())
|
|
|
|
r1 = ParticipationRequest(user=u1, task=t1, requested_at=now)
|
|
|
|
session.add(u1)
|
|
session.add(t1)
|
|
session.add(r1)
|
|
session.commit()
|
|
|
|
reject_result = r1.try_reject(now + 0.5*dt)
|
|
assert isinstance(reject_result.unwrap(), RejectInTime)
|
|
session.commit()
|
|
|
|
reject_result = r1.try_reject(now + 0.5*dt)
|
|
assert isinstance(reject_result.unwrap_err(), AlreadyRejected)
|
|
session.commit()
|
|
|
|
r2 = ParticipationRequest(user=u1, task=t1, requested_at=now + 0.5*dt)
|
|
session.add(r2)
|
|
|
|
# RejectAfterTimeout must be given even before check_for_timeout()
|
|
reject_result = r2.try_reject(now + 1.6*dt)
|
|
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
|
|
|
assert(r2.state == ParticipationState.REQUESTED)
|
|
|
|
# Obtain Timeout
|
|
timeout_result = r2.check_for_timeout(now + 1.6*dt)
|
|
assert isinstance(timeout_result, Timeout)
|
|
|
|
assert(r2.state == ParticipationState.TIMEOUT)
|
|
|
|
# Should still be a RejectAfterTimeout
|
|
reject_result = r2.try_reject(now + 1.6*dt)
|
|
assert isinstance(reject_result.unwrap_err(), RejectAfterTimeout)
|
|
|
|
# If we receive a late reject, we should accept it
|
|
reject_result = r2.try_reject(now + 0.5*dt)
|
|
assert isinstance(reject_result.unwrap(), RejectInTime)
|
|
|
|
assert(r2.state == ParticipationState.REJECTED)
|
|
|
|
session.commit()
|
|
|
|
def test_active(session):
|
|
u1 = User(name="u1")
|
|
u2 = User(name="u2")
|
|
u3 = User(name="u3")
|
|
|
|
session.add(u1)
|
|
session.add(u2)
|
|
session.add(u3)
|
|
|
|
session.commit()
|
|
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u2, u3]
|
|
|
|
u2.active = False
|
|
session.commit()
|
|
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u3]
|
|
|
|
u2.active = True
|
|
session.commit()
|
|
|
|
users = all_users_sorted(session)
|
|
assert users == [u1, u2, u3]
|
|
|
|
def test_create_additional_requests(session):
|
|
now = datetime.datetime(2024, 12, 10, 0, 0)
|
|
dt = datetime.timedelta(days=1)
|
|
|
|
u1 = User(name="u1")
|
|
u2 = User(name="u2")
|
|
u3 = User(name="u3")
|
|
|
|
session.add(u1)
|
|
session.add(u2)
|
|
session.add(u3)
|
|
|
|
session.commit()
|
|
|
|
t1 = Task(name="t1", required_number_of_participants=2,
|
|
due=now + 10*dt, timeout=dt.total_seconds())
|
|
|
|
session.add(t1)
|
|
|
|
requests = t1.create_additional_requests(now + 0.5*dt, session)
|
|
assert(is_ok(requests))
|
|
requests = requests.unwrap()
|
|
assert(len(requests) == 2)
|
|
|
|
requested_users = [req.user for req in requests]
|
|
assert(requested_users == [u1, u2])
|
|
|
|
assert(isinstance(requests[0].try_reject(now + 0.75*dt).unwrap(), RejectInTime))
|
|
assert(isinstance(requests[1].try_reject(now + 0.75*dt).unwrap(), RejectInTime))
|
|
|
|
requests = t1.create_additional_requests(now + 0.5*dt, session)
|
|
|
|
# could not create all requests, so we have Err now
|
|
assert(is_err(requests))
|
|
requests = requests.unwrap_err()
|
|
assert(len(requests) == 1)
|
|
|
|
requested_users = [req.user for req in requests]
|
|
assert(requested_users == [u3])
|
|
|
|
# No user left to ask...
|
|
requests = t1.create_additional_requests(now + 0.5*dt, session)
|
|
|
|
assert(is_err(requests))
|
|
requests = requests.unwrap_err()
|
|
assert(len(requests) == 0)
|
|
|
|
u4 = User(name="u4")
|
|
session.add(u4)
|
|
|
|
# Now, we have a user again, so we can ask. Since u3 did not
|
|
# reject (so far), we only need one additional request. So,
|
|
# the result is Ok(...).
|
|
requests = t1.create_additional_requests(now + 0.5*dt, session)
|
|
|
|
assert(is_ok(requests))
|
|
requests = requests.unwrap()
|
|
assert(len(requests) == 1)
|
|
|
|
requested_users = [req.user for req in requests]
|
|
assert(requested_users == [u4])
|
|
|
|
session.commit()
|