diff --git a/changes/345.added b/changes/345.added new file mode 100644 index 00000000..1d1bf00f --- /dev/null +++ b/changes/345.added @@ -0,0 +1,10 @@ +Added bulk sync operations and advanced filtering capabilities to Adapter diff/sync methods. +Added Diff.filter() and Diff.exclude() methods for post-diff, pre-sync manipulation of computed diffs. +Added model_types parameter to scope diffs and syncs to specific model types. +Added sync_attrs and exclude_attrs parameters for attribute-level diff control. +Added filters parameter for per-model-type query predicates during diff calculation. +Added sync_filter callback parameter to approve or reject individual CRUD operations during sync. +Added structured operations summary passed to sync_complete() after sync. +Added bulk CRUD methods (create_bulk, update_bulk, delete_bulk) on DiffSyncModel. +Added store-level bulk methods (add_bulk, update_bulk, remove_bulk) on BaseStore, LocalStore, and RedisStore. +Added concurrent flag for parallel sync of independent top-level subtrees. diff --git a/diffsync/__init__.py b/diffsync/__init__.py index 402320ea..f22ca21c 100644 --- a/diffsync/__init__.py +++ b/diffsync/__init__.py @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=too-many-lines import sys from copy import deepcopy @@ -58,7 +59,7 @@ StrType = str -class DiffSyncModel(BaseModel): +class DiffSyncModel(BaseModel): # pylint: disable=too-many-public-methods """Base class for all DiffSync object models. Note that read-only APIs of this class are implemented as `get_*()` methods rather than as properties; @@ -306,6 +307,51 @@ def delete(self) -> Optional[Self]: """ return self.delete_base() + @classmethod + def create_bulk(cls, adapter: "Adapter", objects: List[Dict]) -> List[Optional[Self]]: + """Bulk create multiple instances. Override for batch creation (e.g. single API call). + + The default implementation loops over individual create() calls. + + Args: + adapter: The master data store for other DiffSyncModel instances + objects: List of dicts, each with "ids" and "attrs" keys + + Returns: + List of created DiffSyncModel instances (or None for failed creations) + """ + return [cls.create(adapter=adapter, ids=obj["ids"], attrs=obj["attrs"]) for obj in objects] + + @classmethod + def update_bulk(cls, adapter: "Adapter", objects: List[Tuple["DiffSyncModel", Dict]]) -> List[Optional[Self]]: # noqa: ARG003 # pylint: disable=unused-argument + """Bulk update multiple instances. Override for batch updates (e.g. single API call). + + The default implementation loops over individual update() calls. + + Args: + adapter: The master data store for other DiffSyncModel instances + objects: List of (existing_model, attrs_to_update) tuples + + Returns: + List of updated DiffSyncModel instances (or None for failed updates) + """ + return [model.update(attrs=attrs) for model, attrs in objects] + + @classmethod + def delete_bulk(cls, adapter: "Adapter", objects: List["DiffSyncModel"]) -> List[Optional[Self]]: # noqa: ARG003 # pylint: disable=unused-argument + """Bulk delete multiple instances. Override for batch deletion (e.g. single API call). + + The default implementation loops over individual delete() calls. + + Args: + adapter: The master data store for other DiffSyncModel instances + objects: List of model instances to delete + + Returns: + List of deleted DiffSyncModel instances (or None for failed deletions) + """ + return [model.delete() for model in objects] + @classmethod def get_type(cls) -> StrType: """Return the type AKA modelname of the object or the class. @@ -441,6 +487,22 @@ class Adapter: # pylint: disable=too-many-public-methods top_level: ClassVar[List[str]] = [] """List of top-level modelnames to begin from when diffing or synchronizing.""" + sync_stages: ClassVar[Optional[List[List[str]]]] = None + """Optional ordered groups of model types for staged concurrent sync. + + Each inner list is a "stage" of model types that can safely execute in parallel. + Stages are processed sequentially — all elements in stage N complete before stage N+1 begins. + Only used when ``concurrent=True``; ignored for serial sync. + + Example:: + + sync_stages = [ + ["site", "vlan"], # stage 1: independent types, run in parallel + ["device"], # stage 2: depends on sites + ["interface"], # stage 3: depends on devices + ] + """ + def __init__( self, name: Optional[str] = None, @@ -482,13 +544,28 @@ def __init_subclass__(cls) -> None: if not isclass(value) or not issubclass(value, DiffSyncModel): raise AttributeError(f'top_level references attribute "{name}" but it is not a DiffSyncModel subclass!') + if cls.sync_stages is not None: + top_level_set = set(cls.top_level) + seen: set = set() + for stage in cls.sync_stages: + for model_type in stage: + if model_type not in top_level_set: + raise AttributeError( + f'sync_stages references "{model_type}" but it is not in top_level!' + ) + if model_type in seen: + raise AttributeError( + f'sync_stages contains duplicate entry "{model_type}"!' + ) + seen.add(model_type) + def __new__(cls, **kwargs): # type: ignore[no-untyped-def] """Document keyword arguments that were used to initialize Adapter.""" meta_kwargs = {} for key, value in kwargs.items(): try: meta_kwargs[key] = deepcopy(value) - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint: disable=broad-except # Some objects (e.g. Kafka Consumer, DB connections) cannot be deep copied meta_kwargs[key] = value instance = super().__new__(cls) @@ -574,13 +651,21 @@ def load_from_dict(self, data: Dict) -> None: # Synchronization between DiffSync instances # ------------------------------------------------------------------------------ - def sync_from( # pylint: disable=too-many-arguments, too-many-positional-arguments + def sync_from( # pylint: disable=too-many-arguments,R0917,too-many-locals self, source: "Adapter", diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[StrType, int, int], None]] = None, diff: Optional[Diff] = None, + model_types: Optional[Set[StrType]] = None, + filters: Optional[Dict[StrType, Callable]] = None, + sync_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + exclude_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + sync_filter: Optional[Callable[[StrType, StrType, Dict, Dict], bool]] = None, + batch_size: Optional[int] = None, + concurrent: bool = False, + max_workers: Optional[int] = None, ) -> Diff: """Synchronize data from the given source DiffSync object into the current DiffSync object. @@ -591,6 +676,14 @@ def sync_from( # pylint: disable=too-many-arguments, too-many-positional-argume callback: Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff and subsequent sync proceed. diff: An existing diff to be used rather than generating a completely new diff. + model_types: Optional set of model type names to restrict the sync to. + filters: Optional dict of {model_type: predicate_callable} to filter which objects are synced. + sync_attrs: Optional dict of {model_type: set_of_attr_names} to whitelist attributes for syncing. + exclude_attrs: Optional dict of {model_type: set_of_attr_names} to exclude attributes from syncing. + sync_filter: Optional callback (action, model_type, ids, attrs) -> bool to approve/reject each operation. + batch_size: Optional chunk size for batched sync execution. + concurrent: If True, sync independent top-level subtrees in parallel. + max_workers: Maximum number of threads for concurrent sync. Returns: Diff between origin object and source @@ -605,27 +698,54 @@ def sync_from( # pylint: disable=too-many-arguments, too-many-positional-argume # Generate the diff if an existing diff was not provided if not diff: - diff = self.diff_from(source, diff_class=diff_class, flags=flags, callback=callback) + diff = self.diff_from( + source, + diff_class=diff_class, + flags=flags, + callback=callback, + model_types=model_types, + filters=filters, + sync_attrs=sync_attrs, + exclude_attrs=exclude_attrs, + ) syncer = DiffSyncSyncer( diff=diff, src_diffsync=source, dst_diffsync=self, flags=flags, callback=callback, + sync_filter=sync_filter, + batch_size=batch_size, + concurrent=concurrent, + max_workers=max_workers, + sync_stages=self.sync_stages, ) result = syncer.perform_sync() if result: - self.sync_complete(source, diff, flags, syncer.base_logger) + # Pass operations summary to sync_complete + try: + self.sync_complete(source, diff, flags, syncer.base_logger, operations=syncer.operations) + except TypeError: + # Backwards compatibility: existing subclass overrides may not accept operations kwarg + self.sync_complete(source, diff, flags, syncer.base_logger) return diff - def sync_to( # pylint: disable=too-many-arguments, too-many-positional-arguments + def sync_to( # pylint: disable=too-many-arguments,R0917 self, target: "Adapter", diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[StrType, int, int], None]] = None, diff: Optional[Diff] = None, + model_types: Optional[Set[StrType]] = None, + filters: Optional[Dict[StrType, Callable]] = None, + sync_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + exclude_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + sync_filter: Optional[Callable[[StrType, StrType, Dict, Dict], bool]] = None, + batch_size: Optional[int] = None, + concurrent: bool = False, + max_workers: Optional[int] = None, ) -> Diff: """Synchronize data from the current DiffSync object into the given target DiffSync object. @@ -636,20 +756,43 @@ def sync_to( # pylint: disable=too-many-arguments, too-many-positional-argument callback: Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff and subsequent sync proceed. diff: An existing diff that will be used when determining what needs to be synced. + model_types: Optional set of model type names to restrict the sync to. + filters: Optional dict of {model_type: predicate_callable} to filter which objects are synced. + sync_attrs: Optional dict of {model_type: set_of_attr_names} to whitelist attributes for syncing. + exclude_attrs: Optional dict of {model_type: set_of_attr_names} to exclude attributes from syncing. + sync_filter: Optional callback (action, model_type, ids, attrs) -> bool to approve/reject each operation. + batch_size: Optional chunk size for batched sync execution. + concurrent: If True, sync independent top-level subtrees in parallel. + max_workers: Maximum number of threads for concurrent sync. Returns: Diff between origin object and target Raises: DiffClassMismatch: The provided diff's class does not match the diff_class """ - return target.sync_from(self, diff_class=diff_class, flags=flags, callback=callback, diff=diff) + return target.sync_from( + self, + diff_class=diff_class, + flags=flags, + callback=callback, + diff=diff, + model_types=model_types, + filters=filters, + sync_attrs=sync_attrs, + exclude_attrs=exclude_attrs, + sync_filter=sync_filter, + batch_size=batch_size, + concurrent=concurrent, + max_workers=max_workers, + ) - def sync_complete( + def sync_complete( # pylint: disable=too-many-arguments,R0917 self, source: "Adapter", diff: Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, logger: Optional[structlog.BoundLogger] = None, + operations: Optional[Dict[StrType, Dict[StrType, List[Dict]]]] = None, ) -> None: """Callback triggered after a `sync_from` operation has completed and updated the model data of this instance. @@ -664,18 +807,24 @@ def sync_complete( diff: The Diff calculated prior to the sync operation. flags: Any flags that influenced the sync. logger: Logging context for the sync. + operations: Structured summary of all CRUD operations performed during sync. + Format: {"model_type": {"create": [{"ids": {...}, "attrs": {...}, "model": ...}], "update": [...], "delete": [...]}} """ # ------------------------------------------------------------------------------ # Diff calculation and construction # ------------------------------------------------------------------------------ - def diff_from( + def diff_from( # pylint: disable=too-many-arguments,R0917 self, source: "Adapter", diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[StrType, int, int], None]] = None, + model_types: Optional[Set[StrType]] = None, + filters: Optional[Dict[StrType, Callable]] = None, + sync_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + exclude_attrs: Optional[Dict[StrType, Set[StrType]]] = None, ) -> Diff: """Generate a Diff describing the difference from the other DiffSync to this one. @@ -685,6 +834,10 @@ def diff_from( flags: Flags influencing the behavior of this diff operation. callback: Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff proceeds. + model_types: Optional set of model type names to restrict the diff to. + filters: Optional dict of {model_type: predicate_callable} to filter which objects are diffed. + sync_attrs: Optional dict of {model_type: set_of_attr_names} to whitelist attributes for diffing. + exclude_attrs: Optional dict of {model_type: set_of_attr_names} to exclude attributes from diffing. """ differ = DiffSyncDiffer( src_diffsync=source, @@ -692,15 +845,23 @@ def diff_from( flags=flags, diff_class=diff_class, callback=callback, + model_types=model_types, + filters=filters, + sync_attrs=sync_attrs, + exclude_attrs=exclude_attrs, ) return differ.calculate_diffs() - def diff_to( + def diff_to( # pylint: disable=too-many-arguments,R0917 self, target: "Adapter", diff_class: Type[Diff] = Diff, flags: DiffSyncFlags = DiffSyncFlags.NONE, callback: Optional[Callable[[StrType, int, int], None]] = None, + model_types: Optional[Set[StrType]] = None, + filters: Optional[Dict[StrType, Callable]] = None, + sync_attrs: Optional[Dict[StrType, Set[StrType]]] = None, + exclude_attrs: Optional[Dict[StrType, Set[StrType]]] = None, ) -> Diff: """Generate a Diff describing the difference from this DiffSync to another one. @@ -710,8 +871,21 @@ def diff_to( flags: Flags influencing the behavior of this diff operation. callback: Function with parameters (stage, current, total), to be called at intervals as the calculation of the diff proceeds. + model_types: Optional set of model type names to restrict the diff to. + filters: Optional dict of {model_type: predicate_callable} to filter which objects are diffed. + sync_attrs: Optional dict of {model_type: set_of_attr_names} to whitelist attributes for diffing. + exclude_attrs: Optional dict of {model_type: set_of_attr_names} to exclude attributes from diffing. """ - return target.diff_from(self, diff_class=diff_class, flags=flags, callback=callback) + return target.diff_from( + self, + diff_class=diff_class, + flags=flags, + callback=callback, + model_types=model_types, + filters=filters, + sync_attrs=sync_attrs, + exclude_attrs=exclude_attrs, + ) # ------------------------------------------------------------------------------ # Object Storage Management diff --git a/diffsync/diff.py b/diffsync/diff.py index 4e226fa7..df0770e5 100644 --- a/diffsync/diff.py +++ b/diffsync/diff.py @@ -15,8 +15,9 @@ limitations under the License. """ +import copy from functools import total_ordering -from typing import Any, Dict, Iterable, Iterator, List, Optional, Type +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Type from .enum import DiffSyncActions from .exceptions import ObjectAlreadyExists @@ -155,12 +156,172 @@ def dict(self) -> Dict[StrType, Dict[StrType, Dict]]: result[child.type][child.name] = child.dict() return dict(result) + def filter( + self, + actions: Optional[Set[StrType]] = None, + model_types: Optional[Set[StrType]] = None, + ) -> "Diff": + """Return a new Diff containing only elements matching the given criteria. + + Args: + actions: If provided, only include elements whose action is in this set (e.g. {"create", "update"}). + model_types: If provided, only include elements whose type is in this set. + + Returns: + A new Diff instance with only the matching elements. + """ + new_diff = self.__class__() + new_diff.models_processed = self.models_processed + + for group in self.groups(): + if model_types is not None and group not in model_types: + continue + for element in self.children[group].values(): + filtered = _filter_diff_element(element, actions=actions, model_types=model_types) + if filtered is not None: + new_diff.add(filtered) + + return new_diff + + def exclude( + self, + actions: Optional[Set[StrType]] = None, + model_types: Optional[Set[StrType]] = None, + ) -> "Diff": + """Return a new Diff excluding elements matching the given criteria. + + Args: + actions: If provided, exclude elements whose action is in this set. + model_types: If provided, exclude elements whose type is in this set. + + Returns: + A new Diff instance without the excluded elements. + """ + new_diff = self.__class__() + new_diff.models_processed = self.models_processed + + for group in self.groups(): + if model_types is not None and group in model_types: + continue + for element in self.children[group].values(): + filtered = _exclude_diff_element(element, actions=actions, model_types=model_types) + if filtered is not None: + new_diff.add(filtered) + + return new_diff + + +def _copy_diff_element(element: "DiffElement") -> "DiffElement": + """Create a shallow copy of a DiffElement without its children.""" + new_element = DiffElement( + obj_type=element.type, + name=element.name, + keys=copy.copy(element.keys), + source_name=element.source_name, + dest_name=element.dest_name, + ) + if element.source_attrs is not None: + new_element.source_attrs = copy.copy(element.source_attrs) + if element.dest_attrs is not None: + new_element.dest_attrs = copy.copy(element.dest_attrs) + return new_element + + +def _filter_diff_element( + element: "DiffElement", + actions: Optional[Set[StrType]] = None, + model_types: Optional[Set[StrType]] = None, +) -> Optional["DiffElement"]: + """Recursively filter a DiffElement, returning a copy with only matching elements or None.""" + # Check if this element's action matches + if actions is not None and element.action not in actions: + # Even if the element itself doesn't match, its children might + has_matching_children = False + new_element = _copy_diff_element(element) + # Clear attrs so this element itself shows no action + new_element.source_attrs = element.source_attrs + new_element.dest_attrs = element.dest_attrs + + for child in element.get_children(): + filtered_child = _filter_diff_element(child, actions=actions, model_types=model_types) + if filtered_child is not None: + new_element.add_child(filtered_child) + has_matching_children = True + + if has_matching_children: + # Return the element as a container for its matching children, but neutralize its own action + neutral = _copy_diff_element(element) + # Set both attrs to be equal so action becomes None + if element.source_attrs is not None and element.dest_attrs is not None: + neutral.source_attrs = copy.copy(element.dest_attrs) + neutral.dest_attrs = copy.copy(element.dest_attrs) + elif element.source_attrs is not None: + neutral.source_attrs = None + neutral.dest_attrs = None + elif element.dest_attrs is not None: + neutral.source_attrs = None + neutral.dest_attrs = None + neutral.child_diff = new_element.child_diff + return neutral + return None + + # Element itself matches (or no action filter) + new_element = _copy_diff_element(element) + for child in element.get_children(): + filtered_child = _filter_diff_element(child, actions=actions, model_types=model_types) + if filtered_child is not None: + new_element.add_child(filtered_child) + + return new_element + + +def _exclude_diff_element( + element: "DiffElement", + actions: Optional[Set[StrType]] = None, + model_types: Optional[Set[StrType]] = None, +) -> Optional["DiffElement"]: + """Recursively exclude matching elements from a DiffElement, returning a copy or None.""" + # If this element's action should be excluded + if actions is not None and element.action in actions: + # Still check children — they might not be excluded + has_kept_children = False + new_element = _copy_diff_element(element) + # Neutralize the element's own action + if element.source_attrs is not None and element.dest_attrs is not None: + new_element.source_attrs = copy.copy(element.dest_attrs) + new_element.dest_attrs = copy.copy(element.dest_attrs) + elif element.source_attrs is not None: + new_element.source_attrs = None + new_element.dest_attrs = None + elif element.dest_attrs is not None: + new_element.source_attrs = None + new_element.dest_attrs = None + + for child in element.get_children(): + excluded_child = _exclude_diff_element(child, actions=actions, model_types=model_types) + if excluded_child is not None: + new_element.add_child(excluded_child) + has_kept_children = True + + if has_kept_children: + return new_element + return None + + # Element itself is not excluded + new_element = _copy_diff_element(element) + for child in element.get_children(): + excluded_child = _exclude_diff_element(child, actions=actions, model_types=model_types) + if excluded_child is not None: + new_element.add_child(excluded_child) + + return new_element + @total_ordering class DiffElement: # pylint: disable=too-many-instance-attributes """DiffElement object, designed to represent a single item/object that may or may not have any diffs.""" - def __init__( # pylint: disable=too-many-positional-arguments + def __init__( # pylint: disable=R0917 self, obj_type: StrType, name: StrType, diff --git a/diffsync/helpers.py b/diffsync/helpers.py index 5765692b..01ee9df2 100644 --- a/diffsync/helpers.py +++ b/diffsync/helpers.py @@ -15,8 +15,10 @@ limitations under the License. """ +import threading from collections.abc import Iterable as ABCIterable from collections.abc import Mapping as ABCMapping +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Type import structlog # type: ignore @@ -37,13 +39,17 @@ class DiffSyncDiffer: # pylint: disable=too-many-instance-attributes Independent from Diff and DiffElement as those classes are purely data objects, while this stores some state. """ - def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments + def __init__( # pylint: disable=too-many-arguments,R0917 self, src_diffsync: "Adapter", dst_diffsync: "Adapter", flags: DiffSyncFlags, diff_class: Type[Diff] = Diff, callback: Optional[Callable[[str, int, int], None]] = None, + model_types: Optional[set] = None, + filters: Optional[Dict[str, Callable]] = None, + sync_attrs: Optional[Dict[str, set]] = None, + exclude_attrs: Optional[Dict[str, set]] = None, ): """Create a DiffSyncDiffer for calculating diffs between the provided DiffSync instances.""" self.src_diffsync = src_diffsync @@ -55,6 +61,14 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen self.callback = callback self.diff: Optional[Diff] = None + # Model-type scoping + self.model_types = model_types + # Attribute-based query predicates + self.filters = filters + # Attribute-scoped syncing + self.sync_attrs = sync_attrs + self.exclude_attrs = exclude_attrs + self.models_processed = 0 self.total_models = len(src_diffsync) + len(dst_diffsync) self.logger.debug(f"Diff calculation between these two datasets will involve {self.total_models} models") @@ -85,7 +99,12 @@ def calculate_diffs(self) -> Diff: elif skipped_type in self.src_diffsync.top_level: self.incr_models_processed(len(self.src_diffsync.get_all(skipped_type))) - for obj_type in intersection(self.dst_diffsync.top_level, self.src_diffsync.top_level): + top_level_types = intersection(self.dst_diffsync.top_level, self.src_diffsync.top_level) + # Model-type scoping + if self.model_types is not None: + top_level_types = [t for t in top_level_types if t in self.model_types] + + for obj_type in top_level_types: diff_elements = self.diff_object_list( src=self.src_diffsync.get_all(obj_type), dst=self.dst_diffsync.get_all(obj_type), @@ -160,7 +179,7 @@ def validate_objects_for_diff( if src_obj.get_identifiers() != dst_obj.get_identifiers(): raise ValueError(f"Keys mismatch: {src_obj.get_identifiers()} vs {dst_obj.get_identifiers()}") - def diff_object_pair( # pylint: disable=too-many-return-statements + def diff_object_pair( # pylint: disable=too-many-return-statements, too-many-branches, too-many-statements self, src_obj: Optional["DiffSyncModel"], dst_obj: Optional["DiffSyncModel"] ) -> Optional[DiffElement]: """Diff the two provided DiffSyncModel objects and return a DiffElement or None. @@ -209,6 +228,16 @@ def diff_object_pair( # pylint: disable=too-many-return-statements self.incr_models_processed() return None + # Attribute-based query predicates + if self.filters and model in self.filters: + predicate = self.filters[model] + obj_to_check = src_obj or dst_obj + if obj_to_check and not predicate(obj_to_check): + log.debug("Skipping due to query predicate filter") + delta = (1 if src_obj else 0) + (1 if dst_obj else 0) + self.incr_models_processed(delta) + return None + diff_element = DiffElement( obj_type=model, name=shortname, @@ -220,10 +249,14 @@ def diff_object_pair( # pylint: disable=too-many-return-statements delta = 0 if src_obj: - diff_element.add_attrs(source=src_obj.get_attrs(), dest=None) + src_attrs = src_obj.get_attrs() + src_attrs = self._filter_attrs(model, src_attrs) + diff_element.add_attrs(source=src_attrs, dest=None) delta += 1 if dst_obj: - diff_element.add_attrs(source=None, dest=dst_obj.get_attrs()) + dst_attrs = dst_obj.get_attrs() + dst_attrs = self._filter_attrs(model, dst_attrs) + diff_element.add_attrs(source=None, dest=dst_attrs) delta += 1 self.incr_models_processed(delta) @@ -233,6 +266,17 @@ def diff_object_pair( # pylint: disable=too-many-return-statements return diff_element + def _filter_attrs(self, model_type: str, attrs: Dict) -> Dict: + """Filter attributes based on sync_attrs and exclude_attrs settings. + + Apply sync_attrs whitelist first, then exclude_attrs blacklist. + """ + if self.sync_attrs and model_type in self.sync_attrs: + attrs = {k: v for k, v in attrs.items() if k in self.sync_attrs[model_type]} + if self.exclude_attrs and model_type in self.exclude_attrs: + attrs = {k: v for k, v in attrs.items() if k not in self.exclude_attrs[model_type]} + return attrs + def diff_child_objects( self, diff_element: DiffElement, @@ -268,6 +312,10 @@ def diff_child_objects( raise RuntimeError("Called with neither src_obj nor dest_obj??") for child_type, child_fieldname in children_mapping.items(): + # Model-type scoping — skip child types not in model_types + if self.model_types is not None and child_type not in self.model_types: + continue + # for example, child_type == "device" and child_fieldname == "devices" # for example, getattr(src_obj, "devices") --> list of device uids @@ -287,13 +335,18 @@ class DiffSyncSyncer: # pylint: disable=too-many-instance-attributes Independent from DiffSync and DiffSyncModel as those classes are purely data objects, while this stores some state. """ - def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments + def __init__( # pylint: disable=too-many-arguments,R0917 self, diff: Diff, src_diffsync: "Adapter", dst_diffsync: "Adapter", flags: DiffSyncFlags, callback: Optional[Callable[[str, int, int], None]] = None, + sync_filter: Optional[Callable[[str, str, Dict, Dict], bool]] = None, + batch_size: Optional[int] = None, + concurrent: bool = False, + max_workers: Optional[int] = None, + sync_stages: Optional[List[List[str]]] = None, ): """Create a DiffSyncSyncer instance, ready to call `perform_sync()` against.""" self.diff = diff @@ -302,15 +355,31 @@ def __init__( # pylint: disable=too-many-arguments, too-many-positional-argumen self.flags = flags self.callback = callback + # Callback-based sync interceptor + self.sync_filter = sync_filter + + # Chunked/batched sync execution + self.batch_size = batch_size + + # Parallel sync of independent subtrees + self.concurrent = concurrent + self.max_workers = max_workers + self.sync_stages = sync_stages + + # Structured operations summary + self.operations: Dict[str, Dict[str, List[Dict]]] = {} + self._operations_lock = threading.Lock() + self.elements_processed = 0 self.total_elements = len(diff) self.base_logger = structlog.get_logger().new(src=src_diffsync, dst=dst_diffsync, flags=flags) - # Local state maintained during synchronization - self.logger: structlog.BoundLogger = self.base_logger - self.model_class: Type["DiffSyncModel"] - self.action: Optional[str] = None + # Thread-local state maintained during synchronization (for concurrent safety) + self._local = threading.local() + self._local.logger = self.base_logger + self._local.model_class = None + self._local.action = None def incr_elements_processed(self, delta: int = 1) -> None: """Increment self.elements_processed, then call self.callback if present.""" @@ -327,8 +396,37 @@ def perform_sync(self) -> bool: """ changed = False self.base_logger.info("Beginning sync") - for element in self.diff.get_children(): - changed |= self.sync_diff_element(element) + + # Parallel sync of independent subtrees + if self.concurrent: + if self.sync_stages: + # Staged concurrent execution: process each stage sequentially, + # parallelizing elements within each stage. + for stage in self.sync_stages: + stage_set = set(stage) + stage_elements = [el for el in self.diff.get_children() if el.type in stage_set] + if stage_elements: + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = {executor.submit(self.sync_diff_element, el): el for el in stage_elements} + for future in as_completed(futures): + changed |= future.result() + + # Handle any elements whose type is not covered by sync_stages (serial fallback) + staged_types = {t for stage in self.sync_stages for t in stage} + for element in self.diff.get_children(): + if element.type not in staged_types: + changed |= self.sync_diff_element(element) + else: + # No stages defined — all elements in one pool (original behavior) + elements = list(self.diff.get_children()) + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = {executor.submit(self.sync_diff_element, element): element for element in elements} + for future in as_completed(futures): + changed |= future.result() + else: + for element in self.diff.get_children(): + changed |= self.sync_diff_element(element) + self.base_logger.info("Sync complete") return changed @@ -340,26 +438,26 @@ def sync_diff_element(self, element: DiffElement, parent_model: Optional["DiffSy Returns: bool: True if this element or any of its children resulted in actual changes, else False. """ - self.model_class = getattr(self.dst_diffsync, element.type) + self._local.model_class = getattr(self.dst_diffsync, element.type) diffs = element.get_attrs_diffs() - self.logger = self.base_logger.bind( + self._local.logger = self.base_logger.bind( action=element.action, model=element.type, - unique_id=self.model_class.create_unique_id(**element.keys), + unique_id=self._local.model_class.create_unique_id(**element.keys), diffs=diffs, ) - self.action = element.action + self._local.action = element.action ids = element.keys # We only actually need the "new" attrs to perform a create/update operation, and don't need any for a delete attrs = diffs.get("+", {}) # Retrieve Source Object to get its flags - src_model = self.src_diffsync.get_or_none(self.model_class, ids) + src_model = self.src_diffsync.get_or_none(self._local.model_class, ids) # Retrieve Dest (and primary) Object dst_model: Optional["DiffSyncModel"] try: - dst_model = self.dst_diffsync.get(self.model_class, ids) + dst_model = self.dst_diffsync.get(self._local.model_class, ids) dst_model.set_status(DiffSyncStatus.UNKNOWN) except ObjectNotFound: dst_model = None @@ -373,23 +471,23 @@ def sync_diff_element(self, element: DiffElement, parent_model: Optional["DiffSy # Recurse through children to delete if we are supposed to delete the current diff element changed = False - if natural_deletion_order and self.action == DiffSyncActions.DELETE and not skip_children: + if natural_deletion_order and self._local.action == DiffSyncActions.DELETE and not skip_children: for child in element.get_children(): changed |= self.sync_diff_element(child, parent_model=dst_model) - # Sync the current model - this will delete the current model if self.action is DELETE + # Sync the current model - this will delete the current model if self._local.action is DELETE changed, modified_model = self.sync_model(src_model=src_model, dst_model=dst_model, ids=ids, attrs=attrs) dst_model = modified_model or dst_model if not modified_model or not dst_model: - self.logger.warning("No object resulted from sync, will not process child objects.") + self._local.logger.warning("No object resulted from sync, will not process child objects.") return changed - if self.action == DiffSyncActions.CREATE: + if self._local.action == DiffSyncActions.CREATE: if parent_model: parent_model.add_child(dst_model) self.dst_diffsync.add(dst_model) - elif self.action == DiffSyncActions.DELETE: + elif self._local.action == DiffSyncActions.DELETE: if parent_model: parent_model.remove_child(dst_model) @@ -400,7 +498,7 @@ def sync_diff_element(self, element: DiffElement, parent_model: Optional["DiffSy self.incr_elements_processed() - if not natural_deletion_order or self.action is not DiffSyncActions.DELETE: + if not natural_deletion_order or self._local.action is not DiffSyncActions.DELETE: for child in element.get_children(): changed |= self.sync_diff_element(child, parent_model=dst_model) @@ -416,44 +514,63 @@ def sync_model( # pylint: disable=too-many-branches, unused-argument Returns: (changed, model) where model may be None if an error occurred """ - if self.action is None: + if self._local.action is None: status = DiffSyncStatus.SUCCESS message = "No changes to apply; no action needed" - self.log_sync_status(self.action, status, message) + self.log_sync_status(self._local.action, status, message) return (False, dst_model) + # Callback-based sync interceptor + if self.sync_filter: + model_type = self._local.model_class.get_type() + if not self.sync_filter(self._local.action, model_type, ids, attrs): + self._local.logger.debug("Skipping due to sync_filter callback") + # Clear the action so sync_diff_element doesn't proceed with store operations + self._local.action = None + return (False, dst_model) + try: - self.logger.debug(f"Attempting model {self.action}") - if self.action == DiffSyncActions.CREATE: + self._local.logger.debug(f"Attempting model {self._local.action}") + if self._local.action == DiffSyncActions.CREATE: if dst_model is not None: - raise ObjectNotCreated(f"Failed to create {self.model_class.get_type()} {ids} - it already exists!") - dst_model = self.model_class.create(adapter=self.dst_diffsync, ids=ids, attrs=attrs) - elif self.action == DiffSyncActions.UPDATE: + raise ObjectNotCreated( + f"Failed to create {self._local.model_class.get_type()} {ids} - it already exists!" + ) + dst_model = self._local.model_class.create(adapter=self.dst_diffsync, ids=ids, attrs=attrs) + elif self._local.action == DiffSyncActions.UPDATE: if dst_model is None: - raise ObjectNotUpdated(f"Failed to update {self.model_class.get_type()} {ids} - not found!") + raise ObjectNotUpdated(f"Failed to update {self._local.model_class.get_type()} {ids} - not found!") dst_model = dst_model.update(attrs=attrs) - elif self.action == DiffSyncActions.DELETE: + elif self._local.action == DiffSyncActions.DELETE: if dst_model is None: - raise ObjectNotDeleted(f"Failed to delete {self.model_class.get_type()} {ids} - not found!") + raise ObjectNotDeleted(f"Failed to delete {self._local.model_class.get_type()} {ids} - not found!") dst_model = dst_model.delete() else: - raise ObjectCrudException(f'Unknown action "{self.action}"!') + raise ObjectCrudException(f'Unknown action "{self._local.action}"!') if dst_model is not None: status, message = dst_model.get_status() else: status = DiffSyncStatus.FAILURE - message = f"{self.model_class.get_type()} {self.action} did not return the model object." + message = f"{self._local.model_class.get_type()} {self._local.action} did not return the model object." except ObjectCrudException as exception: status = DiffSyncStatus.ERROR message = str(exception) - self.log_sync_status(self.action, status, message) + self.log_sync_status(self._local.action, status, message) if self.flags & DiffSyncFlags.CONTINUE_ON_FAILURE: return (True, None) raise - self.log_sync_status(self.action, status, message) + self.log_sync_status(self._local.action, status, message) + + # Track operations for structured sync_complete + if self._local.action is not None and status == DiffSyncStatus.SUCCESS: + with self._operations_lock: + model_type = self._local.model_class.get_type() + if model_type not in self.operations: + self.operations[model_type] = {"create": [], "update": [], "delete": []} + self.operations[model_type][self._local.action].append({"ids": ids, "attrs": attrs, "model": dst_model}) return (True, dst_model) @@ -464,10 +581,10 @@ def log_sync_status(self, action: Optional[str], status: DiffSyncStatus, message """ if action is None: if self.flags & DiffSyncFlags.LOG_UNCHANGED_RECORDS: - self.logger.debug(message, status=status.value) + self._local.logger.debug(message, status=status.value) elif status == DiffSyncStatus.SUCCESS: - self.logger.info(message, status=status.value) + self._local.logger.info(message, status=status.value) elif status == DiffSyncStatus.FAILURE: - self.logger.warning(message, status=status.value) + self._local.logger.warning(message, status=status.value) else: - self.logger.error(message, status=status.value) + self._local.logger.error(message, status=status.value) diff --git a/diffsync/store/__init__.py b/diffsync/store/__init__.py index cf52485f..00550a42 100644 --- a/diffsync/store/__init__.py +++ b/diffsync/store/__init__.py @@ -138,6 +138,43 @@ def count(self, *, model: Union[str, "DiffSyncModel", Type["DiffSyncModel"], Non """Returns the number of elements of a specific model, or all elements in the store if not specified.""" raise NotImplementedError + def add_bulk(self, *, objs: List["DiffSyncModel"]) -> None: + """Add multiple DiffSyncModel objects to the store. + + The default implementation loops over individual add() calls. + Subclasses may override for optimized batch operations. + + Args: + objs: List of objects to store + """ + for obj in objs: + self.add(obj=obj) + + def update_bulk(self, *, objs: List["DiffSyncModel"]) -> None: + """Update multiple DiffSyncModel objects in the store. + + The default implementation loops over individual update() calls. + Subclasses may override for optimized batch operations. + + Args: + objs: List of objects to update + """ + for obj in objs: + self.update(obj=obj) + + def remove_bulk(self, *, objs: List["DiffSyncModel"], remove_children: bool = False) -> None: + """Remove multiple DiffSyncModel objects from the store. + + The default implementation loops over individual remove() calls. + Subclasses may override for optimized batch operations. + + Args: + objs: List of objects to remove + remove_children: If True, also recursively remove children + """ + for obj in objs: + self.remove(obj=obj, remove_children=remove_children) + def get_or_instantiate( self, *, model: Type["DiffSyncModel"], ids: Dict, attrs: Optional[Dict] = None ) -> Tuple["DiffSyncModel", bool]: diff --git a/diffsync/store/local.py b/diffsync/store/local.py index ff561db4..4be3524f 100644 --- a/diffsync/store/local.py +++ b/diffsync/store/local.py @@ -1,5 +1,6 @@ """LocalStore module.""" +import threading from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Set, Type, Union @@ -18,6 +19,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._data: Dict = defaultdict(dict) + self._lock = threading.Lock() def get_all_model_names(self) -> Set[str]: """Get all the model names stored. @@ -97,20 +99,21 @@ def add(self, *, obj: "DiffSyncModel") -> None: Raises: ObjectAlreadyExists: if a different object with the same uid is already present. """ - modelname = obj.get_type() - uid = obj.get_unique_id() + with self._lock: + modelname = obj.get_type() + uid = obj.get_unique_id() - existing_obj = self._data[modelname].get(uid) - if existing_obj: - if existing_obj is not obj: - raise ObjectAlreadyExists(f"Object {uid} already present", obj) - # Return so we don't have to change anything on the existing object and underlying data - return + existing_obj = self._data[modelname].get(uid) + if existing_obj: + if existing_obj is not obj: + raise ObjectAlreadyExists(f"Object {uid} already present", obj) + # Return so we don't have to change anything on the existing object and underlying data + return - if not obj.adapter: - obj.adapter = self.adapter + if not obj.adapter: + obj.adapter = self.adapter - self._data[modelname][uid] = obj + self._data[modelname][uid] = obj def update(self, *, obj: "DiffSyncModel") -> None: """Update a DiffSyncModel object to the store. @@ -118,20 +121,22 @@ def update(self, *, obj: "DiffSyncModel") -> None: Args: obj: Object to update """ - modelname = obj.get_type() - uid = obj.get_unique_id() + with self._lock: + modelname = obj.get_type() + uid = obj.get_unique_id() - existing_obj = self._data[modelname].get(uid) - if existing_obj is obj: - return + existing_obj = self._data[modelname].get(uid) + if existing_obj is obj: + return - self._data[modelname][uid] = obj + self._data[modelname][uid] = obj def remove_item(self, modelname: str, uid: str) -> None: """Remove one item from store.""" - if uid not in self._data[modelname]: - raise ObjectNotFound(f"{modelname} {uid} not present in {str(self)}") - del self._data[modelname][uid] + with self._lock: + if uid not in self._data[modelname]: + raise ObjectNotFound(f"{modelname} {uid} not present in {str(self)}") + del self._data[modelname][uid] def count(self, *, model: Union[str, "DiffSyncModel", Type["DiffSyncModel"], None] = None) -> int: """Returns the number of elements of a specific model, or all elements in the store if unspecified.""" @@ -143,3 +148,25 @@ def count(self, *, model: Union[str, "DiffSyncModel", Type["DiffSyncModel"], Non else: modelname = model.get_type() return len(self._data[modelname]) + + def add_bulk(self, *, objs: List["DiffSyncModel"]) -> None: + """Add multiple DiffSyncModel objects to the store in a batch. + + Args: + objs: List of objects to store + + Raises: + ObjectAlreadyExists: if a different object with the same uid is already present. + """ + with self._lock: + for obj in objs: + modelname = obj.get_type() + uid = obj.get_unique_id() + existing_obj = self._data[modelname].get(uid) + if existing_obj: + if existing_obj is not obj: + raise ObjectAlreadyExists(f"Object {uid} already present", obj) + continue + if not obj.adapter: + obj.adapter = self.adapter + self._data[modelname][uid] = obj diff --git a/diffsync/store/redis.py b/diffsync/store/redis.py index 5325752d..103807cb 100644 --- a/diffsync/store/redis.py +++ b/diffsync/store/redis.py @@ -207,6 +207,39 @@ def remove_item(self, modelname: str, uid: str) -> None: self._store.delete(object_key) + def add_bulk(self, *, objs: List["DiffSyncModel"]) -> None: + """Add multiple DiffSyncModel objects to Redis using a pipeline for efficiency. + + Args: + objs: List of objects to store + + Raises: + ObjectAlreadyExists: if a different object with the same uid is already present. + """ + # Validate first, then batch write + keys_to_set = [] + for obj in objs: + modelname = obj.get_type() + uid = obj.get_unique_id() + object_key = self._get_key_for_object(modelname, uid) + + existing_obj_binary = self._store.get(object_key) + if existing_obj_binary: + existing_obj = loads(existing_obj_binary) # noqa: S301 + if existing_obj.dict() != obj.dict(): + raise ObjectAlreadyExists(f"Object {uid} already present", obj) + continue + + obj_copy = copy.copy(obj) + obj_copy.adapter = None + keys_to_set.append((object_key, dumps(obj_copy))) + + if keys_to_set: + pipe = self._store.pipeline() + for key, data in keys_to_set: + pipe.set(key, data) + pipe.execute() + def count(self, *, model: Union[str, "DiffSyncModel", Type["DiffSyncModel"], None] = None) -> int: """Returns the number of elements of a specific model, or all elements in the store if unspecified.""" search_pattern = f"{self._store_label}:*" diff --git a/tests/unit/test_diff_filtering.py b/tests/unit/test_diff_filtering.py new file mode 100644 index 00000000..3245a4ce --- /dev/null +++ b/tests/unit/test_diff_filtering.py @@ -0,0 +1,95 @@ +"""Unit tests for the Diff filter and exclude methods. + +Copyright (c) 2020-2021 Network To Code, LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + + +def test_diff_filter_by_action_create(diff_with_children): + """Filtering a Diff by action='create' should only retain elements with create diffs.""" + filtered = diff_with_children.filter(actions={"create"}) + actions = [] + for child in filtered.get_children(): + if child.action: + actions.append(child.action) + assert "create" in actions + assert "delete" not in actions + + +def test_diff_filter_by_action_delete(diff_with_children): + """Filtering a Diff by action='delete' should only retain elements with delete diffs.""" + filtered = diff_with_children.filter(actions={"delete"}) + actions = [] + for child in filtered.get_children(): + if child.action: + actions.append(child.action) + assert "delete" in actions + assert "create" not in actions + + +def test_diff_filter_by_model_types(diff_with_children): + """Filtering a Diff by model_types should only retain elements of those types.""" + filtered = diff_with_children.filter(model_types={"person"}) + types = [child.type for child in filtered.get_children()] + assert "person" in types + assert "device" not in types + assert "address" not in types + + +def test_diff_filter_by_action_and_model_type(diff_with_children): + """Filtering by both action and model type should apply both criteria.""" + filtered = diff_with_children.filter(actions={"create"}, model_types={"person"}) + elements = list(filtered.get_children()) + assert len(elements) == 1 + assert elements[0].type == "person" + assert elements[0].action == "create" + + +def test_diff_filter_does_not_mutate_original(diff_with_children): + """Calling filter() should return a new Diff without modifying the original.""" + original_len = len(diff_with_children) + _ = diff_with_children.filter(actions={"create"}) + assert len(diff_with_children) == original_len + + +def test_diff_exclude_by_action(diff_with_children): + """Excluding by action should remove elements with that action.""" + excluded = diff_with_children.exclude(actions={"delete"}) + for child in excluded.get_children(): + assert child.action != "delete" + + +def test_diff_exclude_by_model_types(diff_with_children): + """Excluding by model_types should remove elements of those types.""" + excluded = diff_with_children.exclude(model_types={"person"}) + types = [child.type for child in excluded.get_children()] + assert "person" not in types + + +def test_diff_filter_no_criteria_returns_full_copy(diff_with_children): + """Filtering with no criteria should return a copy of the entire Diff.""" + filtered = diff_with_children.filter() + assert len(filtered) == len(diff_with_children) + + +def test_diff_exclude_no_criteria_returns_full_copy(diff_with_children): + """Excluding with no criteria should return a copy of the entire Diff.""" + excluded = diff_with_children.exclude() + assert len(excluded) == len(diff_with_children) + + +def test_diff_filter_preserves_models_processed(diff_with_children): + """The models_processed count should be preserved on the filtered Diff.""" + filtered = diff_with_children.filter(actions={"create"}) + assert filtered.models_processed == diff_with_children.models_processed diff --git a/tests/unit/test_diffsync_diff_and_sync_parameters.py b/tests/unit/test_diffsync_diff_and_sync_parameters.py new file mode 100644 index 00000000..c0878eee --- /dev/null +++ b/tests/unit/test_diffsync_diff_and_sync_parameters.py @@ -0,0 +1,630 @@ +"""Unit tests for Adapter diff/sync parameters: model_types, filters, sync_attrs, exclude_attrs, sync_filter, and concurrent. + +Copyright (c) 2020-2021 Network To Code, LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Dict, List, Optional + +from diffsync import Adapter, DiffSyncModel +from diffsync.enum import DiffSyncFlags + +# --------------------------------------------------------------------------- +# Models and adapters used across this test module +# --------------------------------------------------------------------------- + + +class _Site(DiffSyncModel): + _modelname = "site" + _identifiers = ("name",) + _attributes = ("location",) + _children = {"device": "devices"} + + name: str + location: str = "" + devices: List = [] + + +class _Device(DiffSyncModel): + _modelname = "device" + _identifiers = ("name",) + _attributes = ("role", "tag") + + name: str + role: str = "" + tag: str = "" + + +class _SimpleAdapter(Adapter): + site = _Site + device = _Device + top_level = ["site"] + + +def _make_adapter_pair(): + """Build a source and destination adapter with overlapping but differing data.""" + src = _SimpleAdapter() + dst = _SimpleAdapter() + + # Source: site1 (NYC) -> device1 (spine/prod), device2 (leaf/staging) + # site2 (SFO) -> device3 (spine/prod) + site1 = _Site(name="site1", location="NYC") + src.add(site1) + d1 = _Device(name="device1", role="spine", tag="prod") + src.add(d1) + site1.add_child(d1) + d2 = _Device(name="device2", role="leaf", tag="staging") + src.add(d2) + site1.add_child(d2) + + site2 = _Site(name="site2", location="SFO") + src.add(site2) + d3 = _Device(name="device3", role="spine", tag="prod") + src.add(d3) + site2.add_child(d3) + + # Dest: site1 (NYC) -> device1 (leaf/dev), device2 (leaf/staging) + # site3 (ATL) — not in source + dst_site1 = _Site(name="site1", location="NYC") + dst.add(dst_site1) + dst_d1 = _Device(name="device1", role="leaf", tag="dev") + dst.add(dst_d1) + dst_site1.add_child(dst_d1) + dst_d2 = _Device(name="device2", role="leaf", tag="staging") + dst.add(dst_d2) + dst_site1.add_child(dst_d2) + + dst_site3 = _Site(name="site3", location="ATL") + dst.add(dst_site3) + + return src, dst + + +# --------------------------------------------------------------------------- +# model_types scoping +# --------------------------------------------------------------------------- + + +def test_diff_with_model_types_restricts_to_site_only(): + """Passing model_types={'site'} should exclude child device elements from the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, model_types={"site"}) + types = set() + for child in diff.get_children(): + types.add(child.type) + assert "site" in types + assert "device" not in types + + +def test_diff_with_model_types_includes_site_and_device(): + """Passing model_types={'site', 'device'} should include both types in the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, model_types={"site", "device"}) + types = set() + for child in diff.get_children(): + types.add(child.type) + for grandchild in child.get_children(): + types.add(grandchild.type) + assert "site" in types + assert "device" in types + + +def test_sync_with_model_types_does_not_touch_excluded_types(): + """Syncing with model_types={'site'} should leave device objects unchanged.""" + src, dst = _make_adapter_pair() + initial_device_count = dst.count("device") + dst.sync_from(src, model_types={"site"}) + assert dst.count("device") == initial_device_count + + +# --------------------------------------------------------------------------- +# sync_attrs / exclude_attrs +# --------------------------------------------------------------------------- + + +def test_sync_attrs_limits_diff_to_whitelisted_attributes(): + """Only the attributes named in sync_attrs should appear in the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, sync_attrs={"device": {"role"}}) + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device" and device_el.action == "update": + diffs = device_el.get_attrs_diffs() + if "+" in diffs: + assert "role" in diffs["+"] + assert "tag" not in diffs["+"] + + +def test_exclude_attrs_removes_blacklisted_attributes(): + """Attributes named in exclude_attrs should not appear in the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, exclude_attrs={"device": {"tag"}}) + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device" and device_el.action == "update": + diffs = device_el.get_attrs_diffs() + if "+" in diffs: + assert "tag" not in diffs["+"] + + +def test_sync_attrs_and_exclude_attrs_applied_together(): + """sync_attrs whitelist is applied first, then exclude_attrs blacklist.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from( + src, + sync_attrs={"device": {"role", "tag"}}, + exclude_attrs={"device": {"role"}}, + ) + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device" and device_el.name == "device1": + diffs = device_el.get_attrs_diffs() + if "+" in diffs: + assert "role" not in diffs["+"] + assert "tag" in diffs["+"] + + +# --------------------------------------------------------------------------- +# filters (query predicates) +# --------------------------------------------------------------------------- + + +def test_filters_include_matching_objects(): + """Objects whose predicate returns True should be included in the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, filters={"device": lambda d: d.role == "spine"}) + + device_names = set() + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device": + device_names.add(device_el.name) + + assert "device1" in device_names # spine in source + assert "device3" in device_names # spine in source + assert "device2" not in device_names # leaf in source, filtered out + + +def test_filters_exclude_nonmatching_objects(): + """Objects whose predicate returns False should be excluded from the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, filters={"device": lambda d: d.role == "nonexistent"}) + + device_elements = [] + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device": + device_elements.append(device_el) + + assert len(device_elements) == 0 + + +def test_filters_do_not_affect_unfiltered_types(): + """Model types not named in the filters dict should remain in the diff.""" + src, dst = _make_adapter_pair() + diff = dst.diff_from(src, filters={"device": lambda d: d.role == "spine"}) + + site_elements = [child for child in diff.get_children() if child.type == "site"] + assert len(site_elements) > 0 + + +# --------------------------------------------------------------------------- +# sync_filter callback +# --------------------------------------------------------------------------- + + +def test_sync_filter_blocks_delete_operations(): + """A sync_filter that rejects deletes should preserve objects that only exist in the destination.""" + src, dst = _make_adapter_pair() + assert dst.get_or_none("site", "site3") is not None + + dst.sync_from( + src, + sync_filter=lambda action, model_type, ids, attrs: action != "delete", + ) + + assert dst.get_or_none("site", "site3") is not None + + +def test_sync_filter_blocks_create_operations(): + """A sync_filter that rejects creates should prevent objects that only exist in the source.""" + src, dst = _make_adapter_pair() + assert dst.get_or_none("site", "site2") is None + + dst.sync_from( + src, + sync_filter=lambda action, model_type, ids, attrs: action != "create", + ) + + assert dst.get_or_none("site", "site2") is None + + +def test_sync_filter_allows_update_operations(): + """A sync_filter that only allows updates should apply attribute changes without creating or deleting.""" + src, dst = _make_adapter_pair() + + dst.sync_from( + src, + sync_filter=lambda action, model_type, ids, attrs: action == "update", + flags=DiffSyncFlags.SKIP_UNMATCHED_BOTH, + ) + + device1 = dst.get_or_none("device", "device1") + assert device1 is not None + assert device1.role == "spine" + + +def test_sync_filter_blocks_by_model_type(): + """A sync_filter can selectively block operations on specific model types.""" + src, dst = _make_adapter_pair() + + dst.sync_from( + src, + sync_filter=lambda action, model_type, ids, attrs: model_type != "device", + ) + + # site2 should be created (not blocked) + assert dst.get_or_none("site", "site2") is not None + + +# --------------------------------------------------------------------------- +# sync_complete operations summary +# --------------------------------------------------------------------------- + + +class _TrackingAdapter(_SimpleAdapter): + """Adapter that captures the operations dict passed to sync_complete.""" + + captured_operations: Optional[Dict] = None + + def sync_complete(self, source, diff, flags=DiffSyncFlags.NONE, logger=None, operations=None): + self.captured_operations = operations + + +def _make_tracking_dst(): + """Build a _TrackingAdapter pre-populated with the same data as the destination in _make_adapter_pair.""" + tracking_dst = _TrackingAdapter() + dst_site1 = _Site(name="site1", location="NYC") + tracking_dst.add(dst_site1) + dst_d1 = _Device(name="device1", role="leaf", tag="dev") + tracking_dst.add(dst_d1) + dst_site1.add_child(dst_d1) + dst_d2 = _Device(name="device2", role="leaf", tag="staging") + tracking_dst.add(dst_d2) + dst_site1.add_child(dst_d2) + return tracking_dst + + +def test_sync_complete_receives_create_operations(): + """The operations dict passed to sync_complete should include create entries.""" + src, _ = _make_adapter_pair() + tracking_dst = _make_tracking_dst() + dst_site3 = _Site(name="site3", location="ATL") + tracking_dst.add(dst_site3) + + tracking_dst.sync_from(src) + + assert tracking_dst.captured_operations is not None + ops = tracking_dst.captured_operations + assert "site" in ops + created_site_names = [op["ids"]["name"] for op in ops["site"]["create"]] + assert "site2" in created_site_names + + +def test_sync_complete_receives_update_operations(): + """The operations dict passed to sync_complete should include update entries.""" + src, _ = _make_adapter_pair() + tracking_dst = _make_tracking_dst() + + tracking_dst.sync_from(src, flags=DiffSyncFlags.SKIP_UNMATCHED_DST) + + ops = tracking_dst.captured_operations + assert ops is not None + assert "device" in ops + updated_device_names = [op["ids"]["name"] for op in ops["device"]["update"]] + assert "device1" in updated_device_names + + +def test_sync_complete_backwards_compat_without_operations_kwarg(): + """Subclasses that override sync_complete without the operations kwarg should still work.""" + + class _OldStyleAdapter(_SimpleAdapter): + sync_complete_called = False + + def sync_complete(self, source, diff, flags=DiffSyncFlags.NONE, logger=None): + self.sync_complete_called = True + + src, _ = _make_adapter_pair() + old_dst = _OldStyleAdapter() + old_dst.add(_Site(name="site1", location="NYC")) + + old_dst.sync_from(src, flags=DiffSyncFlags.SKIP_UNMATCHED_DST) + assert old_dst.sync_complete_called + + +# --------------------------------------------------------------------------- +# concurrent sync +# --------------------------------------------------------------------------- + + +def test_concurrent_sync_matches_serial_sync(): + """Syncing with concurrent=True should produce the same result as a serial sync.""" + src, dst_serial = _make_adapter_pair() + _, dst_concurrent = _make_adapter_pair() + + dst_serial.sync_from(src) + dst_concurrent.sync_from(src, concurrent=True, max_workers=2) + + assert dst_serial.count("site") == dst_concurrent.count("site") + assert dst_serial.count("device") == dst_concurrent.count("device") + + +def test_sync_defaults_to_serial(): + """Passing concurrent=False (the default) should work identically to the original sync.""" + src, dst = _make_adapter_pair() + dst.sync_from(src, concurrent=False) + assert dst.get_or_none("site", "site2") is not None + + +# --------------------------------------------------------------------------- +# Combinations of multiple parameters +# --------------------------------------------------------------------------- + + +def test_diff_filter_then_sync_with_sync_filter(): + """A pre-filtered Diff combined with a sync_filter should respect both layers.""" + src, dst = _make_adapter_pair() + + diff = dst.diff_from(src) + filtered = diff.filter(actions={"create", "update"}) + + dst.sync_from( + src, + diff=filtered, + sync_filter=lambda action, model_type, ids, attrs: not (action == "create" and model_type == "site"), + ) + + # site2 blocked by sync_filter, device1 update allowed + assert dst.get_or_none("site", "site2") is None + device1 = dst.get_or_none("device", "device1") + assert device1 is not None + assert device1.role == "spine" + + +def test_model_types_combined_with_sync_attrs(): + """model_types and sync_attrs should compose — only scoped types with whitelisted attrs appear.""" + src, dst = _make_adapter_pair() + + diff = dst.diff_from( + src, + model_types={"site", "device"}, + sync_attrs={"device": {"role"}}, + ) + + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device" and device_el.action == "update": + diffs = device_el.get_attrs_diffs() + if "+" in diffs: + assert "tag" not in diffs["+"] + + +def test_filters_combined_with_sync_attrs(): + """A query predicate filter and sync_attrs should compose — only matching objects with whitelisted attrs.""" + src, dst = _make_adapter_pair() + + diff = dst.diff_from( + src, + filters={"device": lambda d: d.role == "spine"}, + sync_attrs={"device": {"role"}}, + ) + + device_elements = [] + for child in diff.get_children(): + for device_el in child.get_children(): + if device_el.type == "device": + device_elements.append(device_el) + + for de in device_elements: + assert de.name != "device2" # device2 is leaf, should be filtered out + if de.action == "update": + diffs = de.get_attrs_diffs() + if "+" in diffs: + assert "tag" not in diffs["+"] + + +# --------------------------------------------------------------------------- +# sync_stages — ordered group execution for concurrent sync +# --------------------------------------------------------------------------- + + +# Models and adapters for sync_stages tests — uses multiple top-level types +# to exercise staged parallelism. + +_creation_order: List = [] + + +class _Region(DiffSyncModel): + _modelname = "region" + _identifiers = ("name",) + _attributes = ("slug",) + + name: str + slug: str = "" + + @classmethod + def create(cls, adapter, ids, attrs): + _creation_order.append(("region", ids["name"])) + return super().create(adapter=adapter, ids=ids, attrs=attrs) + + +class _Tenant(DiffSyncModel): + _modelname = "tenant" + _identifiers = ("name",) + _attributes = ("group",) + + name: str + group: str = "" + + @classmethod + def create(cls, adapter, ids, attrs): + _creation_order.append(("tenant", ids["name"])) + return super().create(adapter=adapter, ids=ids, attrs=attrs) + + +class _Rack(DiffSyncModel): + _modelname = "rack" + _identifiers = ("name",) + _attributes = ("site_name",) + + name: str + site_name: str = "" + + @classmethod + def create(cls, adapter, ids, attrs): + _creation_order.append(("rack", ids["name"])) + return super().create(adapter=adapter, ids=ids, attrs=attrs) + + +class _StagedAdapter(Adapter): + region = _Region + tenant = _Tenant + rack = _Rack + top_level = ["region", "tenant", "rack"] + sync_stages = [ + ["region", "tenant"], # stage 1: independent, can run in parallel + ["rack"], # stage 2: depends on regions being created + ] + + +class _UnstagedAdapter(Adapter): + """Same models, no sync_stages — for comparison.""" + region = _Region + tenant = _Tenant + rack = _Rack + top_level = ["region", "tenant", "rack"] + + +def _make_staged_pair(adapter_cls=_StagedAdapter): + """Build a source with regions/tenants/racks and an empty destination.""" + src = adapter_cls() + dst = adapter_cls() + + src.add(_Region(name="region1", slug="r1")) + src.add(_Region(name="region2", slug="r2")) + src.add(_Tenant(name="tenant1", group="g1")) + src.add(_Rack(name="rack1", site_name="region1")) + src.add(_Rack(name="rack2", site_name="region2")) + + return src, dst + + +def test_sync_stages_executes_in_order(): + """All stage-1 types (region, tenant) must be created before any stage-2 type (rack).""" + _creation_order.clear() + src, dst = _make_staged_pair() + dst.sync_from(src, concurrent=True, max_workers=4) + + # Find the index of the first rack creation + rack_indices = [i for i, (t, _) in enumerate(_creation_order) if t == "rack"] + region_indices = [i for i, (t, _) in enumerate(_creation_order) if t == "region"] + tenant_indices = [i for i, (t, _) in enumerate(_creation_order) if t == "tenant"] + + assert len(rack_indices) == 2 + assert len(region_indices) == 2 + assert len(tenant_indices) == 1 + + # All stage-1 creations (regions + tenants) must come before any stage-2 creation (racks) + max_stage1_index = max(max(region_indices), max(tenant_indices)) + min_stage2_index = min(rack_indices) + assert max_stage1_index < min_stage2_index, ( + f"Stage 1 items must all complete before stage 2 begins. " + f"Order was: {_creation_order}" + ) + + +def test_sync_stages_parallelizes_within_stage(): + """Two independent top-level types in the same stage should both be processed.""" + _creation_order.clear() + src, dst = _make_staged_pair() + dst.sync_from(src, concurrent=True, max_workers=4) + + types_created = {t for t, _ in _creation_order} + assert "region" in types_created + assert "tenant" in types_created + assert "rack" in types_created + + +def test_sync_stages_none_preserves_current_behavior(): + """sync_stages=None with concurrent=True should behave like the original unstaged concurrent sync.""" + _creation_order.clear() + src, dst = _make_staged_pair(_UnstagedAdapter) + dst.sync_from(src, concurrent=True, max_workers=2) + + assert dst.get_or_none("region", "region1") is not None + assert dst.get_or_none("tenant", "tenant1") is not None + assert dst.get_or_none("rack", "rack1") is not None + + +def test_sync_stages_ignored_when_serial(): + """sync_stages should have no effect on serial sync — top_level order is used.""" + _creation_order.clear() + src, dst = _make_staged_pair() + dst.sync_from(src, concurrent=False) + + assert dst.get_or_none("region", "region1") is not None + assert dst.get_or_none("rack", "rack1") is not None + + +def test_sync_stages_validation_rejects_unknown_type(): + """A type in sync_stages that is not in top_level should raise AttributeError.""" + import pytest + + with pytest.raises(AttributeError, match="sync_stages.*not in top_level"): + class _BadAdapter(Adapter): + region = _Region + top_level = ["region"] + sync_stages = [["region", "nonexistent"]] + + +def test_sync_stages_validation_rejects_duplicates(): + """A type appearing in multiple stages should raise AttributeError.""" + import pytest + + with pytest.raises(AttributeError, match="sync_stages.*duplicate"): + class _BadAdapter(Adapter): + region = _Region + tenant = _Tenant + top_level = ["region", "tenant"] + sync_stages = [["region", "tenant"], ["region"]] + + +def test_sync_stages_unstaged_types_still_sync(): + """A type in top_level but not in any stage should still be synced (serially, after all stages).""" + + class _PartialStagesAdapter(Adapter): + region = _Region + tenant = _Tenant + rack = _Rack + top_level = ["region", "tenant", "rack"] + sync_stages = [["region"]] # tenant and rack not staged + + _creation_order.clear() + src, dst = _make_staged_pair(_PartialStagesAdapter) + dst.sync_from(src, concurrent=True, max_workers=2) + + # All types should still be synced + assert dst.get_or_none("region", "region1") is not None + assert dst.get_or_none("tenant", "tenant1") is not None + assert dst.get_or_none("rack", "rack1") is not None diff --git a/tests/unit/test_diffsync_model_bulk.py b/tests/unit/test_diffsync_model_bulk.py new file mode 100644 index 00000000..cf1a9f75 --- /dev/null +++ b/tests/unit/test_diffsync_model_bulk.py @@ -0,0 +1,101 @@ +"""Unit tests for the DiffSyncModel bulk CRUD methods and store bulk operations. + +Copyright (c) 2020-2021 Network To Code, LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import List + +from diffsync import Adapter, DiffSyncModel + + +class _Device(DiffSyncModel): + _modelname = "device" + _identifiers = ("name",) + _attributes = ("role", "tag") + + name: str + role: str = "" + tag: str = "" + + +class _Site(DiffSyncModel): + _modelname = "site" + _identifiers = ("name",) + _children = {"device": "devices"} + + name: str + devices: List = [] + + +class _SimpleAdapter(Adapter): + site = _Site + device = _Device + top_level = ["site"] + + +def test_create_bulk_produces_same_results_as_individual_creates(): + """The default create_bulk implementation should create all requested objects.""" + adapter = _SimpleAdapter() + results = _Device.create_bulk( + adapter=adapter, + objects=[ + {"ids": {"name": "d1"}, "attrs": {"role": "spine", "tag": "a"}}, + {"ids": {"name": "d2"}, "attrs": {"role": "leaf", "tag": "b"}}, + ], + ) + assert len(results) == 2 + assert results[0].name == "d1" + assert results[1].name == "d2" + + +def test_update_bulk_updates_all_models(): + """The default update_bulk implementation should update each model's attributes.""" + adapter = _SimpleAdapter() + d1 = _Device(name="d1", role="spine", tag="a") + d2 = _Device(name="d2", role="leaf", tag="b") + results = _Device.update_bulk( + adapter=adapter, + objects=[(d1, {"role": "updated1"}), (d2, {"role": "updated2"})], + ) + assert results[0].role == "updated1" + assert results[1].role == "updated2" + + +def test_delete_bulk_deletes_all_models(): + """The default delete_bulk implementation should delete each model.""" + adapter = _SimpleAdapter() + d1 = _Device(name="d1", role="spine", tag="a") + d2 = _Device(name="d2", role="leaf", tag="b") + results = _Device.delete_bulk(adapter=adapter, objects=[d1, d2]) + assert len(results) == 2 + + +def test_store_add_bulk_adds_multiple_objects(): + """LocalStore.add_bulk should add all provided objects to the store.""" + adapter = _SimpleAdapter() + d1 = _Device(name="d1", role="spine", tag="a") + d2 = _Device(name="d2", role="leaf", tag="b") + adapter.store.add_bulk(objs=[d1, d2]) + assert adapter.count("device") == 2 + + +def test_store_remove_bulk_removes_multiple_objects(): + """LocalStore.remove_bulk should remove all provided objects from the store.""" + adapter = _SimpleAdapter() + d1 = _Device(name="d1", role="spine", tag="a") + d2 = _Device(name="d2", role="leaf", tag="b") + adapter.store.add_bulk(objs=[d1, d2]) + adapter.store.remove_bulk(objs=[d1, d2]) + assert adapter.count("device") == 0