Source code for extral.state

# Copyright 2025 Michael Anckaert
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime
import json
import logging
from typing import TypedDict, Optional

from extral.encoder import CustomEncoder

logger = logging.getLogger(__name__)

DatasetState = TypedDict(
    "DatasetState",
    {
        "incremental": dict[str, str | int | None],
    },
)

PipelineState = TypedDict(
    "PipelineState",
    {
        "datasets": dict[str, DatasetState],
    },
)


# Singleton state class to manage the state of the application
[docs] class State: def __init__(self) -> None: self.pipelines: dict[str, PipelineState] = dict()
[docs] def get_dataset_state( self, pipeline_name: str, dataset_id: str ) -> Optional[DatasetState]: """Get the state for a specific dataset in a specific pipeline. Args: pipeline_name: Name of the pipeline dataset_id: Identifier for the dataset (table name or file name/logical name) """ if pipeline_name not in self.pipelines: return None pipeline_state = self.pipelines[pipeline_name] return pipeline_state.get("datasets", {}).get(dataset_id)
[docs] def set_dataset_state( self, pipeline_name: str, dataset_id: str, dataset_state: DatasetState ) -> None: """Set the state for a specific dataset in a specific pipeline. Args: pipeline_name: Name of the pipeline dataset_id: Identifier for the dataset (table name or file name/logical name) dataset_state: State data for the dataset """ if pipeline_name not in self.pipelines: self.pipelines[pipeline_name] = {"datasets": {}} if "datasets" not in self.pipelines[pipeline_name]: self.pipelines[pipeline_name]["datasets"] = {} self.pipelines[pipeline_name]["datasets"][dataset_id] = dataset_state
[docs] def get_pipeline_state(self, pipeline_name: str) -> Optional[PipelineState]: """Get the complete state for a specific pipeline.""" return self.pipelines.get(pipeline_name)
[docs] def list_pipelines(self) -> list[str]: """List all pipeline names that have state.""" return list(self.pipelines.keys())
[docs] def list_datasets(self, pipeline_name: str) -> list[str]: """List all dataset IDs for a specific pipeline.""" if pipeline_name not in self.pipelines: return [] pipeline_state = self.pipelines[pipeline_name] return list(pipeline_state.get("datasets", {}).keys())
[docs] def store_state(self): """Store the current state to a JSON file.""" # Create backup of the current state try: with open( f".state-backup-{datetime.now().isoformat()}.json", "w" ) as backup_file: json.dump( {"pipelines": self.pipelines}, backup_file, indent=4, cls=CustomEncoder, ) logger.debug("State backup created successfully.") except Exception as e: logger.error(f"Failed to create state backup: {e}") # Store the current state with open("state.json", "w") as state_file: json.dump( {"pipelines": self.pipelines}, state_file, indent=4, cls=CustomEncoder ) logger.debug("State stored successfully.")
[docs] def load_state(self): try: with open("state.json", "r") as state_file: data = json.load(state_file) self.pipelines = data.get("pipelines", {}) logger.debug("State loaded successfully.") except FileNotFoundError: logger.debug("State file not found. Starting with an empty state.") except json.JSONDecodeError as e: logger.error(f"Error decoding state file: {e}") self.pipelines = {}
state = State()