Quizard/modules/constrains/static/must_be_any.py

43 lines
1.5 KiB
Python

from typing import Iterable, List, Union, overload
from modules.constrains.static import VTaskStaticConstraint
from modules.constrains.types import QTaskOrFabric
from modules.fabric import QTaskFabric
from modules.task import QTask
from utils.types import A, C, Q, V
class MustBeAnyConstraint(VTaskStaticConstraint[C, V, Q, A]):
must_be_generated_by: List[QTaskFabric[C, V, Q, A]] = []
must_be_one_of_tasks: List[QTask[C, V, Q, A]] = []
@overload
def __init__(self, item: Iterable[QTaskOrFabric[C, V, Q, A]]): ...
@overload
def __init__(
self, item: QTaskOrFabric[C, V, Q, A], **kwargs: QTaskOrFabric[C, V, Q, A]
): ...
def __init__(
self,
item: Union[Iterable[QTaskOrFabric[C, V, Q, A]], QTaskOrFabric[C, V, Q, A]],
**kwargs: QTaskOrFabric[C, V, Q, A],
):
all_items = []
if isinstance(item, List):
all_items.extend(item)
else:
all_items.append(item)
all_items.extend(kwargs.values())
self.must_be_generated_by = [v for v in all_items if isinstance(v, QTaskFabric)]
self.must_be_one_of_tasks = [v for v in all_items if isinstance(v, QTask)]
def is_satisfied(self, task: QTask[C, V, Q, A]) -> bool:
return any(
[
task.fabric_metadata.unwrap_or(None) == g.metadata.id
for g in self.must_be_generated_by
]
) or any([task.id == t.id for t in self.must_be_one_of_tasks])