from typing import Iterable, List, Union, overload from modules.constrains.static import VTaskStaticConstraint from modules.constrains.types import QTaskOrFabric from modules.fabric import QTaskFabric from modules.task import QTask from utils.types import A, C, Q, V class MustBeAnyConstraint(VTaskStaticConstraint[C, V, Q, A]): must_be_generated_by: List[QTaskFabric[C, V, Q, A]] = [] must_be_one_of_tasks: List[QTask[C, V, Q, A]] = [] @overload def __init__(self, item: Iterable[QTaskOrFabric[C, V, Q, A]]): ... @overload def __init__( self, item: QTaskOrFabric[C, V, Q, A], **kwargs: QTaskOrFabric[C, V, Q, A] ): ... def __init__( self, item: Union[Iterable[QTaskOrFabric[C, V, Q, A]], QTaskOrFabric[C, V, Q, A]], **kwargs: QTaskOrFabric[C, V, Q, A], ): all_items = [] if isinstance(item, List): all_items.extend(item) else: all_items.append(item) all_items.extend(kwargs.values()) self.must_be_generated_by = [v for v in all_items if isinstance(v, QTaskFabric)] self.must_be_one_of_tasks = [v for v in all_items if isinstance(v, QTask)] def is_satisfied(self, task: QTask[C, V, Q, A]) -> bool: return any( [ task.fabric_metadata.unwrap_or(None) == g.metadata.id for g in self.must_be_generated_by ] ) or any([task.id == t.id for t in self.must_be_one_of_tasks])