From a1f533c6a803d5f2406f20f84c858b18469166ba Mon Sep 17 00:00:00 2001 From: Austin Godber Date: Wed, 13 Aug 2025 20:48:39 -0700 Subject: [PATCH] fix formatting --- src/embeddingbuddy/__init__.py | 2 +- src/embeddingbuddy/app.py | 19 +- src/embeddingbuddy/config/settings.py | 133 ++++++------ src/embeddingbuddy/data/parser.py | 31 ++- src/embeddingbuddy/data/processor.py | 37 ++-- src/embeddingbuddy/models/reducers.py | 51 ++--- src/embeddingbuddy/models/schemas.py | 8 +- .../ui/callbacks/data_processing.py | 62 +++--- .../ui/callbacks/interactions.py | 96 +++++---- .../ui/callbacks/visualization.py | 109 ++++++---- src/embeddingbuddy/ui/components/sidebar.py | 98 ++++----- src/embeddingbuddy/ui/components/upload.py | 69 +++--- src/embeddingbuddy/ui/layout.py | 64 +++--- src/embeddingbuddy/visualization/colors.py | 37 ++-- src/embeddingbuddy/visualization/plots.py | 199 ++++++++++-------- tests/test_data_processing.py | 42 ++-- tests/test_reducers.py | 53 ++--- 17 files changed, 592 insertions(+), 518 deletions(-) diff --git a/src/embeddingbuddy/__init__.py b/src/embeddingbuddy/__init__.py index 3a705fd..6e27f08 100644 --- a/src/embeddingbuddy/__init__.py +++ b/src/embeddingbuddy/__init__.py @@ -1,3 +1,3 @@ """EmbeddingBuddy - Interactive exploration and visualization of embedding vectors.""" -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.1.0" diff --git a/src/embeddingbuddy/app.py b/src/embeddingbuddy/app.py index 9c386fa..bf326d8 100644 --- a/src/embeddingbuddy/app.py +++ b/src/embeddingbuddy/app.py @@ -8,32 +8,29 @@ from .ui.callbacks.interactions import InteractionCallbacks def create_app(): - app = dash.Dash( - __name__, - external_stylesheets=[dbc.themes.BOOTSTRAP] - ) - + app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) + layout_manager = AppLayout() app.layout = layout_manager.create_layout() - + DataProcessingCallbacks() VisualizationCallbacks() InteractionCallbacks() - + return app def run_app(app=None, debug=None, host=None, port=None): if app is None: app = create_app() - + app.run( debug=debug if debug is not None else AppSettings.DEBUG, host=host if host is not None else AppSettings.HOST, - port=port if port is not None else AppSettings.PORT + port=port if port is not None else AppSettings.PORT, ) -if __name__ == '__main__': +if __name__ == "__main__": app = create_app() - run_app(app) \ No newline at end of file + run_app(app) diff --git a/src/embeddingbuddy/config/settings.py b/src/embeddingbuddy/config/settings.py index 48db0b1..8839430 100644 --- a/src/embeddingbuddy/config/settings.py +++ b/src/embeddingbuddy/config/settings.py @@ -3,105 +3,100 @@ import os class AppSettings: - # UI Configuration UPLOAD_STYLE = { - 'width': '100%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin-bottom': '20px' + "width": "100%", + "height": "60px", + "lineHeight": "60px", + "borderWidth": "1px", + "borderStyle": "dashed", + "borderRadius": "5px", + "textAlign": "center", + "margin-bottom": "20px", } - - PROMPTS_UPLOAD_STYLE = { - **UPLOAD_STYLE, - 'borderColor': '#28a745' - } - - PLOT_CONFIG = { - 'responsive': True, - 'displayModeBar': True - } - - PLOT_STYLE = { - 'height': '85vh', - 'width': '100%' - } - + + PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"} + + PLOT_CONFIG = {"responsive": True, "displayModeBar": True} + + PLOT_STYLE = {"height": "85vh", "width": "100%"} + PLOT_LAYOUT_CONFIG = { - 'height': None, - 'autosize': True, - 'margin': dict(l=0, r=0, t=50, b=0) + "height": None, + "autosize": True, + "margin": dict(l=0, r=0, t=50, b=0), } - + # Dimensionality Reduction Settings DEFAULT_N_COMPONENTS_3D = 3 DEFAULT_N_COMPONENTS_2D = 2 DEFAULT_RANDOM_STATE = 42 - + # Available Methods REDUCTION_METHODS = [ - {'label': 'PCA', 'value': 'pca'}, - {'label': 't-SNE', 'value': 'tsne'}, - {'label': 'UMAP', 'value': 'umap'} + {"label": "PCA", "value": "pca"}, + {"label": "t-SNE", "value": "tsne"}, + {"label": "UMAP", "value": "umap"}, ] - + COLOR_OPTIONS = [ - {'label': 'Category', 'value': 'category'}, - {'label': 'Subcategory', 'value': 'subcategory'}, - {'label': 'Tags', 'value': 'tags'} + {"label": "Category", "value": "category"}, + {"label": "Subcategory", "value": "subcategory"}, + {"label": "Tags", "value": "tags"}, ] - - DIMENSION_OPTIONS = [ - {'label': '2D', 'value': '2d'}, - {'label': '3D', 'value': '3d'} - ] - + + DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}] + # Default Values - DEFAULT_METHOD = 'pca' - DEFAULT_COLOR_BY = 'category' - DEFAULT_DIMENSIONS = '3d' - DEFAULT_SHOW_PROMPTS = ['show'] - + DEFAULT_METHOD = "pca" + DEFAULT_COLOR_BY = "category" + DEFAULT_DIMENSIONS = "3d" + DEFAULT_SHOW_PROMPTS = ["show"] + # Plot Marker Settings DOCUMENT_MARKER_SIZE_2D = 8 DOCUMENT_MARKER_SIZE_3D = 5 PROMPT_MARKER_SIZE_2D = 10 PROMPT_MARKER_SIZE_3D = 6 - - DOCUMENT_MARKER_SYMBOL = 'circle' - PROMPT_MARKER_SYMBOL = 'diamond' - + + DOCUMENT_MARKER_SYMBOL = "circle" + PROMPT_MARKER_SYMBOL = "diamond" + DOCUMENT_OPACITY = 1.0 PROMPT_OPACITY = 0.8 - + # Text Processing TEXT_PREVIEW_LENGTH = 100 - + # App Configuration - DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true' - HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1') - PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050')) - + DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true" + HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1") + PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050")) + # Bootstrap Theme - EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css'] - + EXTERNAL_STYLESHEETS = [ + "https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" + ] + @classmethod - def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]: + def get_plot_marker_config( + cls, dimensions: str, is_prompt: bool = False + ) -> Dict[str, Any]: if is_prompt: - size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D + size = ( + cls.PROMPT_MARKER_SIZE_3D + if dimensions == "3d" + else cls.PROMPT_MARKER_SIZE_2D + ) symbol = cls.PROMPT_MARKER_SYMBOL opacity = cls.PROMPT_OPACITY else: - size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D + size = ( + cls.DOCUMENT_MARKER_SIZE_3D + if dimensions == "3d" + else cls.DOCUMENT_MARKER_SIZE_2D + ) symbol = cls.DOCUMENT_MARKER_SYMBOL opacity = cls.DOCUMENT_OPACITY - - return { - 'size': size, - 'symbol': symbol, - 'opacity': opacity - } \ No newline at end of file + + return {"size": size, "symbol": symbol, "opacity": opacity} diff --git a/src/embeddingbuddy/data/parser.py b/src/embeddingbuddy/data/parser.py index af241cb..ed76bb2 100644 --- a/src/embeddingbuddy/data/parser.py +++ b/src/embeddingbuddy/data/parser.py @@ -6,34 +6,33 @@ from ..models.schemas import Document class NDJSONParser: - @staticmethod def parse_upload_contents(contents: str) -> List[Document]: - content_type, content_string = contents.split(',') + content_type, content_string = contents.split(",") decoded = base64.b64decode(content_string) - text_content = decoded.decode('utf-8') + text_content = decoded.decode("utf-8") return NDJSONParser.parse_text(text_content) - + @staticmethod def parse_text(text_content: str) -> List[Document]: documents = [] - for line in text_content.strip().split('\n'): + for line in text_content.strip().split("\n"): if line.strip(): doc_dict = json.loads(line) doc = NDJSONParser._dict_to_document(doc_dict) documents.append(doc) return documents - + @staticmethod def _dict_to_document(doc_dict: dict) -> Document: - if 'id' not in doc_dict: - doc_dict['id'] = str(uuid.uuid4()) - + if "id" not in doc_dict: + doc_dict["id"] = str(uuid.uuid4()) + return Document( - id=doc_dict['id'], - text=doc_dict['text'], - embedding=doc_dict['embedding'], - category=doc_dict.get('category'), - subcategory=doc_dict.get('subcategory'), - tags=doc_dict.get('tags') - ) \ No newline at end of file + id=doc_dict["id"], + text=doc_dict["text"], + embedding=doc_dict["embedding"], + category=doc_dict.get("category"), + subcategory=doc_dict.get("subcategory"), + tags=doc_dict.get("tags"), + ) diff --git a/src/embeddingbuddy/data/processor.py b/src/embeddingbuddy/data/processor.py index 7c8ac87..a9eb683 100644 --- a/src/embeddingbuddy/data/processor.py +++ b/src/embeddingbuddy/data/processor.py @@ -5,18 +5,19 @@ from .parser import NDJSONParser class DataProcessor: - def __init__(self): self.parser = NDJSONParser() - - def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData: + + def process_upload( + self, contents: str, filename: Optional[str] = None + ) -> ProcessedData: try: documents = self.parser.parse_upload_contents(contents) embeddings = self._extract_embeddings(documents) return ProcessedData(documents=documents, embeddings=embeddings) except Exception as e: return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) - + def process_text(self, text_content: str) -> ProcessedData: try: documents = self.parser.parse_text(text_content) @@ -24,31 +25,35 @@ class DataProcessor: return ProcessedData(documents=documents, embeddings=embeddings) except Exception as e: return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) - + def _extract_embeddings(self, documents: List[Document]) -> np.ndarray: if not documents: return np.array([]) return np.array([doc.embedding for doc in documents]) - - def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]: + + def combine_data( + self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None + ) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]: if not doc_data or doc_data.error: raise ValueError("Invalid document data") - + all_embeddings = doc_data.embeddings documents = doc_data.documents prompts = None - + if prompt_data and not prompt_data.error and prompt_data.documents: all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings]) prompts = prompt_data.documents - + return all_embeddings, documents, prompts - - def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]: + + def split_reduced_data( + self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0 + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: doc_reduced = reduced_embeddings[:n_documents] prompt_reduced = None - + if n_prompts > 0: - prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts] - - return doc_reduced, prompt_reduced \ No newline at end of file + prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts] + + return doc_reduced, prompt_reduced diff --git a/src/embeddingbuddy/models/reducers.py b/src/embeddingbuddy/models/reducers.py index bbfcd05..6e3a9bd 100644 --- a/src/embeddingbuddy/models/reducers.py +++ b/src/embeddingbuddy/models/reducers.py @@ -7,88 +7,89 @@ from .schemas import ReducedData class DimensionalityReducer(ABC): - def __init__(self, n_components: int = 3, random_state: int = 42): self.n_components = n_components self.random_state = random_state self._reducer = None - + @abstractmethod def fit_transform(self, embeddings: np.ndarray) -> ReducedData: pass - + @abstractmethod def get_method_name(self) -> str: pass class PCAReducer(DimensionalityReducer): - def fit_transform(self, embeddings: np.ndarray) -> ReducedData: self._reducer = PCA(n_components=self.n_components) reduced = self._reducer.fit_transform(embeddings) variance_explained = self._reducer.explained_variance_ratio_ - + return ReducedData( reduced_embeddings=reduced, variance_explained=variance_explained, method=self.get_method_name(), - n_components=self.n_components + n_components=self.n_components, ) - + def get_method_name(self) -> str: return "PCA" class TSNEReducer(DimensionalityReducer): - def fit_transform(self, embeddings: np.ndarray) -> ReducedData: - self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state) + self._reducer = TSNE( + n_components=self.n_components, random_state=self.random_state + ) reduced = self._reducer.fit(embeddings) - + return ReducedData( reduced_embeddings=reduced, variance_explained=None, method=self.get_method_name(), - n_components=self.n_components + n_components=self.n_components, ) - + def get_method_name(self) -> str: return "t-SNE" class UMAPReducer(DimensionalityReducer): - def fit_transform(self, embeddings: np.ndarray) -> ReducedData: - self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state) + self._reducer = umap.UMAP( + n_components=self.n_components, random_state=self.random_state + ) reduced = self._reducer.fit_transform(embeddings) - + return ReducedData( reduced_embeddings=reduced, variance_explained=None, method=self.get_method_name(), - n_components=self.n_components + n_components=self.n_components, ) - + def get_method_name(self) -> str: return "UMAP" class ReducerFactory: - @staticmethod - def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer: + def create_reducer( + method: str, n_components: int = 3, random_state: int = 42 + ) -> DimensionalityReducer: method_lower = method.lower() - - if method_lower == 'pca': + + if method_lower == "pca": return PCAReducer(n_components=n_components, random_state=random_state) - elif method_lower == 'tsne': + elif method_lower == "tsne": return TSNEReducer(n_components=n_components, random_state=random_state) - elif method_lower == 'umap': + elif method_lower == "umap": return UMAPReducer(n_components=n_components, random_state=random_state) else: raise ValueError(f"Unknown reduction method: {method}") - + @staticmethod def get_available_methods() -> list: - return ['pca', 'tsne', 'umap'] \ No newline at end of file + return ["pca", "tsne", "umap"] diff --git a/src/embeddingbuddy/models/schemas.py b/src/embeddingbuddy/models/schemas.py index 36f9089..a3baf8f 100644 --- a/src/embeddingbuddy/models/schemas.py +++ b/src/embeddingbuddy/models/schemas.py @@ -50,9 +50,11 @@ class PlotData: coordinates: np.ndarray prompts: Optional[List[Document]] = None prompt_coordinates: Optional[np.ndarray] = None - + def __post_init__(self): if not isinstance(self.coordinates, np.ndarray): self.coordinates = np.array(self.coordinates) - if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray): - self.prompt_coordinates = np.array(self.prompt_coordinates) \ No newline at end of file + if self.prompt_coordinates is not None and not isinstance( + self.prompt_coordinates, np.ndarray + ): + self.prompt_coordinates = np.array(self.prompt_coordinates) diff --git a/src/embeddingbuddy/ui/callbacks/data_processing.py b/src/embeddingbuddy/ui/callbacks/data_processing.py index 09f7e47..2a2bf7a 100644 --- a/src/embeddingbuddy/ui/callbacks/data_processing.py +++ b/src/embeddingbuddy/ui/callbacks/data_processing.py @@ -3,58 +3,60 @@ from ...data.processor import DataProcessor class DataProcessingCallbacks: - def __init__(self): self.processor = DataProcessor() self._register_callbacks() - + def _register_callbacks(self): - @callback( - Output('processed-data', 'data'), - Input('upload-data', 'contents'), - State('upload-data', 'filename') + Output("processed-data", "data"), + Input("upload-data", "contents"), + State("upload-data", "filename"), ) def process_uploaded_file(contents, filename): if contents is None: return None - + processed_data = self.processor.process_upload(contents, filename) - + if processed_data.error: - return {'error': processed_data.error} - + return {"error": processed_data.error} + return { - 'documents': [self._document_to_dict(doc) for doc in processed_data.documents], - 'embeddings': processed_data.embeddings.tolist() + "documents": [ + self._document_to_dict(doc) for doc in processed_data.documents + ], + "embeddings": processed_data.embeddings.tolist(), } - + @callback( - Output('processed-prompts', 'data'), - Input('upload-prompts', 'contents'), - State('upload-prompts', 'filename') + Output("processed-prompts", "data"), + Input("upload-prompts", "contents"), + State("upload-prompts", "filename"), ) def process_uploaded_prompts(contents, filename): if contents is None: return None - + processed_data = self.processor.process_upload(contents, filename) - + if processed_data.error: - return {'error': processed_data.error} - + return {"error": processed_data.error} + return { - 'prompts': [self._document_to_dict(doc) for doc in processed_data.documents], - 'embeddings': processed_data.embeddings.tolist() + "prompts": [ + self._document_to_dict(doc) for doc in processed_data.documents + ], + "embeddings": processed_data.embeddings.tolist(), } - + @staticmethod def _document_to_dict(doc): return { - 'id': doc.id, - 'text': doc.text, - 'embedding': doc.embedding, - 'category': doc.category, - 'subcategory': doc.subcategory, - 'tags': doc.tags - } \ No newline at end of file + "id": doc.id, + "text": doc.text, + "embedding": doc.embedding, + "category": doc.category, + "subcategory": doc.subcategory, + "tags": doc.tags, + } diff --git a/src/embeddingbuddy/ui/callbacks/interactions.py b/src/embeddingbuddy/ui/callbacks/interactions.py index d01f125..55aec8f 100644 --- a/src/embeddingbuddy/ui/callbacks/interactions.py +++ b/src/embeddingbuddy/ui/callbacks/interactions.py @@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc class InteractionCallbacks: - def __init__(self): self._register_callbacks() - + def _register_callbacks(self): - @callback( - Output('point-details', 'children'), - Input('embedding-plot', 'clickData'), - [State('processed-data', 'data'), - State('processed-prompts', 'data')] + Output("point-details", "children"), + Input("embedding-plot", "clickData"), + [State("processed-data", "data"), State("processed-prompts", "data")], ) def display_click_data(clickData, data, prompts_data): if not clickData or not data: return "Click on a point to see details" - - point_data = clickData['points'][0] - trace_name = point_data.get('fullData', {}).get('name', 'Documents') - - if 'pointIndex' in point_data: - point_index = point_data['pointIndex'] - elif 'pointNumber' in point_data: - point_index = point_data['pointNumber'] + + point_data = clickData["points"][0] + trace_name = point_data.get("fullData", {}).get("name", "Documents") + + if "pointIndex" in point_data: + point_index = point_data["pointIndex"] + elif "pointNumber" in point_data: + point_index = point_data["pointNumber"] else: return "Could not identify clicked point" - - if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data: - item = prompts_data['prompts'][point_index] - item_type = 'Prompt' + + if ( + trace_name.startswith("Prompts") + and prompts_data + and "prompts" in prompts_data + ): + item = prompts_data["prompts"][point_index] + item_type = "Prompt" else: - item = data['documents'][point_index] - item_type = 'Document' - + item = data["documents"][point_index] + item_type = "Document" + return self._create_detail_card(item, item_type) - + @callback( - [Output('processed-data', 'data', allow_duplicate=True), - Output('processed-prompts', 'data', allow_duplicate=True), - Output('point-details', 'children', allow_duplicate=True)], - Input('reset-button', 'n_clicks'), - prevent_initial_call=True + [ + Output("processed-data", "data", allow_duplicate=True), + Output("processed-prompts", "data", allow_duplicate=True), + Output("point-details", "children", allow_duplicate=True), + ], + Input("reset-button", "n_clicks"), + prevent_initial_call=True, ) def reset_data(n_clicks): if n_clicks is None or n_clicks == 0: return dash.no_update, dash.no_update, dash.no_update - + return None, None, "Click on a point to see details" - + @staticmethod def _create_detail_card(item, item_type): - return dbc.Card([ - dbc.CardBody([ - html.H5(f"{item_type}: {item['id']}", className="card-title"), - html.P(f"Text: {item['text']}", className="card-text"), - html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"), - html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"), - html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"), - html.P(f"Type: {item_type}", className="card-text text-muted") - ]) - ]) \ No newline at end of file + return dbc.Card( + [ + dbc.CardBody( + [ + html.H5(f"{item_type}: {item['id']}", className="card-title"), + html.P(f"Text: {item['text']}", className="card-text"), + html.P( + f"Category: {item.get('category', 'Unknown')}", + className="card-text", + ), + html.P( + f"Subcategory: {item.get('subcategory', 'Unknown')}", + className="card-text", + ), + html.P( + f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", + className="card-text", + ), + html.P(f"Type: {item_type}", className="card-text text-muted"), + ] + ) + ] + ) diff --git a/src/embeddingbuddy/ui/callbacks/visualization.py b/src/embeddingbuddy/ui/callbacks/visualization.py index b10f4ae..201bc70 100644 --- a/src/embeddingbuddy/ui/callbacks/visualization.py +++ b/src/embeddingbuddy/ui/callbacks/visualization.py @@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory class VisualizationCallbacks: - def __init__(self): self.plot_factory = PlotFactory() self._register_callbacks() - + def _register_callbacks(self): - @callback( - Output('embedding-plot', 'figure'), - [Input('processed-data', 'data'), - Input('processed-prompts', 'data'), - Input('method-dropdown', 'value'), - Input('color-dropdown', 'value'), - Input('dimension-toggle', 'value'), - Input('show-prompts-toggle', 'value')] + Output("embedding-plot", "figure"), + [ + Input("processed-data", "data"), + Input("processed-prompts", "data"), + Input("method-dropdown", "value"), + Input("color-dropdown", "value"), + Input("dimension-toggle", "value"), + Input("show-prompts-toggle", "value"), + ], ) def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): - if not data or 'error' in data: + if not data or "error" in data: return go.Figure().add_annotation( text="Upload a valid NDJSON file to see visualization", - xref="paper", yref="paper", - x=0.5, y=0.5, xanchor='center', yanchor='middle', - showarrow=False, font=dict(size=16) + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", + showarrow=False, + font=dict(size=16), ) - + try: - doc_embeddings = np.array(data['embeddings']) + doc_embeddings = np.array(data["embeddings"]) all_embeddings = doc_embeddings - has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts') - + has_prompts = ( + prompts_data + and "error" not in prompts_data + and prompts_data.get("prompts") + ) + if has_prompts: - prompt_embeddings = np.array(prompts_data['embeddings']) + prompt_embeddings = np.array(prompts_data["embeddings"]) all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) - - n_components = 3 if dimensions == '3d' else 2 - - reducer = ReducerFactory.create_reducer(method, n_components=n_components) + + n_components = 3 if dimensions == "3d" else 2 + + reducer = ReducerFactory.create_reducer( + method, n_components=n_components + ) reduced_data = reducer.fit_transform(all_embeddings) - - doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)] + + doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)] prompt_reduced = None if has_prompts: - prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):] - - documents = [self._dict_to_document(doc) for doc in data['documents']] + prompt_reduced = reduced_data.reduced_embeddings[ + len(doc_embeddings) : + ] + + documents = [self._dict_to_document(doc) for doc in data["documents"]] prompts = None if has_prompts: - prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']] - + prompts = [ + self._dict_to_document(prompt) + for prompt in prompts_data["prompts"] + ] + plot_data = PlotData( documents=documents, coordinates=doc_reduced, prompts=prompts, - prompt_coordinates=prompt_reduced + prompt_coordinates=prompt_reduced, ) - + return self.plot_factory.create_plot( plot_data, dimensions, color_by, reduced_data.method, show_prompts ) - + except Exception as e: return go.Figure().add_annotation( text=f"Error creating visualization: {str(e)}", - xref="paper", yref="paper", - x=0.5, y=0.5, xanchor='center', yanchor='middle', - showarrow=False, font=dict(size=16) + xref="paper", + yref="paper", + x=0.5, + y=0.5, + xanchor="center", + yanchor="middle", + showarrow=False, + font=dict(size=16), ) - + @staticmethod def _dict_to_document(doc_dict): return Document( - id=doc_dict['id'], - text=doc_dict['text'], - embedding=doc_dict['embedding'], - category=doc_dict.get('category'), - subcategory=doc_dict.get('subcategory'), - tags=doc_dict.get('tags', []) - ) \ No newline at end of file + id=doc_dict["id"], + text=doc_dict["text"], + embedding=doc_dict["embedding"], + category=doc_dict.get("category"), + subcategory=doc_dict.get("subcategory"), + tags=doc_dict.get("tags", []), + ) diff --git a/src/embeddingbuddy/ui/components/sidebar.py b/src/embeddingbuddy/ui/components/sidebar.py index 1160f22..2ea6487 100644 --- a/src/embeddingbuddy/ui/components/sidebar.py +++ b/src/embeddingbuddy/ui/components/sidebar.py @@ -4,79 +4,81 @@ from .upload import UploadComponent class SidebarComponent: - def __init__(self): self.upload_component = UploadComponent() - + def create_layout(self): - return dbc.Col([ - html.H5("Upload Data", className="mb-3"), - self.upload_component.create_data_upload(), - self.upload_component.create_prompts_upload(), - self.upload_component.create_reset_button(), - - html.H5("Visualization Controls", className="mb-3"), - self._create_method_dropdown(), - self._create_color_dropdown(), - self._create_dimension_toggle(), - self._create_prompts_toggle(), - - html.H5("Point Details", className="mb-3"), - html.Div(id='point-details', children="Click on a point to see details") - - ], width=3, style={'padding-right': '20px'}) - + return dbc.Col( + [ + html.H5("Upload Data", className="mb-3"), + self.upload_component.create_data_upload(), + self.upload_component.create_prompts_upload(), + self.upload_component.create_reset_button(), + html.H5("Visualization Controls", className="mb-3"), + self._create_method_dropdown(), + self._create_color_dropdown(), + self._create_dimension_toggle(), + self._create_prompts_toggle(), + html.H5("Point Details", className="mb-3"), + html.Div( + id="point-details", children="Click on a point to see details" + ), + ], + width=3, + style={"padding-right": "20px"}, + ) + def _create_method_dropdown(self): return [ dbc.Label("Method:"), dcc.Dropdown( - id='method-dropdown', + id="method-dropdown", options=[ - {'label': 'PCA', 'value': 'pca'}, - {'label': 't-SNE', 'value': 'tsne'}, - {'label': 'UMAP', 'value': 'umap'} + {"label": "PCA", "value": "pca"}, + {"label": "t-SNE", "value": "tsne"}, + {"label": "UMAP", "value": "umap"}, ], - value='pca', - style={'margin-bottom': '15px'} - ) + value="pca", + style={"margin-bottom": "15px"}, + ), ] - + def _create_color_dropdown(self): return [ dbc.Label("Color by:"), dcc.Dropdown( - id='color-dropdown', + id="color-dropdown", options=[ - {'label': 'Category', 'value': 'category'}, - {'label': 'Subcategory', 'value': 'subcategory'}, - {'label': 'Tags', 'value': 'tags'} + {"label": "Category", "value": "category"}, + {"label": "Subcategory", "value": "subcategory"}, + {"label": "Tags", "value": "tags"}, ], - value='category', - style={'margin-bottom': '15px'} - ) + value="category", + style={"margin-bottom": "15px"}, + ), ] - + def _create_dimension_toggle(self): return [ dbc.Label("Dimensions:"), dcc.RadioItems( - id='dimension-toggle', + id="dimension-toggle", options=[ - {'label': '2D', 'value': '2d'}, - {'label': '3D', 'value': '3d'} + {"label": "2D", "value": "2d"}, + {"label": "3D", "value": "3d"}, ], - value='3d', - style={'margin-bottom': '20px'} - ) + value="3d", + style={"margin-bottom": "20px"}, + ), ] - + def _create_prompts_toggle(self): return [ dbc.Label("Show Prompts:"), dcc.Checklist( - id='show-prompts-toggle', - options=[{'label': 'Show prompts on plot', 'value': 'show'}], - value=['show'], - style={'margin-bottom': '20px'} - ) - ] \ No newline at end of file + id="show-prompts-toggle", + options=[{"label": "Show prompts on plot", "value": "show"}], + value=["show"], + style={"margin-bottom": "20px"}, + ), + ] diff --git a/src/embeddingbuddy/ui/components/upload.py b/src/embeddingbuddy/ui/components/upload.py index 9aace94..9a25092 100644 --- a/src/embeddingbuddy/ui/components/upload.py +++ b/src/embeddingbuddy/ui/components/upload.py @@ -3,58 +3,51 @@ import dash_bootstrap_components as dbc class UploadComponent: - @staticmethod def create_data_upload(): return dcc.Upload( - id='upload-data', - children=html.Div([ - 'Drag and Drop or ', - html.A('Select Files') - ]), + id="upload-data", + children=html.Div(["Drag and Drop or ", html.A("Select Files")]), style={ - 'width': '100%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin-bottom': '20px' + "width": "100%", + "height": "60px", + "lineHeight": "60px", + "borderWidth": "1px", + "borderStyle": "dashed", + "borderRadius": "5px", + "textAlign": "center", + "margin-bottom": "20px", }, - multiple=False + multiple=False, ) - + @staticmethod def create_prompts_upload(): return dcc.Upload( - id='upload-prompts', - children=html.Div([ - 'Drag and Drop Prompts or ', - html.A('Select Files') - ]), + id="upload-prompts", + children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]), style={ - 'width': '100%', - 'height': '60px', - 'lineHeight': '60px', - 'borderWidth': '1px', - 'borderStyle': 'dashed', - 'borderRadius': '5px', - 'textAlign': 'center', - 'margin-bottom': '20px', - 'borderColor': '#28a745' + "width": "100%", + "height": "60px", + "lineHeight": "60px", + "borderWidth": "1px", + "borderStyle": "dashed", + "borderRadius": "5px", + "textAlign": "center", + "margin-bottom": "20px", + "borderColor": "#28a745", }, - multiple=False + multiple=False, ) - + @staticmethod def create_reset_button(): return dbc.Button( "Reset All Data", - id='reset-button', - color='danger', + id="reset-button", + color="danger", outline=True, - size='sm', - className='mb-3', - style={'width': '100%'} - ) \ No newline at end of file + size="sm", + className="mb-3", + style={"width": "100%"}, + ) diff --git a/src/embeddingbuddy/ui/layout.py b/src/embeddingbuddy/ui/layout.py index 4ed4f24..960bc4a 100644 --- a/src/embeddingbuddy/ui/layout.py +++ b/src/embeddingbuddy/ui/layout.py @@ -4,41 +4,43 @@ from .components.sidebar import SidebarComponent class AppLayout: - def __init__(self): self.sidebar = SidebarComponent() - + def create_layout(self): - return dbc.Container([ - self._create_header(), - self._create_main_content(), - self._create_stores() - ], fluid=True) - + return dbc.Container( + [self._create_header(), self._create_main_content(), self._create_stores()], + fluid=True, + ) + def _create_header(self): - return dbc.Row([ - dbc.Col([ - html.H1("EmbeddingBuddy", className="text-center mb-4"), - ], width=12) - ]) - + return dbc.Row( + [ + dbc.Col( + [ + html.H1("EmbeddingBuddy", className="text-center mb-4"), + ], + width=12, + ) + ] + ) + def _create_main_content(self): - return dbc.Row([ - self.sidebar.create_layout(), - self._create_visualization_area() - ]) - + return dbc.Row( + [self.sidebar.create_layout(), self._create_visualization_area()] + ) + def _create_visualization_area(self): - return dbc.Col([ - dcc.Graph( - id='embedding-plot', - style={'height': '85vh', 'width': '100%'}, - config={'responsive': True, 'displayModeBar': True} - ) - ], width=9) - + return dbc.Col( + [ + dcc.Graph( + id="embedding-plot", + style={"height": "85vh", "width": "100%"}, + config={"responsive": True, "displayModeBar": True}, + ) + ], + width=9, + ) + def _create_stores(self): - return [ - dcc.Store(id='processed-data'), - dcc.Store(id='processed-prompts') - ] \ No newline at end of file + return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")] diff --git a/src/embeddingbuddy/visualization/colors.py b/src/embeddingbuddy/visualization/colors.py index d263228..5ba6ddf 100644 --- a/src/embeddingbuddy/visualization/colors.py +++ b/src/embeddingbuddy/visualization/colors.py @@ -4,30 +4,33 @@ from ..models.schemas import Document class ColorMapper: - @staticmethod def create_color_mapping(documents: List[Document], color_by: str) -> List[str]: - if color_by == 'category': + if color_by == "category": return [doc.category for doc in documents] - elif color_by == 'subcategory': + elif color_by == "subcategory": return [doc.subcategory for doc in documents] - elif color_by == 'tags': - return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents] + elif color_by == "tags": + return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents] else: - return ['All'] * len(documents) - + return ["All"] * len(documents) + @staticmethod def to_grayscale_hex(color_str: str) -> str: try: - if color_str.startswith('#'): - rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5)) + if color_str.startswith("#"): + rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5)) else: - rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0]) - + rgb = pc.hex_to_rgb( + pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0] + ) + gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) - gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, - gray_value * 0.7 + rgb[1] * 0.3, - gray_value * 0.7 + rgb[2] * 0.3) - return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})' - except: # noqa: E722 - return 'rgb(128,128,128)' \ No newline at end of file + gray_rgb = ( + gray_value * 0.7 + rgb[0] * 0.3, + gray_value * 0.7 + rgb[1] * 0.3, + gray_value * 0.7 + rgb[2] * 0.3, + ) + return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})" + except: # noqa: E722 + return "rgb(128,128,128)" diff --git a/src/embeddingbuddy/visualization/plots.py b/src/embeddingbuddy/visualization/plots.py index d472b1b..0b32445 100644 --- a/src/embeddingbuddy/visualization/plots.py +++ b/src/embeddingbuddy/visualization/plots.py @@ -7,139 +7,172 @@ from .colors import ColorMapper class PlotFactory: - def __init__(self): self.color_mapper = ColorMapper() - - def create_plot(self, plot_data: PlotData, dimensions: str = '3d', - color_by: str = 'category', method: str = 'PCA', - show_prompts: Optional[List[str]] = None) -> go.Figure: - - if plot_data.prompts and show_prompts and 'show' in show_prompts: + + def create_plot( + self, + plot_data: PlotData, + dimensions: str = "3d", + color_by: str = "category", + method: str = "PCA", + show_prompts: Optional[List[str]] = None, + ) -> go.Figure: + if plot_data.prompts and show_prompts and "show" in show_prompts: return self._create_dual_plot(plot_data, dimensions, color_by, method) else: return self._create_single_plot(plot_data, dimensions, color_by, method) - - def _create_single_plot(self, plot_data: PlotData, dimensions: str, - color_by: str, method: str) -> go.Figure: - df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) - color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) - - hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] - - if dimensions == '3d': + + def _create_single_plot( + self, plot_data: PlotData, dimensions: str, color_by: str, method: str + ) -> go.Figure: + df = self._prepare_dataframe( + plot_data.documents, plot_data.coordinates, dimensions + ) + color_values = self.color_mapper.create_color_mapping( + plot_data.documents, color_by + ) + + hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"] + + if dimensions == "3d": fig = px.scatter_3d( - df, x='dim_1', y='dim_2', z='dim_3', + df, + x="dim_1", + y="dim_2", + z="dim_3", color=color_values, hover_data=hover_fields, - title=f'3D Embedding Visualization - {method} (colored by {color_by})' + title=f"3D Embedding Visualization - {method} (colored by {color_by})", ) fig.update_traces(marker=dict(size=5)) else: fig = px.scatter( - df, x='dim_1', y='dim_2', + df, + x="dim_1", + y="dim_2", color=color_values, hover_data=hover_fields, - title=f'2D Embedding Visualization - {method} (colored by {color_by})' + title=f"2D Embedding Visualization - {method} (colored by {color_by})", ) fig.update_traces(marker=dict(size=8)) - - fig.update_layout( - height=None, - autosize=True, - margin=dict(l=0, r=0, t=50, b=0) - ) + + fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)) return fig - - def _create_dual_plot(self, plot_data: PlotData, dimensions: str, - color_by: str, method: str) -> go.Figure: + + def _create_dual_plot( + self, plot_data: PlotData, dimensions: str, color_by: str, method: str + ) -> go.Figure: fig = go.Figure() - - doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) - doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) - - hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] - - if dimensions == '3d': + + doc_df = self._prepare_dataframe( + plot_data.documents, plot_data.coordinates, dimensions + ) + doc_color_values = self.color_mapper.create_color_mapping( + plot_data.documents, color_by + ) + + hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"] + + if dimensions == "3d": doc_fig = px.scatter_3d( - doc_df, x='dim_1', y='dim_2', z='dim_3', + doc_df, + x="dim_1", + y="dim_2", + z="dim_3", color=doc_color_values, - hover_data=hover_fields + hover_data=hover_fields, ) else: doc_fig = px.scatter( - doc_df, x='dim_1', y='dim_2', + doc_df, + x="dim_1", + y="dim_2", color=doc_color_values, - hover_data=hover_fields + hover_data=hover_fields, ) - + for trace in doc_fig.data: - trace.name = f'Documents - {trace.name}' - if dimensions == '3d': + trace.name = f"Documents - {trace.name}" + if dimensions == "3d": trace.marker.size = 5 - trace.marker.symbol = 'circle' + trace.marker.symbol = "circle" else: trace.marker.size = 8 - trace.marker.symbol = 'circle' + trace.marker.symbol = "circle" trace.marker.opacity = 1.0 fig.add_trace(trace) - + if plot_data.prompts and plot_data.prompt_coordinates is not None: - prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions) - prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by) - - if dimensions == '3d': + prompt_df = self._prepare_dataframe( + plot_data.prompts, plot_data.prompt_coordinates, dimensions + ) + prompt_color_values = self.color_mapper.create_color_mapping( + plot_data.prompts, color_by + ) + + if dimensions == "3d": prompt_fig = px.scatter_3d( - prompt_df, x='dim_1', y='dim_2', z='dim_3', + prompt_df, + x="dim_1", + y="dim_2", + z="dim_3", color=prompt_color_values, - hover_data=hover_fields + hover_data=hover_fields, ) else: prompt_fig = px.scatter( - prompt_df, x='dim_1', y='dim_2', + prompt_df, + x="dim_1", + y="dim_2", color=prompt_color_values, - hover_data=hover_fields + hover_data=hover_fields, ) - + for trace in prompt_fig.data: - if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str): - trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color) - - trace.name = f'Prompts - {trace.name}' - if dimensions == '3d': + if hasattr(trace.marker, "color") and isinstance( + trace.marker.color, str + ): + trace.marker.color = self.color_mapper.to_grayscale_hex( + trace.marker.color + ) + + trace.name = f"Prompts - {trace.name}" + if dimensions == "3d": trace.marker.size = 6 - trace.marker.symbol = 'diamond' + trace.marker.symbol = "diamond" else: trace.marker.size = 10 - trace.marker.symbol = 'diamond' + trace.marker.symbol = "diamond" trace.marker.opacity = 0.8 fig.add_trace(trace) - - title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})' + + title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})" fig.update_layout( - title=title, - height=None, - autosize=True, - margin=dict(l=0, r=0, t=50, b=0) + title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0) ) - + return fig - - def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame: + + def _prepare_dataframe( + self, documents: List[Document], coordinates, dimensions: str + ) -> pd.DataFrame: df_data = [] for i, doc in enumerate(documents): row = { - 'id': doc.id, - 'text': doc.text, - 'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text, - 'category': doc.category, - 'subcategory': doc.subcategory, - 'tags_str': ', '.join(doc.tags) if doc.tags else 'None', - 'dim_1': coordinates[i, 0], - 'dim_2': coordinates[i, 1], + "id": doc.id, + "text": doc.text, + "text_preview": doc.text[:100] + "..." + if len(doc.text) > 100 + else doc.text, + "category": doc.category, + "subcategory": doc.subcategory, + "tags_str": ", ".join(doc.tags) if doc.tags else "None", + "dim_1": coordinates[i, 0], + "dim_2": coordinates[i, 1], } - if dimensions == '3d': - row['dim_3'] = coordinates[i, 2] + if dimensions == "3d": + row["dim_3"] = coordinates[i, 2] df_data.append(row) - - return pd.DataFrame(df_data) \ No newline at end of file + + return pd.DataFrame(df_data) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index edfe278..041f043 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -6,62 +6,64 @@ from src.embeddingbuddy.models.schemas import Document class TestNDJSONParser: - def test_parse_text_basic(self): - text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}' + text_content = ( + '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}' + ) documents = NDJSONParser.parse_text(text_content) - + assert len(documents) == 1 assert documents[0].id == "test1" assert documents[0].text == "Hello world" assert documents[0].embedding == [0.1, 0.2, 0.3] - + def test_parse_text_with_metadata(self): text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}' documents = NDJSONParser.parse_text(text_content) - + assert documents[0].category == "greeting" assert documents[0].tags == ["test"] - + def test_parse_text_missing_id(self): text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}' documents = NDJSONParser.parse_text(text_content) - + assert len(documents) == 1 assert documents[0].id is not None # Should be auto-generated class TestDataProcessor: - def test_extract_embeddings(self): documents = [ Document(id="1", text="test1", embedding=[0.1, 0.2]), - Document(id="2", text="test2", embedding=[0.3, 0.4]) + Document(id="2", text="test2", embedding=[0.3, 0.4]), ] - + processor = DataProcessor() embeddings = processor._extract_embeddings(documents) - + assert embeddings.shape == (2, 2) assert np.allclose(embeddings[0], [0.1, 0.2]) assert np.allclose(embeddings[1], [0.3, 0.4]) - + def test_combine_data(self): from src.embeddingbuddy.models.schemas import ProcessedData - + doc_data = ProcessedData( documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])], - embeddings=np.array([[0.1, 0.2]]) + embeddings=np.array([[0.1, 0.2]]), ) - + prompt_data = ProcessedData( documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])], - embeddings=np.array([[0.3, 0.4]]) + embeddings=np.array([[0.3, 0.4]]), ) - + processor = DataProcessor() - all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data) - + all_embeddings, documents, prompts = processor.combine_data( + doc_data, prompt_data + ) + assert all_embeddings.shape == (2, 2) assert len(documents) == 1 assert len(prompts) == 1 @@ -70,4 +72,4 @@ class TestDataProcessor: if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/test_reducers.py b/tests/test_reducers.py index ff8dba1..6959541 100644 --- a/tests/test_reducers.py +++ b/tests/test_reducers.py @@ -1,89 +1,90 @@ import pytest import numpy as np -from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer +from src.embeddingbuddy.models.reducers import ( + ReducerFactory, + PCAReducer, + TSNEReducer, + UMAPReducer, +) class TestReducerFactory: - def test_create_pca_reducer(self): - reducer = ReducerFactory.create_reducer('pca', n_components=2) + reducer = ReducerFactory.create_reducer("pca", n_components=2) assert isinstance(reducer, PCAReducer) assert reducer.n_components == 2 - + def test_create_tsne_reducer(self): - reducer = ReducerFactory.create_reducer('tsne', n_components=3) + reducer = ReducerFactory.create_reducer("tsne", n_components=3) assert isinstance(reducer, TSNEReducer) assert reducer.n_components == 3 - + def test_create_umap_reducer(self): - reducer = ReducerFactory.create_reducer('umap', n_components=2) + reducer = ReducerFactory.create_reducer("umap", n_components=2) assert isinstance(reducer, UMAPReducer) assert reducer.n_components == 2 - + def test_invalid_method(self): with pytest.raises(ValueError, match="Unknown reduction method"): - ReducerFactory.create_reducer('invalid_method') - + ReducerFactory.create_reducer("invalid_method") + def test_available_methods(self): methods = ReducerFactory.get_available_methods() - assert 'pca' in methods - assert 'tsne' in methods - assert 'umap' in methods + assert "pca" in methods + assert "tsne" in methods + assert "umap" in methods class TestPCAReducer: - def test_fit_transform(self): embeddings = np.random.rand(100, 512) reducer = PCAReducer(n_components=2) - + result = reducer.fit_transform(embeddings) - + assert result.reduced_embeddings.shape == (100, 2) assert result.variance_explained is not None assert result.method == "PCA" assert result.n_components == 2 - + def test_method_name(self): reducer = PCAReducer() assert reducer.get_method_name() == "PCA" class TestTSNEReducer: - def test_fit_transform_small_dataset(self): embeddings = np.random.rand(30, 10) # Small dataset for faster testing reducer = TSNEReducer(n_components=2) - + result = reducer.fit_transform(embeddings) - + assert result.reduced_embeddings.shape == (30, 2) assert result.variance_explained is None # t-SNE doesn't provide this assert result.method == "t-SNE" assert result.n_components == 2 - + def test_method_name(self): reducer = TSNEReducer() assert reducer.get_method_name() == "t-SNE" class TestUMAPReducer: - def test_fit_transform(self): embeddings = np.random.rand(50, 10) reducer = UMAPReducer(n_components=2) - + result = reducer.fit_transform(embeddings) - + assert result.reduced_embeddings.shape == (50, 2) assert result.variance_explained is None # UMAP doesn't provide this assert result.method == "UMAP" assert result.n_components == 2 - + def test_method_name(self): reducer = UMAPReducer() assert reducer.get_method_name() == "UMAP" if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__])