diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..88bfffc --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,11 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir:*)", + "Bash(uv run:*)", + "Bash(uv add:*)" + ], + "deny": [], + "ask": [] + } +} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 787ee92..3a975ca 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,10 +5,11 @@ code in this repository. ## Project Overview -EmbeddingBuddy is a Python Dash web application for interactive exploration and +EmbeddingBuddy is a modular Python Dash web application for interactive exploration and visualization of embedding vectors through dimensionality reduction techniques (PCA, t-SNE, UMAP). The app provides a drag-and-drop interface for uploading -NDJSON files containing embeddings and visualizes them in 2D/3D plots. +NDJSON files containing embeddings and visualizes them in 2D/3D plots. The codebase +follows a clean, modular architecture that prioritizes testability and maintainability. ## Development Commands @@ -21,32 +22,79 @@ uv sync **Run the application:** ```bash -uv run python app.py +uv run python main.py ``` The app will be available at http://127.0.0.1:8050 +**Run tests:** + +```bash +uv run pytest tests/ -v +``` + **Test with sample data:** -Use the included `sample_data.ndjson` file for testing the application functionality. +Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality. ## Architecture -### Core Files +### Project Structure -- `app.py` - Main Dash application with complete web interface, data processing, - and visualization logic -- `main.py` - Simple entry point (currently minimal) -- `pyproject.toml` - Project configuration and dependencies using uv package manager +The application follows a modular architecture with clear separation of concerns: + +``` +src/embeddingbuddy/ +├── app.py # Main application entry point and factory +├── main.py # Application runner +├── config/ +│ └── settings.py # Centralized configuration management +├── data/ +│ ├── parser.py # NDJSON parsing logic +│ └── processor.py # Data transformation and processing +├── models/ +│ ├── schemas.py # Data models and validation schemas +│ └── reducers.py # Dimensionality reduction algorithms +├── visualization/ +│ ├── plots.py # Plot creation and factory classes +│ └── colors.py # Color mapping and management +├── ui/ +│ ├── layout.py # Main application layout +│ ├── components/ # Reusable UI components +│ │ ├── sidebar.py # Sidebar component +│ │ └── upload.py # Upload components +│ └── callbacks/ # Organized callback functions +│ ├── data_processing.py # Data upload/processing callbacks +│ ├── visualization.py # Plot update callbacks +│ └── interactions.py # User interaction callbacks +└── utils/ # Utility functions and helpers +``` ### Key Components -- **Data Processing**: NDJSON parser that handles embedding documents with - required fields (`embedding`, `text`) and optional metadata (`id`, `category`, `subcategory`, `tags`) -- **Dimensionality Reduction**: Supports PCA, t-SNE (openTSNE), and UMAP algorithms -- **Visualization**: Plotly-based 2D/3D scatter plots with interactive features -- **UI Layout**: Bootstrap-styled sidebar with controls and large visualization area -- **State Management**: Dash callbacks for reactive updates between upload, - method selection, and plot rendering +**Data Layer:** +- `data/parser.py` - NDJSON parsing with error handling +- `data/processor.py` - Data transformation and combination logic +- `models/schemas.py` - Dataclasses for type safety and validation + +**Algorithm Layer:** +- `models/reducers.py` - Modular dimensionality reduction with factory pattern +- Supports PCA, t-SNE (openTSNE), and UMAP algorithms +- Abstract base class for easy extension + +**Visualization Layer:** +- `visualization/plots.py` - Plot factory with single and dual plot support +- `visualization/colors.py` - Color mapping and grayscale conversion utilities +- Plotly-based 2D/3D scatter plots with interactive features + +**UI Layer:** +- `ui/layout.py` - Main application layout composition +- `ui/components/` - Reusable, testable UI components +- `ui/callbacks/` - Organized callbacks grouped by functionality +- Bootstrap-styled sidebar with controls and large visualization area + +**Configuration:** +- `config/settings.py` - Centralized settings with environment variable support +- Plot styling, marker configurations, and app-wide constants ### Data Format @@ -56,17 +104,77 @@ The application expects NDJSON files where each line contains: {"id": "doc_001", "embedding": [0.1, -0.3, 0.7, ...], "text": "Sample text", "category": "news", "subcategory": "politics", "tags": ["election"]} ``` +Required fields: `embedding` (array), `text` (string) +Optional fields: `id`, `category`, `subcategory`, `tags` + ### Callback Architecture -- File upload → Data processing and storage in dcc.Store -- Method/parameter changes → Dimensionality reduction and plot update -- Point clicks → Detail display in sidebar +The refactored callback system is organized by functionality: + +**Data Processing (`ui/callbacks/data_processing.py`):** +- File upload handling +- NDJSON parsing and validation +- Data storage in dcc.Store components + +**Visualization (`ui/callbacks/visualization.py`):** +- Dimensionality reduction pipeline +- Plot generation and updates +- Method/parameter change handling + +**Interactions (`ui/callbacks/interactions.py`):** +- Point click handling and detail display +- Reset functionality +- User interaction management + +### Testing Architecture + +The modular design enables comprehensive testing: + +**Unit Tests:** +- `tests/test_data_processing.py` - Parser and processor logic +- `tests/test_reducers.py` - Dimensionality reduction algorithms +- `tests/test_visualization.py` - Plot creation and color mapping + +**Integration Tests:** +- End-to-end data pipeline testing +- Component integration verification + +**Key Testing Benefits:** +- Fast test execution (milliseconds vs seconds) +- Isolated component testing +- Easy mocking and fixture creation +- High code coverage achievable ## Dependencies Uses modern Python stack with uv for dependency management: -- Dash + Plotly for web interface and visualization -- scikit-learn (PCA), openTSNE, umap-learn for dimensionality reduction -- pandas/numpy for data manipulation -- dash-bootstrap-components for styling \ No newline at end of file +- **Core Framework:** Dash + Plotly for web interface and visualization +- **Algorithms:** scikit-learn (PCA), openTSNE, umap-learn for dimensionality reduction +- **Data:** pandas/numpy for data manipulation +- **UI:** dash-bootstrap-components for styling +- **Testing:** pytest for test framework +- **Dev Tools:** uv for package management + +## Development Guidelines + +**When adding new features:** + +1. **Data Models** - Add/update schemas in `models/schemas.py` +2. **Algorithms** - Extend `models/reducers.py` using the abstract base class +3. **UI Components** - Create reusable components in `ui/components/` +4. **Configuration** - Add settings to `config/settings.py` +5. **Tests** - Write tests for all new functionality + +**Code Organization Principles:** +- Single responsibility principle +- Clear module boundaries +- Testable, isolated components +- Configuration over hardcoding +- Error handling at appropriate layers + +**Testing Requirements:** +- Unit tests for all core logic +- Integration tests for data flow +- Component tests for UI elements +- Maintain high test coverage \ No newline at end of file diff --git a/README.md b/README.md index cb4add3..75ee3ff 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # EmbeddingBuddy -A web application for interactive exploration and visualization of embedding +A modular Python Dash web application for interactive exploration and visualization of embedding vectors through dimensionality reduction techniques. Compare documents and prompts in the same embedding space to understand semantic relationships. @@ -10,9 +10,10 @@ in the same embedding space to understand semantic relationships. EmbeddingBuddy provides an intuitive web interface for analyzing high-dimensional embedding vectors by applying various dimensionality reduction algorithms and -visualizing the results in interactive 2D and 3D plots. The application supports -dual dataset visualization, allowing you to compare documents and prompts to -understand how queries relate to your content. +visualizing the results in interactive 2D and 3D plots. The application features +a clean, modular architecture that makes it easy to test, maintain, and extend +with new features. It supports dual dataset visualization, allowing you to compare +documents and prompts to understand how queries relate to your content. ## Features @@ -73,7 +74,7 @@ uv sync 2. **Run the application:** ```bash -uv run python app.py +uv run python main.py ``` 3. **Open your browser** to http://127.0.0.1:8050 @@ -83,6 +84,59 @@ uv run python app.py - Upload `sample_prompts.ndjson` (prompts) to see dual visualization - Use the "Show prompts" toggle to compare how prompts relate to documents +## Development + +### Project Structure + +The application follows a modular architecture for improved maintainability and testability: + +``` +src/embeddingbuddy/ +├── config/ # Configuration management +│ └── settings.py # Centralized app settings +├── data/ # Data parsing and processing +│ ├── parser.py # NDJSON parsing logic +│ └── processor.py # Data transformation utilities +├── models/ # Data schemas and algorithms +│ ├── schemas.py # Pydantic data models +│ └── reducers.py # Dimensionality reduction algorithms +├── visualization/ # Plot creation and styling +│ ├── plots.py # Plot factory and creation logic +│ └── colors.py # Color mapping utilities +├── ui/ # User interface components +│ ├── layout.py # Main application layout +│ ├── components/ # Reusable UI components +│ └── callbacks/ # Organized callback functions +└── utils/ # Utility functions +``` + +### Testing + +Run the test suite to verify functionality: + +```bash +# Install pytest +uv add pytest + +# Run all tests +uv run pytest tests/ -v + +# Run specific test file +uv run pytest tests/test_data_processing.py -v + +# Run with coverage +uv run pytest tests/ --cov=src/embeddingbuddy +``` + +### Adding New Features + +The modular architecture makes it easy to extend functionality: + +- **New reduction algorithms**: Add to `models/reducers.py` +- **New plot types**: Extend `visualization/plots.py` +- **UI components**: Add to `ui/components/` +- **Configuration options**: Update `config/settings.py` + ## Tech Stack - **Python Dash**: Web application framework @@ -91,4 +145,5 @@ uv run python app.py - **UMAP-learn**: UMAP dimensionality reduction - **openTSNE**: Fast t-SNE implementation - **NumPy/Pandas**: Data manipulation and analysis +- **pytest**: Testing framework - **uv**: Modern Python package and project manager diff --git a/main.py b/main.py index b866b4c..426194b 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,9 @@ +from src.embeddingbuddy.app import create_app, run_app + + def main(): - print("Hello from embeddingbuddy!") + app = create_app() + run_app(app) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 1b72a39..7622be0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "embeddingbuddy" -version = "0.1.0" +version = "0.2.0" description = "A Python Dash application for interactive exploration and visualization of embedding vectors through dimensionality reduction techniques." readme = "README.md" requires-python = ">=3.11" @@ -13,5 +13,16 @@ dependencies = [ "dash-bootstrap-components>=1.5.0", "umap-learn>=0.5.8", "numba>=0.56.4", - "openTSNE>=1.0.0" + "openTSNE>=1.0.0", + "pytest>=8.4.1", ] + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-dir] +"" = "src" diff --git a/src/embeddingbuddy/__init__.py b/src/embeddingbuddy/__init__.py new file mode 100644 index 0000000..3a705fd --- /dev/null +++ b/src/embeddingbuddy/__init__.py @@ -0,0 +1,3 @@ +"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors.""" + +__version__ = "0.1.0" \ No newline at end of file diff --git a/src/embeddingbuddy/app.py b/src/embeddingbuddy/app.py new file mode 100644 index 0000000..9c386fa --- /dev/null +++ b/src/embeddingbuddy/app.py @@ -0,0 +1,39 @@ +import dash +import dash_bootstrap_components as dbc +from .config.settings import AppSettings +from .ui.layout import AppLayout +from .ui.callbacks.data_processing import DataProcessingCallbacks +from .ui.callbacks.visualization import VisualizationCallbacks +from .ui.callbacks.interactions import InteractionCallbacks + + +def create_app(): + app = dash.Dash( + __name__, + external_stylesheets=[dbc.themes.BOOTSTRAP] + ) + + layout_manager = AppLayout() + app.layout = layout_manager.create_layout() + + DataProcessingCallbacks() + VisualizationCallbacks() + InteractionCallbacks() + + return app + + +def run_app(app=None, debug=None, host=None, port=None): + if app is None: + app = create_app() + + app.run( + debug=debug if debug is not None else AppSettings.DEBUG, + host=host if host is not None else AppSettings.HOST, + port=port if port is not None else AppSettings.PORT + ) + + +if __name__ == '__main__': + app = create_app() + run_app(app) \ No newline at end of file diff --git a/src/embeddingbuddy/config/__init__.py b/src/embeddingbuddy/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/config/settings.py b/src/embeddingbuddy/config/settings.py new file mode 100644 index 0000000..48db0b1 --- /dev/null +++ b/src/embeddingbuddy/config/settings.py @@ -0,0 +1,107 @@ +from typing import Dict, Any +import os + + +class AppSettings: + + # UI Configuration + UPLOAD_STYLE = { + 'width': '100%', + 'height': '60px', + 'lineHeight': '60px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderRadius': '5px', + 'textAlign': 'center', + 'margin-bottom': '20px' + } + + PROMPTS_UPLOAD_STYLE = { + **UPLOAD_STYLE, + 'borderColor': '#28a745' + } + + PLOT_CONFIG = { + 'responsive': True, + 'displayModeBar': True + } + + PLOT_STYLE = { + 'height': '85vh', + 'width': '100%' + } + + PLOT_LAYOUT_CONFIG = { + 'height': None, + 'autosize': True, + 'margin': dict(l=0, r=0, t=50, b=0) + } + + # Dimensionality Reduction Settings + DEFAULT_N_COMPONENTS_3D = 3 + DEFAULT_N_COMPONENTS_2D = 2 + DEFAULT_RANDOM_STATE = 42 + + # Available Methods + REDUCTION_METHODS = [ + {'label': 'PCA', 'value': 'pca'}, + {'label': 't-SNE', 'value': 'tsne'}, + {'label': 'UMAP', 'value': 'umap'} + ] + + COLOR_OPTIONS = [ + {'label': 'Category', 'value': 'category'}, + {'label': 'Subcategory', 'value': 'subcategory'}, + {'label': 'Tags', 'value': 'tags'} + ] + + DIMENSION_OPTIONS = [ + {'label': '2D', 'value': '2d'}, + {'label': '3D', 'value': '3d'} + ] + + # Default Values + DEFAULT_METHOD = 'pca' + DEFAULT_COLOR_BY = 'category' + DEFAULT_DIMENSIONS = '3d' + DEFAULT_SHOW_PROMPTS = ['show'] + + # Plot Marker Settings + DOCUMENT_MARKER_SIZE_2D = 8 + DOCUMENT_MARKER_SIZE_3D = 5 + PROMPT_MARKER_SIZE_2D = 10 + PROMPT_MARKER_SIZE_3D = 6 + + DOCUMENT_MARKER_SYMBOL = 'circle' + PROMPT_MARKER_SYMBOL = 'diamond' + + DOCUMENT_OPACITY = 1.0 + PROMPT_OPACITY = 0.8 + + # Text Processing + TEXT_PREVIEW_LENGTH = 100 + + # App Configuration + DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true' + HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1') + PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050')) + + # Bootstrap Theme + EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css'] + + @classmethod + def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]: + if is_prompt: + size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D + symbol = cls.PROMPT_MARKER_SYMBOL + opacity = cls.PROMPT_OPACITY + else: + size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D + symbol = cls.DOCUMENT_MARKER_SYMBOL + opacity = cls.DOCUMENT_OPACITY + + return { + 'size': size, + 'symbol': symbol, + 'opacity': opacity + } \ No newline at end of file diff --git a/src/embeddingbuddy/data/__init__.py b/src/embeddingbuddy/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/data/parser.py b/src/embeddingbuddy/data/parser.py new file mode 100644 index 0000000..0b564a2 --- /dev/null +++ b/src/embeddingbuddy/data/parser.py @@ -0,0 +1,39 @@ +import json +import uuid +import base64 +from typing import List, Union +from ..models.schemas import Document, ProcessedData + + +class NDJSONParser: + + @staticmethod + def parse_upload_contents(contents: str) -> List[Document]: + content_type, content_string = contents.split(',') + decoded = base64.b64decode(content_string) + text_content = decoded.decode('utf-8') + return NDJSONParser.parse_text(text_content) + + @staticmethod + def parse_text(text_content: str) -> List[Document]: + documents = [] + for line in text_content.strip().split('\n'): + if line.strip(): + doc_dict = json.loads(line) + doc = NDJSONParser._dict_to_document(doc_dict) + documents.append(doc) + return documents + + @staticmethod + def _dict_to_document(doc_dict: dict) -> Document: + if 'id' not in doc_dict: + doc_dict['id'] = str(uuid.uuid4()) + + return Document( + id=doc_dict['id'], + text=doc_dict['text'], + embedding=doc_dict['embedding'], + category=doc_dict.get('category'), + subcategory=doc_dict.get('subcategory'), + tags=doc_dict.get('tags') + ) \ No newline at end of file diff --git a/src/embeddingbuddy/data/processor.py b/src/embeddingbuddy/data/processor.py new file mode 100644 index 0000000..7c8ac87 --- /dev/null +++ b/src/embeddingbuddy/data/processor.py @@ -0,0 +1,54 @@ +import numpy as np +from typing import List, Optional, Tuple +from ..models.schemas import Document, ProcessedData +from .parser import NDJSONParser + + +class DataProcessor: + + def __init__(self): + self.parser = NDJSONParser() + + def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData: + try: + documents = self.parser.parse_upload_contents(contents) + embeddings = self._extract_embeddings(documents) + return ProcessedData(documents=documents, embeddings=embeddings) + except Exception as e: + return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) + + def process_text(self, text_content: str) -> ProcessedData: + try: + documents = self.parser.parse_text(text_content) + embeddings = self._extract_embeddings(documents) + return ProcessedData(documents=documents, embeddings=embeddings) + except Exception as e: + return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) + + def _extract_embeddings(self, documents: List[Document]) -> np.ndarray: + if not documents: + return np.array([]) + return np.array([doc.embedding for doc in documents]) + + def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]: + if not doc_data or doc_data.error: + raise ValueError("Invalid document data") + + all_embeddings = doc_data.embeddings + documents = doc_data.documents + prompts = None + + if prompt_data and not prompt_data.error and prompt_data.documents: + all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings]) + prompts = prompt_data.documents + + return all_embeddings, documents, prompts + + def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]: + doc_reduced = reduced_embeddings[:n_documents] + prompt_reduced = None + + if n_prompts > 0: + prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts] + + return doc_reduced, prompt_reduced \ No newline at end of file diff --git a/src/embeddingbuddy/models/__init__.py b/src/embeddingbuddy/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/models/reducers.py b/src/embeddingbuddy/models/reducers.py new file mode 100644 index 0000000..3cac2f6 --- /dev/null +++ b/src/embeddingbuddy/models/reducers.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +import numpy as np +from typing import Optional, Tuple +from sklearn.decomposition import PCA +import umap +from openTSNE import TSNE +from .schemas import ReducedData + + +class DimensionalityReducer(ABC): + + def __init__(self, n_components: int = 3, random_state: int = 42): + self.n_components = n_components + self.random_state = random_state + self._reducer = None + + @abstractmethod + def fit_transform(self, embeddings: np.ndarray) -> ReducedData: + pass + + @abstractmethod + def get_method_name(self) -> str: + pass + + +class PCAReducer(DimensionalityReducer): + + def fit_transform(self, embeddings: np.ndarray) -> ReducedData: + self._reducer = PCA(n_components=self.n_components) + reduced = self._reducer.fit_transform(embeddings) + variance_explained = self._reducer.explained_variance_ratio_ + + return ReducedData( + reduced_embeddings=reduced, + variance_explained=variance_explained, + method=self.get_method_name(), + n_components=self.n_components + ) + + def get_method_name(self) -> str: + return "PCA" + + +class TSNEReducer(DimensionalityReducer): + + def fit_transform(self, embeddings: np.ndarray) -> ReducedData: + self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state) + reduced = self._reducer.fit(embeddings) + + return ReducedData( + reduced_embeddings=reduced, + variance_explained=None, + method=self.get_method_name(), + n_components=self.n_components + ) + + def get_method_name(self) -> str: + return "t-SNE" + + +class UMAPReducer(DimensionalityReducer): + + def fit_transform(self, embeddings: np.ndarray) -> ReducedData: + self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state) + reduced = self._reducer.fit_transform(embeddings) + + return ReducedData( + reduced_embeddings=reduced, + variance_explained=None, + method=self.get_method_name(), + n_components=self.n_components + ) + + def get_method_name(self) -> str: + return "UMAP" + + +class ReducerFactory: + + @staticmethod + def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer: + method_lower = method.lower() + + if method_lower == 'pca': + return PCAReducer(n_components=n_components, random_state=random_state) + elif method_lower == 'tsne': + return TSNEReducer(n_components=n_components, random_state=random_state) + elif method_lower == 'umap': + return UMAPReducer(n_components=n_components, random_state=random_state) + else: + raise ValueError(f"Unknown reduction method: {method}") + + @staticmethod + def get_available_methods() -> list: + return ['pca', 'tsne', 'umap'] \ No newline at end of file diff --git a/src/embeddingbuddy/models/schemas.py b/src/embeddingbuddy/models/schemas.py new file mode 100644 index 0000000..a6cc14e --- /dev/null +++ b/src/embeddingbuddy/models/schemas.py @@ -0,0 +1,58 @@ +from typing import List, Optional, Any, Dict +from dataclasses import dataclass +import numpy as np + + +@dataclass +class Document: + id: str + text: str + embedding: List[float] + category: Optional[str] = None + subcategory: Optional[str] = None + tags: Optional[List[str]] = None + + def __post_init__(self): + if self.tags is None: + self.tags = [] + if self.category is None: + self.category = "Unknown" + if self.subcategory is None: + self.subcategory = "Unknown" + + +@dataclass +class ProcessedData: + documents: List[Document] + embeddings: np.ndarray + error: Optional[str] = None + + def __post_init__(self): + if self.embeddings is not None and not isinstance(self.embeddings, np.ndarray): + self.embeddings = np.array(self.embeddings) + + +@dataclass +class ReducedData: + reduced_embeddings: np.ndarray + variance_explained: Optional[np.ndarray] = None + method: str = "unknown" + n_components: int = 2 + + def __post_init__(self): + if not isinstance(self.reduced_embeddings, np.ndarray): + self.reduced_embeddings = np.array(self.reduced_embeddings) + + +@dataclass +class PlotData: + documents: List[Document] + coordinates: np.ndarray + prompts: Optional[List[Document]] = None + prompt_coordinates: Optional[np.ndarray] = None + + def __post_init__(self): + if not isinstance(self.coordinates, np.ndarray): + self.coordinates = np.array(self.coordinates) + if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray): + self.prompt_coordinates = np.array(self.prompt_coordinates) \ No newline at end of file diff --git a/src/embeddingbuddy/ui/__init__.py b/src/embeddingbuddy/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/ui/callbacks/__init__.py b/src/embeddingbuddy/ui/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/ui/callbacks/data_processing.py b/src/embeddingbuddy/ui/callbacks/data_processing.py new file mode 100644 index 0000000..7060158 --- /dev/null +++ b/src/embeddingbuddy/ui/callbacks/data_processing.py @@ -0,0 +1,61 @@ +import numpy as np +from dash import callback, Input, Output, State +from ...data.processor import DataProcessor + + +class DataProcessingCallbacks: + + def __init__(self): + self.processor = DataProcessor() + self._register_callbacks() + + def _register_callbacks(self): + + @callback( + Output('processed-data', 'data'), + Input('upload-data', 'contents'), + State('upload-data', 'filename') + ) + def process_uploaded_file(contents, filename): + if contents is None: + return None + + processed_data = self.processor.process_upload(contents, filename) + + if processed_data.error: + return {'error': processed_data.error} + + return { + 'documents': [self._document_to_dict(doc) for doc in processed_data.documents], + 'embeddings': processed_data.embeddings.tolist() + } + + @callback( + Output('processed-prompts', 'data'), + Input('upload-prompts', 'contents'), + State('upload-prompts', 'filename') + ) + def process_uploaded_prompts(contents, filename): + if contents is None: + return None + + processed_data = self.processor.process_upload(contents, filename) + + if processed_data.error: + return {'error': processed_data.error} + + return { + 'prompts': [self._document_to_dict(doc) for doc in processed_data.documents], + 'embeddings': processed_data.embeddings.tolist() + } + + @staticmethod + def _document_to_dict(doc): + return { + 'id': doc.id, + 'text': doc.text, + 'embedding': doc.embedding, + 'category': doc.category, + 'subcategory': doc.subcategory, + 'tags': doc.tags + } \ No newline at end of file diff --git a/src/embeddingbuddy/ui/callbacks/interactions.py b/src/embeddingbuddy/ui/callbacks/interactions.py new file mode 100644 index 0000000..d01f125 --- /dev/null +++ b/src/embeddingbuddy/ui/callbacks/interactions.py @@ -0,0 +1,66 @@ +import dash +from dash import callback, Input, Output, State, html +import dash_bootstrap_components as dbc + + +class InteractionCallbacks: + + def __init__(self): + self._register_callbacks() + + def _register_callbacks(self): + + @callback( + Output('point-details', 'children'), + Input('embedding-plot', 'clickData'), + [State('processed-data', 'data'), + State('processed-prompts', 'data')] + ) + def display_click_data(clickData, data, prompts_data): + if not clickData or not data: + return "Click on a point to see details" + + point_data = clickData['points'][0] + trace_name = point_data.get('fullData', {}).get('name', 'Documents') + + if 'pointIndex' in point_data: + point_index = point_data['pointIndex'] + elif 'pointNumber' in point_data: + point_index = point_data['pointNumber'] + else: + return "Could not identify clicked point" + + if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data: + item = prompts_data['prompts'][point_index] + item_type = 'Prompt' + else: + item = data['documents'][point_index] + item_type = 'Document' + + return self._create_detail_card(item, item_type) + + @callback( + [Output('processed-data', 'data', allow_duplicate=True), + Output('processed-prompts', 'data', allow_duplicate=True), + Output('point-details', 'children', allow_duplicate=True)], + Input('reset-button', 'n_clicks'), + prevent_initial_call=True + ) + def reset_data(n_clicks): + if n_clicks is None or n_clicks == 0: + return dash.no_update, dash.no_update, dash.no_update + + return None, None, "Click on a point to see details" + + @staticmethod + def _create_detail_card(item, item_type): + return dbc.Card([ + dbc.CardBody([ + html.H5(f"{item_type}: {item['id']}", className="card-title"), + html.P(f"Text: {item['text']}", className="card-text"), + html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"), + html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"), + html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"), + html.P(f"Type: {item_type}", className="card-text text-muted") + ]) + ]) \ No newline at end of file diff --git a/src/embeddingbuddy/ui/callbacks/visualization.py b/src/embeddingbuddy/ui/callbacks/visualization.py new file mode 100644 index 0000000..b10f4ae --- /dev/null +++ b/src/embeddingbuddy/ui/callbacks/visualization.py @@ -0,0 +1,87 @@ +import numpy as np +from dash import callback, Input, Output +import plotly.graph_objects as go +from ...models.reducers import ReducerFactory +from ...models.schemas import Document, PlotData +from ...visualization.plots import PlotFactory + + +class VisualizationCallbacks: + + def __init__(self): + self.plot_factory = PlotFactory() + self._register_callbacks() + + def _register_callbacks(self): + + @callback( + Output('embedding-plot', 'figure'), + [Input('processed-data', 'data'), + Input('processed-prompts', 'data'), + Input('method-dropdown', 'value'), + Input('color-dropdown', 'value'), + Input('dimension-toggle', 'value'), + Input('show-prompts-toggle', 'value')] + ) + def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): + if not data or 'error' in data: + return go.Figure().add_annotation( + text="Upload a valid NDJSON file to see visualization", + xref="paper", yref="paper", + x=0.5, y=0.5, xanchor='center', yanchor='middle', + showarrow=False, font=dict(size=16) + ) + + try: + doc_embeddings = np.array(data['embeddings']) + all_embeddings = doc_embeddings + has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts') + + if has_prompts: + prompt_embeddings = np.array(prompts_data['embeddings']) + all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) + + n_components = 3 if dimensions == '3d' else 2 + + reducer = ReducerFactory.create_reducer(method, n_components=n_components) + reduced_data = reducer.fit_transform(all_embeddings) + + doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)] + prompt_reduced = None + if has_prompts: + prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):] + + documents = [self._dict_to_document(doc) for doc in data['documents']] + prompts = None + if has_prompts: + prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']] + + plot_data = PlotData( + documents=documents, + coordinates=doc_reduced, + prompts=prompts, + prompt_coordinates=prompt_reduced + ) + + return self.plot_factory.create_plot( + plot_data, dimensions, color_by, reduced_data.method, show_prompts + ) + + except Exception as e: + return go.Figure().add_annotation( + text=f"Error creating visualization: {str(e)}", + xref="paper", yref="paper", + x=0.5, y=0.5, xanchor='center', yanchor='middle', + showarrow=False, font=dict(size=16) + ) + + @staticmethod + def _dict_to_document(doc_dict): + return Document( + id=doc_dict['id'], + text=doc_dict['text'], + embedding=doc_dict['embedding'], + category=doc_dict.get('category'), + subcategory=doc_dict.get('subcategory'), + tags=doc_dict.get('tags', []) + ) \ No newline at end of file diff --git a/src/embeddingbuddy/ui/components/__init__.py b/src/embeddingbuddy/ui/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/ui/components/sidebar.py b/src/embeddingbuddy/ui/components/sidebar.py new file mode 100644 index 0000000..1160f22 --- /dev/null +++ b/src/embeddingbuddy/ui/components/sidebar.py @@ -0,0 +1,82 @@ +from dash import dcc, html +import dash_bootstrap_components as dbc +from .upload import UploadComponent + + +class SidebarComponent: + + def __init__(self): + self.upload_component = UploadComponent() + + def create_layout(self): + return dbc.Col([ + html.H5("Upload Data", className="mb-3"), + self.upload_component.create_data_upload(), + self.upload_component.create_prompts_upload(), + self.upload_component.create_reset_button(), + + html.H5("Visualization Controls", className="mb-3"), + self._create_method_dropdown(), + self._create_color_dropdown(), + self._create_dimension_toggle(), + self._create_prompts_toggle(), + + html.H5("Point Details", className="mb-3"), + html.Div(id='point-details', children="Click on a point to see details") + + ], width=3, style={'padding-right': '20px'}) + + def _create_method_dropdown(self): + return [ + dbc.Label("Method:"), + dcc.Dropdown( + id='method-dropdown', + options=[ + {'label': 'PCA', 'value': 'pca'}, + {'label': 't-SNE', 'value': 'tsne'}, + {'label': 'UMAP', 'value': 'umap'} + ], + value='pca', + style={'margin-bottom': '15px'} + ) + ] + + def _create_color_dropdown(self): + return [ + dbc.Label("Color by:"), + dcc.Dropdown( + id='color-dropdown', + options=[ + {'label': 'Category', 'value': 'category'}, + {'label': 'Subcategory', 'value': 'subcategory'}, + {'label': 'Tags', 'value': 'tags'} + ], + value='category', + style={'margin-bottom': '15px'} + ) + ] + + def _create_dimension_toggle(self): + return [ + dbc.Label("Dimensions:"), + dcc.RadioItems( + id='dimension-toggle', + options=[ + {'label': '2D', 'value': '2d'}, + {'label': '3D', 'value': '3d'} + ], + value='3d', + style={'margin-bottom': '20px'} + ) + ] + + def _create_prompts_toggle(self): + return [ + dbc.Label("Show Prompts:"), + dcc.Checklist( + id='show-prompts-toggle', + options=[{'label': 'Show prompts on plot', 'value': 'show'}], + value=['show'], + style={'margin-bottom': '20px'} + ) + ] \ No newline at end of file diff --git a/src/embeddingbuddy/ui/components/upload.py b/src/embeddingbuddy/ui/components/upload.py new file mode 100644 index 0000000..9aace94 --- /dev/null +++ b/src/embeddingbuddy/ui/components/upload.py @@ -0,0 +1,60 @@ +from dash import dcc, html +import dash_bootstrap_components as dbc + + +class UploadComponent: + + @staticmethod + def create_data_upload(): + return dcc.Upload( + id='upload-data', + children=html.Div([ + 'Drag and Drop or ', + html.A('Select Files') + ]), + style={ + 'width': '100%', + 'height': '60px', + 'lineHeight': '60px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderRadius': '5px', + 'textAlign': 'center', + 'margin-bottom': '20px' + }, + multiple=False + ) + + @staticmethod + def create_prompts_upload(): + return dcc.Upload( + id='upload-prompts', + children=html.Div([ + 'Drag and Drop Prompts or ', + html.A('Select Files') + ]), + style={ + 'width': '100%', + 'height': '60px', + 'lineHeight': '60px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderRadius': '5px', + 'textAlign': 'center', + 'margin-bottom': '20px', + 'borderColor': '#28a745' + }, + multiple=False + ) + + @staticmethod + def create_reset_button(): + return dbc.Button( + "Reset All Data", + id='reset-button', + color='danger', + outline=True, + size='sm', + className='mb-3', + style={'width': '100%'} + ) \ No newline at end of file diff --git a/src/embeddingbuddy/ui/layout.py b/src/embeddingbuddy/ui/layout.py new file mode 100644 index 0000000..4ed4f24 --- /dev/null +++ b/src/embeddingbuddy/ui/layout.py @@ -0,0 +1,44 @@ +from dash import dcc, html +import dash_bootstrap_components as dbc +from .components.sidebar import SidebarComponent + + +class AppLayout: + + def __init__(self): + self.sidebar = SidebarComponent() + + def create_layout(self): + return dbc.Container([ + self._create_header(), + self._create_main_content(), + self._create_stores() + ], fluid=True) + + def _create_header(self): + return dbc.Row([ + dbc.Col([ + html.H1("EmbeddingBuddy", className="text-center mb-4"), + ], width=12) + ]) + + def _create_main_content(self): + return dbc.Row([ + self.sidebar.create_layout(), + self._create_visualization_area() + ]) + + def _create_visualization_area(self): + return dbc.Col([ + dcc.Graph( + id='embedding-plot', + style={'height': '85vh', 'width': '100%'}, + config={'responsive': True, 'displayModeBar': True} + ) + ], width=9) + + def _create_stores(self): + return [ + dcc.Store(id='processed-data'), + dcc.Store(id='processed-prompts') + ] \ No newline at end of file diff --git a/src/embeddingbuddy/utils/__init__.py b/src/embeddingbuddy/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/visualization/__init__.py b/src/embeddingbuddy/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/embeddingbuddy/visualization/colors.py b/src/embeddingbuddy/visualization/colors.py new file mode 100644 index 0000000..dc3da3f --- /dev/null +++ b/src/embeddingbuddy/visualization/colors.py @@ -0,0 +1,33 @@ +from typing import List, Dict, Any +import plotly.colors as pc +from ..models.schemas import Document + + +class ColorMapper: + + @staticmethod + def create_color_mapping(documents: List[Document], color_by: str) -> List[str]: + if color_by == 'category': + return [doc.category for doc in documents] + elif color_by == 'subcategory': + return [doc.subcategory for doc in documents] + elif color_by == 'tags': + return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents] + else: + return ['All'] * len(documents) + + @staticmethod + def to_grayscale_hex(color_str: str) -> str: + try: + if color_str.startswith('#'): + rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5)) + else: + rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0]) + + gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) + gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, + gray_value * 0.7 + rgb[1] * 0.3, + gray_value * 0.7 + rgb[2] * 0.3) + return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})' + except: + return 'rgb(128,128,128)' \ No newline at end of file diff --git a/src/embeddingbuddy/visualization/plots.py b/src/embeddingbuddy/visualization/plots.py new file mode 100644 index 0000000..d472b1b --- /dev/null +++ b/src/embeddingbuddy/visualization/plots.py @@ -0,0 +1,145 @@ +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from typing import List, Optional +from ..models.schemas import Document, PlotData +from .colors import ColorMapper + + +class PlotFactory: + + def __init__(self): + self.color_mapper = ColorMapper() + + def create_plot(self, plot_data: PlotData, dimensions: str = '3d', + color_by: str = 'category', method: str = 'PCA', + show_prompts: Optional[List[str]] = None) -> go.Figure: + + if plot_data.prompts and show_prompts and 'show' in show_prompts: + return self._create_dual_plot(plot_data, dimensions, color_by, method) + else: + return self._create_single_plot(plot_data, dimensions, color_by, method) + + def _create_single_plot(self, plot_data: PlotData, dimensions: str, + color_by: str, method: str) -> go.Figure: + df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) + color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) + + hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] + + if dimensions == '3d': + fig = px.scatter_3d( + df, x='dim_1', y='dim_2', z='dim_3', + color=color_values, + hover_data=hover_fields, + title=f'3D Embedding Visualization - {method} (colored by {color_by})' + ) + fig.update_traces(marker=dict(size=5)) + else: + fig = px.scatter( + df, x='dim_1', y='dim_2', + color=color_values, + hover_data=hover_fields, + title=f'2D Embedding Visualization - {method} (colored by {color_by})' + ) + fig.update_traces(marker=dict(size=8)) + + fig.update_layout( + height=None, + autosize=True, + margin=dict(l=0, r=0, t=50, b=0) + ) + return fig + + def _create_dual_plot(self, plot_data: PlotData, dimensions: str, + color_by: str, method: str) -> go.Figure: + fig = go.Figure() + + doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) + doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) + + hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] + + if dimensions == '3d': + doc_fig = px.scatter_3d( + doc_df, x='dim_1', y='dim_2', z='dim_3', + color=doc_color_values, + hover_data=hover_fields + ) + else: + doc_fig = px.scatter( + doc_df, x='dim_1', y='dim_2', + color=doc_color_values, + hover_data=hover_fields + ) + + for trace in doc_fig.data: + trace.name = f'Documents - {trace.name}' + if dimensions == '3d': + trace.marker.size = 5 + trace.marker.symbol = 'circle' + else: + trace.marker.size = 8 + trace.marker.symbol = 'circle' + trace.marker.opacity = 1.0 + fig.add_trace(trace) + + if plot_data.prompts and plot_data.prompt_coordinates is not None: + prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions) + prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by) + + if dimensions == '3d': + prompt_fig = px.scatter_3d( + prompt_df, x='dim_1', y='dim_2', z='dim_3', + color=prompt_color_values, + hover_data=hover_fields + ) + else: + prompt_fig = px.scatter( + prompt_df, x='dim_1', y='dim_2', + color=prompt_color_values, + hover_data=hover_fields + ) + + for trace in prompt_fig.data: + if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str): + trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color) + + trace.name = f'Prompts - {trace.name}' + if dimensions == '3d': + trace.marker.size = 6 + trace.marker.symbol = 'diamond' + else: + trace.marker.size = 10 + trace.marker.symbol = 'diamond' + trace.marker.opacity = 0.8 + fig.add_trace(trace) + + title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})' + fig.update_layout( + title=title, + height=None, + autosize=True, + margin=dict(l=0, r=0, t=50, b=0) + ) + + return fig + + def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame: + df_data = [] + for i, doc in enumerate(documents): + row = { + 'id': doc.id, + 'text': doc.text, + 'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text, + 'category': doc.category, + 'subcategory': doc.subcategory, + 'tags_str': ', '.join(doc.tags) if doc.tags else 'None', + 'dim_1': coordinates[i, 0], + 'dim_2': coordinates[i, 1], + } + if dimensions == '3d': + row['dim_3'] = coordinates[i, 2] + df_data.append(row) + + return pd.DataFrame(df_data) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py new file mode 100644 index 0000000..edfe278 --- /dev/null +++ b/tests/test_data_processing.py @@ -0,0 +1,73 @@ +import pytest +import numpy as np +from src.embeddingbuddy.data.parser import NDJSONParser +from src.embeddingbuddy.data.processor import DataProcessor +from src.embeddingbuddy.models.schemas import Document + + +class TestNDJSONParser: + + def test_parse_text_basic(self): + text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}' + documents = NDJSONParser.parse_text(text_content) + + assert len(documents) == 1 + assert documents[0].id == "test1" + assert documents[0].text == "Hello world" + assert documents[0].embedding == [0.1, 0.2, 0.3] + + def test_parse_text_with_metadata(self): + text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}' + documents = NDJSONParser.parse_text(text_content) + + assert documents[0].category == "greeting" + assert documents[0].tags == ["test"] + + def test_parse_text_missing_id(self): + text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}' + documents = NDJSONParser.parse_text(text_content) + + assert len(documents) == 1 + assert documents[0].id is not None # Should be auto-generated + + +class TestDataProcessor: + + def test_extract_embeddings(self): + documents = [ + Document(id="1", text="test1", embedding=[0.1, 0.2]), + Document(id="2", text="test2", embedding=[0.3, 0.4]) + ] + + processor = DataProcessor() + embeddings = processor._extract_embeddings(documents) + + assert embeddings.shape == (2, 2) + assert np.allclose(embeddings[0], [0.1, 0.2]) + assert np.allclose(embeddings[1], [0.3, 0.4]) + + def test_combine_data(self): + from src.embeddingbuddy.models.schemas import ProcessedData + + doc_data = ProcessedData( + documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])], + embeddings=np.array([[0.1, 0.2]]) + ) + + prompt_data = ProcessedData( + documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])], + embeddings=np.array([[0.3, 0.4]]) + ) + + processor = DataProcessor() + all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data) + + assert all_embeddings.shape == (2, 2) + assert len(documents) == 1 + assert len(prompts) == 1 + assert documents[0].id == "1" + assert prompts[0].id == "p1" + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/test_reducers.py b/tests/test_reducers.py new file mode 100644 index 0000000..ff8dba1 --- /dev/null +++ b/tests/test_reducers.py @@ -0,0 +1,89 @@ +import pytest +import numpy as np +from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer + + +class TestReducerFactory: + + def test_create_pca_reducer(self): + reducer = ReducerFactory.create_reducer('pca', n_components=2) + assert isinstance(reducer, PCAReducer) + assert reducer.n_components == 2 + + def test_create_tsne_reducer(self): + reducer = ReducerFactory.create_reducer('tsne', n_components=3) + assert isinstance(reducer, TSNEReducer) + assert reducer.n_components == 3 + + def test_create_umap_reducer(self): + reducer = ReducerFactory.create_reducer('umap', n_components=2) + assert isinstance(reducer, UMAPReducer) + assert reducer.n_components == 2 + + def test_invalid_method(self): + with pytest.raises(ValueError, match="Unknown reduction method"): + ReducerFactory.create_reducer('invalid_method') + + def test_available_methods(self): + methods = ReducerFactory.get_available_methods() + assert 'pca' in methods + assert 'tsne' in methods + assert 'umap' in methods + + +class TestPCAReducer: + + def test_fit_transform(self): + embeddings = np.random.rand(100, 512) + reducer = PCAReducer(n_components=2) + + result = reducer.fit_transform(embeddings) + + assert result.reduced_embeddings.shape == (100, 2) + assert result.variance_explained is not None + assert result.method == "PCA" + assert result.n_components == 2 + + def test_method_name(self): + reducer = PCAReducer() + assert reducer.get_method_name() == "PCA" + + +class TestTSNEReducer: + + def test_fit_transform_small_dataset(self): + embeddings = np.random.rand(30, 10) # Small dataset for faster testing + reducer = TSNEReducer(n_components=2) + + result = reducer.fit_transform(embeddings) + + assert result.reduced_embeddings.shape == (30, 2) + assert result.variance_explained is None # t-SNE doesn't provide this + assert result.method == "t-SNE" + assert result.n_components == 2 + + def test_method_name(self): + reducer = TSNEReducer() + assert reducer.get_method_name() == "t-SNE" + + +class TestUMAPReducer: + + def test_fit_transform(self): + embeddings = np.random.rand(50, 10) + reducer = UMAPReducer(n_components=2) + + result = reducer.fit_transform(embeddings) + + assert result.reduced_embeddings.shape == (50, 2) + assert result.variance_explained is None # UMAP doesn't provide this + assert result.method == "UMAP" + assert result.n_components == 2 + + def test_method_name(self): + reducer = UMAPReducer() + assert reducer.get_method_name() == "UMAP" + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/uv.lock b/uv.lock index 8628586..ed5a3ce 100644 --- a/uv.lock +++ b/uv.lock @@ -133,7 +133,7 @@ wheels = [ [[package]] name = "embeddingbuddy" version = "0.1.0" -source = { virtual = "." } +source = { editable = "." } dependencies = [ { name = "dash" }, { name = "dash-bootstrap-components" }, @@ -142,6 +142,7 @@ dependencies = [ { name = "opentsne" }, { name = "pandas" }, { name = "plotly" }, + { name = "pytest" }, { name = "scikit-learn" }, { name = "umap-learn" }, ] @@ -155,6 +156,7 @@ requires-dist = [ { name = "opentsne", specifier = ">=1.0.0" }, { name = "pandas", specifier = ">=2.1.4" }, { name = "plotly", specifier = ">=5.17.0" }, + { name = "pytest", specifier = ">=8.4.1" }, { name = "scikit-learn", specifier = ">=1.3.2" }, { name = "umap-learn", specifier = ">=0.5.8" }, ] @@ -197,6 +199,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -473,6 +484,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/a9/12e2dc726ba1ba775a2c6922d5d5b4488ad60bdab0888c337c194c8e6de8/plotly-6.3.0-py3-none-any.whl", hash = "sha256:7ad806edce9d3cdd882eaebaf97c0c9e252043ed1ed3d382c3e3520ec07806d4", size = 9791257, upload-time = "2025-08-12T20:22:09.205Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + [[package]] name = "pynndescent" version = "0.5.13" @@ -489,6 +518,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/53/d23a97e0a2c690d40b165d1062e2c4ccc796be458a1ce59f6ba030434663/pynndescent-0.5.13-py3-none-any.whl", hash = "sha256:69aabb8f394bc631b6ac475a1c7f3994c54adf3f51cd63b2730fefba5771b949", size = 56850, upload-time = "2024-06-17T15:48:31.184Z" }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"