43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
from typing import Iterable, List, Union, overload
|
|
|
|
from modules.constrains.static import VTaskStaticConstraint
|
|
from modules.constrains.types import QTaskOrFabric
|
|
from modules.fabric import QTaskFabric
|
|
from modules.task import QTask
|
|
from utils.types import A, C, Q, V
|
|
|
|
|
|
class MustBeAnyConstraint(VTaskStaticConstraint[C, V, Q, A]):
|
|
must_be_generated_by: List[QTaskFabric[C, V, Q, A]] = []
|
|
must_be_one_of_tasks: List[QTask[C, V, Q, A]] = []
|
|
|
|
@overload
|
|
def __init__(self, item: Iterable[QTaskOrFabric[C, V, Q, A]]): ...
|
|
|
|
@overload
|
|
def __init__(
|
|
self, item: QTaskOrFabric[C, V, Q, A], **kwargs: QTaskOrFabric[C, V, Q, A]
|
|
): ...
|
|
|
|
def __init__(
|
|
self,
|
|
item: Union[Iterable[QTaskOrFabric[C, V, Q, A]], QTaskOrFabric[C, V, Q, A]],
|
|
**kwargs: QTaskOrFabric[C, V, Q, A],
|
|
):
|
|
all_items = []
|
|
if isinstance(item, List):
|
|
all_items.extend(item)
|
|
else:
|
|
all_items.append(item)
|
|
all_items.extend(kwargs.values())
|
|
self.must_be_generated_by = [v for v in all_items if isinstance(v, QTaskFabric)]
|
|
self.must_be_one_of_tasks = [v for v in all_items if isinstance(v, QTask)]
|
|
|
|
def is_satisfied(self, task: QTask[C, V, Q, A]) -> bool:
|
|
return any(
|
|
[
|
|
task.fabric_metadata.unwrap_or(None) == g.metadata.id
|
|
for g in self.must_be_generated_by
|
|
]
|
|
) or any([task.id == t.id for t in self.must_be_one_of_tasks])
|