|
| 1 | +import pandas as pd |
| 2 | +from refinery import exceptions |
| 3 | + |
| 4 | + |
| 5 | +class ModelCallback: |
| 6 | + def __init__( |
| 7 | + self, client, inference_fn, preprocessing_fn=None, postprocessing_fn=None |
| 8 | + ): |
| 9 | + self.client = client |
| 10 | + self.inference_fn = inference_fn |
| 11 | + self.preprocessing_fn = preprocessing_fn |
| 12 | + self.postprocessing_fn = postprocessing_fn |
| 13 | + |
| 14 | + self.primary_keys = client.get_primary_keys() |
| 15 | + |
| 16 | + @staticmethod |
| 17 | + def __batch(documents): |
| 18 | + BATCH_SIZE = 32 |
| 19 | + length = len(documents) |
| 20 | + for idx in range(0, length, BATCH_SIZE): |
| 21 | + yield documents[idx : min(idx + BATCH_SIZE, length)] |
| 22 | + |
| 23 | + def run(self, inputs, indices): |
| 24 | + indices_df = pd.DataFrame(indices) |
| 25 | + if not all([key in indices_df.columns for key in self.primary_keys]): |
| 26 | + raise exceptions.PrimaryKeyError("Errorneous primary keys given for index.") |
| 27 | + |
| 28 | + index_generator = ModelCallback.__batch(indices) |
| 29 | + for batched_inputs in ModelCallback.__batch(inputs): |
| 30 | + batched_indices = next(index_generator) |
| 31 | + |
| 32 | + if self.preprocessing_fn is not None: |
| 33 | + batched_inputs = self.preprocessing_fn(batched_inputs) |
| 34 | + |
| 35 | + batched_outputs = self.inference_fn(batched_inputs) |
| 36 | + |
| 37 | + if self.postprocessing_fn is not None: |
| 38 | + batched_outputs = self.postprocessing_fn(batched_outputs) |
| 39 | + |
| 40 | + yield {"index": batched_indices, "associations": batched_outputs} |
0 commit comments