import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import pandas as pd
import requests
from oagdedupe import db
from oagdedupe.base import BaseCluster
from oagdedupe.block.blocking import Blocking
from oagdedupe.block.optimizers import DynamicProgram
from oagdedupe.cluster.cluster import ConnectedComponents
from oagdedupe.settings import Settings
root = logging.getLogger()
root.setLevel(logging.DEBUG)
[docs]@dataclass
class BaseModel(ABC):
"""Abstract base class from which all model classes inherit.
All descendent classes must implement predict, train, and candidates methods.
"""
settings: Settings
cluster: BaseCluster = ConnectedComponents
[docs] def __post_init__(
self,
):
self.repo = db.get_repository(settings=self.settings)
self.blocking = Blocking(
repo=self.repo.blocking,
optimizer=DynamicProgram,
)
self.cluster = self.cluster(settings=self.settings, repo=self.repo)
[docs] @abstractmethod
def initialize(self):
return
[docs] def predict(self) -> Union[pd.DataFrame, Tuple[pd.DataFrame]]:
"""fast-api trains model on latest labels then submits scores to
postgres
clusterer loads scores and uses comparison indices and
predicted probabilities to generate clusters
Returns
-------
df: pd.DataFrame
if dedupe, returns single df
df,df2: tuple
if recordlinkage, two dataframes
"""
logging.info("get clusters")
requests.post(f"{self.settings.fast_api.url}/train")
self.repo.save_predictions()
return self.cluster.get_df_cluster()
[docs] def fit_blocks(self) -> None:
logging.info("getting comparisons")
self.blocking.save(full=True)
# get distances
logging.info("computing distances")
self.repo.save_distances(full=True, labels=False)
[docs]@dataclass
class Dedupe(BaseModel):
"""General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self):
super().__post_init__()
[docs] def initialize(self, df: pd.DataFrame) -> None:
"""learn p(match)"""
self.repo.setup(df=df, df2=None)
self.repo.save_distances(full=False, labels=True)
logging.info("getting comparisons")
self.blocking.save(full=False)
logging.info("get distance matrix")
self.repo.save_distances(full=False, labels=False)
[docs]@dataclass
class RecordLinkage(BaseModel):
"""General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self):
super().__post_init__()
[docs] def initialize(
self,
df: pd.DataFrame,
df2: pd.DataFrame,
) -> None:
"""learn p(match)"""
self.repo.setup(df=df, df2=df2)
self.repo.save_distances(full=False, labels=True)
logging.info("getting comparisons")
self.blocking.save(full=False)
logging.info("get distance matrix")
self.repo.save_distances(full=False, labels=False)
[docs]@dataclass
class Fapi(BaseModel):
"""General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self):
super().__post_init__()
[docs] def initialize(self) -> None:
"""learn p(match)"""
self.repo.resample()
self.repo.save_distances(full=False, labels=True)
logging.info("getting comparisons")
self.blocking.save(full=False)
logging.info("get distance matrix")
self.repo.save_distances(full=False, labels=False)