Source code for oagdedupe.api

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

import pandas as pd
import requests

from oagdedupe import db
from oagdedupe.base import BaseCluster
from oagdedupe.block.blocking import Blocking
from oagdedupe.block.optimizers import DynamicProgram
from oagdedupe.cluster.cluster import ConnectedComponents
from oagdedupe.settings import Settings

root = logging.getLogger()
root.setLevel(logging.DEBUG)


[docs]@dataclass class BaseModel(ABC): """Abstract base class from which all model classes inherit. All descendent classes must implement predict, train, and candidates methods. """ settings: Settings cluster: BaseCluster = ConnectedComponents
[docs] def __post_init__( self, ): self.repo = db.get_repository(settings=self.settings) self.blocking = Blocking( repo=self.repo.blocking, optimizer=DynamicProgram, ) self.cluster = self.cluster(settings=self.settings, repo=self.repo)
[docs] @abstractmethod def initialize(self): return
[docs] def predict(self) -> Union[pd.DataFrame, Tuple[pd.DataFrame]]: """fast-api trains model on latest labels then submits scores to postgres clusterer loads scores and uses comparison indices and predicted probabilities to generate clusters Returns ------- df: pd.DataFrame if dedupe, returns single df df,df2: tuple if recordlinkage, two dataframes """ logging.info("get clusters") requests.post(f"{self.settings.fast_api.url}/train") self.repo.save_predictions() return self.cluster.get_df_cluster()
[docs] def fit_blocks(self) -> None: logging.info("getting comparisons") self.blocking.save(full=True) # get distances logging.info("computing distances") self.repo.save_distances(full=True, labels=False)
[docs]@dataclass class Dedupe(BaseModel): """General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self): super().__post_init__()
[docs] def initialize(self, df: pd.DataFrame) -> None: """learn p(match)""" self.repo.setup(df=df, df2=None) self.repo.save_distances(full=False, labels=True) logging.info("getting comparisons") self.blocking.save(full=False) logging.info("get distance matrix") self.repo.save_distances(full=False, labels=False)
[docs]@dataclass class RecordLinkage(BaseModel): """General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self): super().__post_init__()
[docs] def initialize( self, df: pd.DataFrame, df2: pd.DataFrame, ) -> None: """learn p(match)""" self.repo.setup(df=df, df2=df2) self.repo.save_distances(full=False, labels=True) logging.info("getting comparisons") self.blocking.save(full=False) logging.info("get distance matrix") self.repo.save_distances(full=False, labels=False)
[docs]@dataclass class Fapi(BaseModel): """General dedupe block, inherits from BaseModel."""
[docs] def __post_init__(self): super().__post_init__()
[docs] def initialize(self) -> None: """learn p(match)""" self.repo.resample() self.repo.save_distances(full=False, labels=True) logging.info("getting comparisons") self.blocking.save(full=False) logging.info("get distance matrix") self.repo.save_distances(full=False, labels=False)