refactor and add tests, v0.2.0
This commit is contained in:
11
.claude/settings.local.json
Normal file
11
.claude/settings.local.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(mkdir:*)",
|
||||
"Bash(uv run:*)",
|
||||
"Bash(uv add:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
}
|
||||
}
|
154
CLAUDE.md
154
CLAUDE.md
@@ -5,10 +5,11 @@ code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
EmbeddingBuddy is a Python Dash web application for interactive exploration and
|
||||
EmbeddingBuddy is a modular Python Dash web application for interactive exploration and
|
||||
visualization of embedding vectors through dimensionality reduction techniques
|
||||
(PCA, t-SNE, UMAP). The app provides a drag-and-drop interface for uploading
|
||||
NDJSON files containing embeddings and visualizes them in 2D/3D plots.
|
||||
NDJSON files containing embeddings and visualizes them in 2D/3D plots. The codebase
|
||||
follows a clean, modular architecture that prioritizes testability and maintainability.
|
||||
|
||||
## Development Commands
|
||||
|
||||
@@ -21,32 +22,79 @@ uv sync
|
||||
**Run the application:**
|
||||
|
||||
```bash
|
||||
uv run python app.py
|
||||
uv run python main.py
|
||||
```
|
||||
|
||||
The app will be available at http://127.0.0.1:8050
|
||||
|
||||
**Run tests:**
|
||||
|
||||
```bash
|
||||
uv run pytest tests/ -v
|
||||
```
|
||||
|
||||
**Test with sample data:**
|
||||
Use the included `sample_data.ndjson` file for testing the application functionality.
|
||||
Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Files
|
||||
### Project Structure
|
||||
|
||||
- `app.py` - Main Dash application with complete web interface, data processing,
|
||||
and visualization logic
|
||||
- `main.py` - Simple entry point (currently minimal)
|
||||
- `pyproject.toml` - Project configuration and dependencies using uv package manager
|
||||
The application follows a modular architecture with clear separation of concerns:
|
||||
|
||||
```
|
||||
src/embeddingbuddy/
|
||||
├── app.py # Main application entry point and factory
|
||||
├── main.py # Application runner
|
||||
├── config/
|
||||
│ └── settings.py # Centralized configuration management
|
||||
├── data/
|
||||
│ ├── parser.py # NDJSON parsing logic
|
||||
│ └── processor.py # Data transformation and processing
|
||||
├── models/
|
||||
│ ├── schemas.py # Data models and validation schemas
|
||||
│ └── reducers.py # Dimensionality reduction algorithms
|
||||
├── visualization/
|
||||
│ ├── plots.py # Plot creation and factory classes
|
||||
│ └── colors.py # Color mapping and management
|
||||
├── ui/
|
||||
│ ├── layout.py # Main application layout
|
||||
│ ├── components/ # Reusable UI components
|
||||
│ │ ├── sidebar.py # Sidebar component
|
||||
│ │ └── upload.py # Upload components
|
||||
│ └── callbacks/ # Organized callback functions
|
||||
│ ├── data_processing.py # Data upload/processing callbacks
|
||||
│ ├── visualization.py # Plot update callbacks
|
||||
│ └── interactions.py # User interaction callbacks
|
||||
└── utils/ # Utility functions and helpers
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
- **Data Processing**: NDJSON parser that handles embedding documents with
|
||||
required fields (`embedding`, `text`) and optional metadata (`id`, `category`, `subcategory`, `tags`)
|
||||
- **Dimensionality Reduction**: Supports PCA, t-SNE (openTSNE), and UMAP algorithms
|
||||
- **Visualization**: Plotly-based 2D/3D scatter plots with interactive features
|
||||
- **UI Layout**: Bootstrap-styled sidebar with controls and large visualization area
|
||||
- **State Management**: Dash callbacks for reactive updates between upload,
|
||||
method selection, and plot rendering
|
||||
**Data Layer:**
|
||||
- `data/parser.py` - NDJSON parsing with error handling
|
||||
- `data/processor.py` - Data transformation and combination logic
|
||||
- `models/schemas.py` - Dataclasses for type safety and validation
|
||||
|
||||
**Algorithm Layer:**
|
||||
- `models/reducers.py` - Modular dimensionality reduction with factory pattern
|
||||
- Supports PCA, t-SNE (openTSNE), and UMAP algorithms
|
||||
- Abstract base class for easy extension
|
||||
|
||||
**Visualization Layer:**
|
||||
- `visualization/plots.py` - Plot factory with single and dual plot support
|
||||
- `visualization/colors.py` - Color mapping and grayscale conversion utilities
|
||||
- Plotly-based 2D/3D scatter plots with interactive features
|
||||
|
||||
**UI Layer:**
|
||||
- `ui/layout.py` - Main application layout composition
|
||||
- `ui/components/` - Reusable, testable UI components
|
||||
- `ui/callbacks/` - Organized callbacks grouped by functionality
|
||||
- Bootstrap-styled sidebar with controls and large visualization area
|
||||
|
||||
**Configuration:**
|
||||
- `config/settings.py` - Centralized settings with environment variable support
|
||||
- Plot styling, marker configurations, and app-wide constants
|
||||
|
||||
### Data Format
|
||||
|
||||
@@ -56,17 +104,77 @@ The application expects NDJSON files where each line contains:
|
||||
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, ...], "text": "Sample text", "category": "news", "subcategory": "politics", "tags": ["election"]}
|
||||
```
|
||||
|
||||
Required fields: `embedding` (array), `text` (string)
|
||||
Optional fields: `id`, `category`, `subcategory`, `tags`
|
||||
|
||||
### Callback Architecture
|
||||
|
||||
- File upload → Data processing and storage in dcc.Store
|
||||
- Method/parameter changes → Dimensionality reduction and plot update
|
||||
- Point clicks → Detail display in sidebar
|
||||
The refactored callback system is organized by functionality:
|
||||
|
||||
**Data Processing (`ui/callbacks/data_processing.py`):**
|
||||
- File upload handling
|
||||
- NDJSON parsing and validation
|
||||
- Data storage in dcc.Store components
|
||||
|
||||
**Visualization (`ui/callbacks/visualization.py`):**
|
||||
- Dimensionality reduction pipeline
|
||||
- Plot generation and updates
|
||||
- Method/parameter change handling
|
||||
|
||||
**Interactions (`ui/callbacks/interactions.py`):**
|
||||
- Point click handling and detail display
|
||||
- Reset functionality
|
||||
- User interaction management
|
||||
|
||||
### Testing Architecture
|
||||
|
||||
The modular design enables comprehensive testing:
|
||||
|
||||
**Unit Tests:**
|
||||
- `tests/test_data_processing.py` - Parser and processor logic
|
||||
- `tests/test_reducers.py` - Dimensionality reduction algorithms
|
||||
- `tests/test_visualization.py` - Plot creation and color mapping
|
||||
|
||||
**Integration Tests:**
|
||||
- End-to-end data pipeline testing
|
||||
- Component integration verification
|
||||
|
||||
**Key Testing Benefits:**
|
||||
- Fast test execution (milliseconds vs seconds)
|
||||
- Isolated component testing
|
||||
- Easy mocking and fixture creation
|
||||
- High code coverage achievable
|
||||
|
||||
## Dependencies
|
||||
|
||||
Uses modern Python stack with uv for dependency management:
|
||||
|
||||
- Dash + Plotly for web interface and visualization
|
||||
- scikit-learn (PCA), openTSNE, umap-learn for dimensionality reduction
|
||||
- pandas/numpy for data manipulation
|
||||
- dash-bootstrap-components for styling
|
||||
- **Core Framework:** Dash + Plotly for web interface and visualization
|
||||
- **Algorithms:** scikit-learn (PCA), openTSNE, umap-learn for dimensionality reduction
|
||||
- **Data:** pandas/numpy for data manipulation
|
||||
- **UI:** dash-bootstrap-components for styling
|
||||
- **Testing:** pytest for test framework
|
||||
- **Dev Tools:** uv for package management
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
**When adding new features:**
|
||||
|
||||
1. **Data Models** - Add/update schemas in `models/schemas.py`
|
||||
2. **Algorithms** - Extend `models/reducers.py` using the abstract base class
|
||||
3. **UI Components** - Create reusable components in `ui/components/`
|
||||
4. **Configuration** - Add settings to `config/settings.py`
|
||||
5. **Tests** - Write tests for all new functionality
|
||||
|
||||
**Code Organization Principles:**
|
||||
- Single responsibility principle
|
||||
- Clear module boundaries
|
||||
- Testable, isolated components
|
||||
- Configuration over hardcoding
|
||||
- Error handling at appropriate layers
|
||||
|
||||
**Testing Requirements:**
|
||||
- Unit tests for all core logic
|
||||
- Integration tests for data flow
|
||||
- Component tests for UI elements
|
||||
- Maintain high test coverage
|
65
README.md
65
README.md
@@ -1,6 +1,6 @@
|
||||
# EmbeddingBuddy
|
||||
|
||||
A web application for interactive exploration and visualization of embedding
|
||||
A modular Python Dash web application for interactive exploration and visualization of embedding
|
||||
vectors through dimensionality reduction techniques. Compare documents and prompts
|
||||
in the same embedding space to understand semantic relationships.
|
||||
|
||||
@@ -10,9 +10,10 @@ in the same embedding space to understand semantic relationships.
|
||||
|
||||
EmbeddingBuddy provides an intuitive web interface for analyzing high-dimensional
|
||||
embedding vectors by applying various dimensionality reduction algorithms and
|
||||
visualizing the results in interactive 2D and 3D plots. The application supports
|
||||
dual dataset visualization, allowing you to compare documents and prompts to
|
||||
understand how queries relate to your content.
|
||||
visualizing the results in interactive 2D and 3D plots. The application features
|
||||
a clean, modular architecture that makes it easy to test, maintain, and extend
|
||||
with new features. It supports dual dataset visualization, allowing you to compare
|
||||
documents and prompts to understand how queries relate to your content.
|
||||
|
||||
## Features
|
||||
|
||||
@@ -73,7 +74,7 @@ uv sync
|
||||
2. **Run the application:**
|
||||
|
||||
```bash
|
||||
uv run python app.py
|
||||
uv run python main.py
|
||||
```
|
||||
|
||||
3. **Open your browser** to http://127.0.0.1:8050
|
||||
@@ -83,6 +84,59 @@ uv run python app.py
|
||||
- Upload `sample_prompts.ndjson` (prompts) to see dual visualization
|
||||
- Use the "Show prompts" toggle to compare how prompts relate to documents
|
||||
|
||||
## Development
|
||||
|
||||
### Project Structure
|
||||
|
||||
The application follows a modular architecture for improved maintainability and testability:
|
||||
|
||||
```
|
||||
src/embeddingbuddy/
|
||||
├── config/ # Configuration management
|
||||
│ └── settings.py # Centralized app settings
|
||||
├── data/ # Data parsing and processing
|
||||
│ ├── parser.py # NDJSON parsing logic
|
||||
│ └── processor.py # Data transformation utilities
|
||||
├── models/ # Data schemas and algorithms
|
||||
│ ├── schemas.py # Pydantic data models
|
||||
│ └── reducers.py # Dimensionality reduction algorithms
|
||||
├── visualization/ # Plot creation and styling
|
||||
│ ├── plots.py # Plot factory and creation logic
|
||||
│ └── colors.py # Color mapping utilities
|
||||
├── ui/ # User interface components
|
||||
│ ├── layout.py # Main application layout
|
||||
│ ├── components/ # Reusable UI components
|
||||
│ └── callbacks/ # Organized callback functions
|
||||
└── utils/ # Utility functions
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
Run the test suite to verify functionality:
|
||||
|
||||
```bash
|
||||
# Install pytest
|
||||
uv add pytest
|
||||
|
||||
# Run all tests
|
||||
uv run pytest tests/ -v
|
||||
|
||||
# Run specific test file
|
||||
uv run pytest tests/test_data_processing.py -v
|
||||
|
||||
# Run with coverage
|
||||
uv run pytest tests/ --cov=src/embeddingbuddy
|
||||
```
|
||||
|
||||
### Adding New Features
|
||||
|
||||
The modular architecture makes it easy to extend functionality:
|
||||
|
||||
- **New reduction algorithms**: Add to `models/reducers.py`
|
||||
- **New plot types**: Extend `visualization/plots.py`
|
||||
- **UI components**: Add to `ui/components/`
|
||||
- **Configuration options**: Update `config/settings.py`
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Python Dash**: Web application framework
|
||||
@@ -91,4 +145,5 @@ uv run python app.py
|
||||
- **UMAP-learn**: UMAP dimensionality reduction
|
||||
- **openTSNE**: Fast t-SNE implementation
|
||||
- **NumPy/Pandas**: Data manipulation and analysis
|
||||
- **pytest**: Testing framework
|
||||
- **uv**: Modern Python package and project manager
|
||||
|
6
main.py
6
main.py
@@ -1,5 +1,9 @@
|
||||
from src.embeddingbuddy.app import create_app, run_app
|
||||
|
||||
|
||||
def main():
|
||||
print("Hello from embeddingbuddy!")
|
||||
app = create_app()
|
||||
run_app(app)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "embeddingbuddy"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "A Python Dash application for interactive exploration and visualization of embedding vectors through dimensionality reduction techniques."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
@@ -13,5 +13,16 @@ dependencies = [
|
||||
"dash-bootstrap-components>=1.5.0",
|
||||
"umap-learn>=0.5.8",
|
||||
"numba>=0.56.4",
|
||||
"openTSNE>=1.0.0"
|
||||
"openTSNE>=1.0.0",
|
||||
"pytest>=8.4.1",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
"" = "src"
|
||||
|
3
src/embeddingbuddy/__init__.py
Normal file
3
src/embeddingbuddy/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
|
||||
|
||||
__version__ = "0.1.0"
|
39
src/embeddingbuddy/app.py
Normal file
39
src/embeddingbuddy/app.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import dash
|
||||
import dash_bootstrap_components as dbc
|
||||
from .config.settings import AppSettings
|
||||
from .ui.layout import AppLayout
|
||||
from .ui.callbacks.data_processing import DataProcessingCallbacks
|
||||
from .ui.callbacks.visualization import VisualizationCallbacks
|
||||
from .ui.callbacks.interactions import InteractionCallbacks
|
||||
|
||||
|
||||
def create_app():
|
||||
app = dash.Dash(
|
||||
__name__,
|
||||
external_stylesheets=[dbc.themes.BOOTSTRAP]
|
||||
)
|
||||
|
||||
layout_manager = AppLayout()
|
||||
app.layout = layout_manager.create_layout()
|
||||
|
||||
DataProcessingCallbacks()
|
||||
VisualizationCallbacks()
|
||||
InteractionCallbacks()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_app(app=None, debug=None, host=None, port=None):
|
||||
if app is None:
|
||||
app = create_app()
|
||||
|
||||
app.run(
|
||||
debug=debug if debug is not None else AppSettings.DEBUG,
|
||||
host=host if host is not None else AppSettings.HOST,
|
||||
port=port if port is not None else AppSettings.PORT
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_app()
|
||||
run_app(app)
|
0
src/embeddingbuddy/config/__init__.py
Normal file
0
src/embeddingbuddy/config/__init__.py
Normal file
107
src/embeddingbuddy/config/settings.py
Normal file
107
src/embeddingbuddy/config/settings.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
|
||||
|
||||
class AppSettings:
|
||||
|
||||
# UI Configuration
|
||||
UPLOAD_STYLE = {
|
||||
'width': '100%',
|
||||
'height': '60px',
|
||||
'lineHeight': '60px',
|
||||
'borderWidth': '1px',
|
||||
'borderStyle': 'dashed',
|
||||
'borderRadius': '5px',
|
||||
'textAlign': 'center',
|
||||
'margin-bottom': '20px'
|
||||
}
|
||||
|
||||
PROMPTS_UPLOAD_STYLE = {
|
||||
**UPLOAD_STYLE,
|
||||
'borderColor': '#28a745'
|
||||
}
|
||||
|
||||
PLOT_CONFIG = {
|
||||
'responsive': True,
|
||||
'displayModeBar': True
|
||||
}
|
||||
|
||||
PLOT_STYLE = {
|
||||
'height': '85vh',
|
||||
'width': '100%'
|
||||
}
|
||||
|
||||
PLOT_LAYOUT_CONFIG = {
|
||||
'height': None,
|
||||
'autosize': True,
|
||||
'margin': dict(l=0, r=0, t=50, b=0)
|
||||
}
|
||||
|
||||
# Dimensionality Reduction Settings
|
||||
DEFAULT_N_COMPONENTS_3D = 3
|
||||
DEFAULT_N_COMPONENTS_2D = 2
|
||||
DEFAULT_RANDOM_STATE = 42
|
||||
|
||||
# Available Methods
|
||||
REDUCTION_METHODS = [
|
||||
{'label': 'PCA', 'value': 'pca'},
|
||||
{'label': 't-SNE', 'value': 'tsne'},
|
||||
{'label': 'UMAP', 'value': 'umap'}
|
||||
]
|
||||
|
||||
COLOR_OPTIONS = [
|
||||
{'label': 'Category', 'value': 'category'},
|
||||
{'label': 'Subcategory', 'value': 'subcategory'},
|
||||
{'label': 'Tags', 'value': 'tags'}
|
||||
]
|
||||
|
||||
DIMENSION_OPTIONS = [
|
||||
{'label': '2D', 'value': '2d'},
|
||||
{'label': '3D', 'value': '3d'}
|
||||
]
|
||||
|
||||
# Default Values
|
||||
DEFAULT_METHOD = 'pca'
|
||||
DEFAULT_COLOR_BY = 'category'
|
||||
DEFAULT_DIMENSIONS = '3d'
|
||||
DEFAULT_SHOW_PROMPTS = ['show']
|
||||
|
||||
# Plot Marker Settings
|
||||
DOCUMENT_MARKER_SIZE_2D = 8
|
||||
DOCUMENT_MARKER_SIZE_3D = 5
|
||||
PROMPT_MARKER_SIZE_2D = 10
|
||||
PROMPT_MARKER_SIZE_3D = 6
|
||||
|
||||
DOCUMENT_MARKER_SYMBOL = 'circle'
|
||||
PROMPT_MARKER_SYMBOL = 'diamond'
|
||||
|
||||
DOCUMENT_OPACITY = 1.0
|
||||
PROMPT_OPACITY = 0.8
|
||||
|
||||
# Text Processing
|
||||
TEXT_PREVIEW_LENGTH = 100
|
||||
|
||||
# App Configuration
|
||||
DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true'
|
||||
HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1')
|
||||
PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050'))
|
||||
|
||||
# Bootstrap Theme
|
||||
EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css']
|
||||
|
||||
@classmethod
|
||||
def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]:
|
||||
if is_prompt:
|
||||
size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D
|
||||
symbol = cls.PROMPT_MARKER_SYMBOL
|
||||
opacity = cls.PROMPT_OPACITY
|
||||
else:
|
||||
size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D
|
||||
symbol = cls.DOCUMENT_MARKER_SYMBOL
|
||||
opacity = cls.DOCUMENT_OPACITY
|
||||
|
||||
return {
|
||||
'size': size,
|
||||
'symbol': symbol,
|
||||
'opacity': opacity
|
||||
}
|
0
src/embeddingbuddy/data/__init__.py
Normal file
0
src/embeddingbuddy/data/__init__.py
Normal file
39
src/embeddingbuddy/data/parser.py
Normal file
39
src/embeddingbuddy/data/parser.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
from typing import List, Union
|
||||
from ..models.schemas import Document, ProcessedData
|
||||
|
||||
|
||||
class NDJSONParser:
|
||||
|
||||
@staticmethod
|
||||
def parse_upload_contents(contents: str) -> List[Document]:
|
||||
content_type, content_string = contents.split(',')
|
||||
decoded = base64.b64decode(content_string)
|
||||
text_content = decoded.decode('utf-8')
|
||||
return NDJSONParser.parse_text(text_content)
|
||||
|
||||
@staticmethod
|
||||
def parse_text(text_content: str) -> List[Document]:
|
||||
documents = []
|
||||
for line in text_content.strip().split('\n'):
|
||||
if line.strip():
|
||||
doc_dict = json.loads(line)
|
||||
doc = NDJSONParser._dict_to_document(doc_dict)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_document(doc_dict: dict) -> Document:
|
||||
if 'id' not in doc_dict:
|
||||
doc_dict['id'] = str(uuid.uuid4())
|
||||
|
||||
return Document(
|
||||
id=doc_dict['id'],
|
||||
text=doc_dict['text'],
|
||||
embedding=doc_dict['embedding'],
|
||||
category=doc_dict.get('category'),
|
||||
subcategory=doc_dict.get('subcategory'),
|
||||
tags=doc_dict.get('tags')
|
||||
)
|
54
src/embeddingbuddy/data/processor.py
Normal file
54
src/embeddingbuddy/data/processor.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple
|
||||
from ..models.schemas import Document, ProcessedData
|
||||
from .parser import NDJSONParser
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self):
|
||||
self.parser = NDJSONParser()
|
||||
|
||||
def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData:
|
||||
try:
|
||||
documents = self.parser.parse_upload_contents(contents)
|
||||
embeddings = self._extract_embeddings(documents)
|
||||
return ProcessedData(documents=documents, embeddings=embeddings)
|
||||
except Exception as e:
|
||||
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
||||
|
||||
def process_text(self, text_content: str) -> ProcessedData:
|
||||
try:
|
||||
documents = self.parser.parse_text(text_content)
|
||||
embeddings = self._extract_embeddings(documents)
|
||||
return ProcessedData(documents=documents, embeddings=embeddings)
|
||||
except Exception as e:
|
||||
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
||||
|
||||
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
|
||||
if not documents:
|
||||
return np.array([])
|
||||
return np.array([doc.embedding for doc in documents])
|
||||
|
||||
def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
|
||||
if not doc_data or doc_data.error:
|
||||
raise ValueError("Invalid document data")
|
||||
|
||||
all_embeddings = doc_data.embeddings
|
||||
documents = doc_data.documents
|
||||
prompts = None
|
||||
|
||||
if prompt_data and not prompt_data.error and prompt_data.documents:
|
||||
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
|
||||
prompts = prompt_data.documents
|
||||
|
||||
return all_embeddings, documents, prompts
|
||||
|
||||
def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
doc_reduced = reduced_embeddings[:n_documents]
|
||||
prompt_reduced = None
|
||||
|
||||
if n_prompts > 0:
|
||||
prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts]
|
||||
|
||||
return doc_reduced, prompt_reduced
|
0
src/embeddingbuddy/models/__init__.py
Normal file
0
src/embeddingbuddy/models/__init__.py
Normal file
95
src/embeddingbuddy/models/reducers.py
Normal file
95
src/embeddingbuddy/models/reducers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
from sklearn.decomposition import PCA
|
||||
import umap
|
||||
from openTSNE import TSNE
|
||||
from .schemas import ReducedData
|
||||
|
||||
|
||||
class DimensionalityReducer(ABC):
|
||||
|
||||
def __init__(self, n_components: int = 3, random_state: int = 42):
|
||||
self.n_components = n_components
|
||||
self.random_state = random_state
|
||||
self._reducer = None
|
||||
|
||||
@abstractmethod
|
||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_method_name(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class PCAReducer(DimensionalityReducer):
|
||||
|
||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||
self._reducer = PCA(n_components=self.n_components)
|
||||
reduced = self._reducer.fit_transform(embeddings)
|
||||
variance_explained = self._reducer.explained_variance_ratio_
|
||||
|
||||
return ReducedData(
|
||||
reduced_embeddings=reduced,
|
||||
variance_explained=variance_explained,
|
||||
method=self.get_method_name(),
|
||||
n_components=self.n_components
|
||||
)
|
||||
|
||||
def get_method_name(self) -> str:
|
||||
return "PCA"
|
||||
|
||||
|
||||
class TSNEReducer(DimensionalityReducer):
|
||||
|
||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||
self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state)
|
||||
reduced = self._reducer.fit(embeddings)
|
||||
|
||||
return ReducedData(
|
||||
reduced_embeddings=reduced,
|
||||
variance_explained=None,
|
||||
method=self.get_method_name(),
|
||||
n_components=self.n_components
|
||||
)
|
||||
|
||||
def get_method_name(self) -> str:
|
||||
return "t-SNE"
|
||||
|
||||
|
||||
class UMAPReducer(DimensionalityReducer):
|
||||
|
||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||
self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state)
|
||||
reduced = self._reducer.fit_transform(embeddings)
|
||||
|
||||
return ReducedData(
|
||||
reduced_embeddings=reduced,
|
||||
variance_explained=None,
|
||||
method=self.get_method_name(),
|
||||
n_components=self.n_components
|
||||
)
|
||||
|
||||
def get_method_name(self) -> str:
|
||||
return "UMAP"
|
||||
|
||||
|
||||
class ReducerFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer:
|
||||
method_lower = method.lower()
|
||||
|
||||
if method_lower == 'pca':
|
||||
return PCAReducer(n_components=n_components, random_state=random_state)
|
||||
elif method_lower == 'tsne':
|
||||
return TSNEReducer(n_components=n_components, random_state=random_state)
|
||||
elif method_lower == 'umap':
|
||||
return UMAPReducer(n_components=n_components, random_state=random_state)
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction method: {method}")
|
||||
|
||||
@staticmethod
|
||||
def get_available_methods() -> list:
|
||||
return ['pca', 'tsne', 'umap']
|
58
src/embeddingbuddy/models/schemas.py
Normal file
58
src/embeddingbuddy/models/schemas.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import List, Optional, Any, Dict
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
id: str
|
||||
text: str
|
||||
embedding: List[float]
|
||||
category: Optional[str] = None
|
||||
subcategory: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
if self.category is None:
|
||||
self.category = "Unknown"
|
||||
if self.subcategory is None:
|
||||
self.subcategory = "Unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedData:
|
||||
documents: List[Document]
|
||||
embeddings: np.ndarray
|
||||
error: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.embeddings is not None and not isinstance(self.embeddings, np.ndarray):
|
||||
self.embeddings = np.array(self.embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReducedData:
|
||||
reduced_embeddings: np.ndarray
|
||||
variance_explained: Optional[np.ndarray] = None
|
||||
method: str = "unknown"
|
||||
n_components: int = 2
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.reduced_embeddings, np.ndarray):
|
||||
self.reduced_embeddings = np.array(self.reduced_embeddings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotData:
|
||||
documents: List[Document]
|
||||
coordinates: np.ndarray
|
||||
prompts: Optional[List[Document]] = None
|
||||
prompt_coordinates: Optional[np.ndarray] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.coordinates, np.ndarray):
|
||||
self.coordinates = np.array(self.coordinates)
|
||||
if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray):
|
||||
self.prompt_coordinates = np.array(self.prompt_coordinates)
|
0
src/embeddingbuddy/ui/__init__.py
Normal file
0
src/embeddingbuddy/ui/__init__.py
Normal file
0
src/embeddingbuddy/ui/callbacks/__init__.py
Normal file
0
src/embeddingbuddy/ui/callbacks/__init__.py
Normal file
61
src/embeddingbuddy/ui/callbacks/data_processing.py
Normal file
61
src/embeddingbuddy/ui/callbacks/data_processing.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import numpy as np
|
||||
from dash import callback, Input, Output, State
|
||||
from ...data.processor import DataProcessor
|
||||
|
||||
|
||||
class DataProcessingCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self.processor = DataProcessor()
|
||||
self._register_callbacks()
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('processed-data', 'data'),
|
||||
Input('upload-data', 'contents'),
|
||||
State('upload-data', 'filename')
|
||||
)
|
||||
def process_uploaded_file(contents, filename):
|
||||
if contents is None:
|
||||
return None
|
||||
|
||||
processed_data = self.processor.process_upload(contents, filename)
|
||||
|
||||
if processed_data.error:
|
||||
return {'error': processed_data.error}
|
||||
|
||||
return {
|
||||
'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
|
||||
'embeddings': processed_data.embeddings.tolist()
|
||||
}
|
||||
|
||||
@callback(
|
||||
Output('processed-prompts', 'data'),
|
||||
Input('upload-prompts', 'contents'),
|
||||
State('upload-prompts', 'filename')
|
||||
)
|
||||
def process_uploaded_prompts(contents, filename):
|
||||
if contents is None:
|
||||
return None
|
||||
|
||||
processed_data = self.processor.process_upload(contents, filename)
|
||||
|
||||
if processed_data.error:
|
||||
return {'error': processed_data.error}
|
||||
|
||||
return {
|
||||
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
|
||||
'embeddings': processed_data.embeddings.tolist()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _document_to_dict(doc):
|
||||
return {
|
||||
'id': doc.id,
|
||||
'text': doc.text,
|
||||
'embedding': doc.embedding,
|
||||
'category': doc.category,
|
||||
'subcategory': doc.subcategory,
|
||||
'tags': doc.tags
|
||||
}
|
66
src/embeddingbuddy/ui/callbacks/interactions.py
Normal file
66
src/embeddingbuddy/ui/callbacks/interactions.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import dash
|
||||
from dash import callback, Input, Output, State, html
|
||||
import dash_bootstrap_components as dbc
|
||||
|
||||
|
||||
class InteractionCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self._register_callbacks()
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('point-details', 'children'),
|
||||
Input('embedding-plot', 'clickData'),
|
||||
[State('processed-data', 'data'),
|
||||
State('processed-prompts', 'data')]
|
||||
)
|
||||
def display_click_data(clickData, data, prompts_data):
|
||||
if not clickData or not data:
|
||||
return "Click on a point to see details"
|
||||
|
||||
point_data = clickData['points'][0]
|
||||
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
|
||||
|
||||
if 'pointIndex' in point_data:
|
||||
point_index = point_data['pointIndex']
|
||||
elif 'pointNumber' in point_data:
|
||||
point_index = point_data['pointNumber']
|
||||
else:
|
||||
return "Could not identify clicked point"
|
||||
|
||||
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
|
||||
item = prompts_data['prompts'][point_index]
|
||||
item_type = 'Prompt'
|
||||
else:
|
||||
item = data['documents'][point_index]
|
||||
item_type = 'Document'
|
||||
|
||||
return self._create_detail_card(item, item_type)
|
||||
|
||||
@callback(
|
||||
[Output('processed-data', 'data', allow_duplicate=True),
|
||||
Output('processed-prompts', 'data', allow_duplicate=True),
|
||||
Output('point-details', 'children', allow_duplicate=True)],
|
||||
Input('reset-button', 'n_clicks'),
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def reset_data(n_clicks):
|
||||
if n_clicks is None or n_clicks == 0:
|
||||
return dash.no_update, dash.no_update, dash.no_update
|
||||
|
||||
return None, None, "Click on a point to see details"
|
||||
|
||||
@staticmethod
|
||||
def _create_detail_card(item, item_type):
|
||||
return dbc.Card([
|
||||
dbc.CardBody([
|
||||
html.H5(f"{item_type}: {item['id']}", className="card-title"),
|
||||
html.P(f"Text: {item['text']}", className="card-text"),
|
||||
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
|
||||
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
|
||||
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
|
||||
html.P(f"Type: {item_type}", className="card-text text-muted")
|
||||
])
|
||||
])
|
87
src/embeddingbuddy/ui/callbacks/visualization.py
Normal file
87
src/embeddingbuddy/ui/callbacks/visualization.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
from dash import callback, Input, Output
|
||||
import plotly.graph_objects as go
|
||||
from ...models.reducers import ReducerFactory
|
||||
from ...models.schemas import Document, PlotData
|
||||
from ...visualization.plots import PlotFactory
|
||||
|
||||
|
||||
class VisualizationCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self.plot_factory = PlotFactory()
|
||||
self._register_callbacks()
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('embedding-plot', 'figure'),
|
||||
[Input('processed-data', 'data'),
|
||||
Input('processed-prompts', 'data'),
|
||||
Input('method-dropdown', 'value'),
|
||||
Input('color-dropdown', 'value'),
|
||||
Input('dimension-toggle', 'value'),
|
||||
Input('show-prompts-toggle', 'value')]
|
||||
)
|
||||
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
||||
if not data or 'error' in data:
|
||||
return go.Figure().add_annotation(
|
||||
text="Upload a valid NDJSON file to see visualization",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
)
|
||||
|
||||
try:
|
||||
doc_embeddings = np.array(data['embeddings'])
|
||||
all_embeddings = doc_embeddings
|
||||
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
|
||||
|
||||
if has_prompts:
|
||||
prompt_embeddings = np.array(prompts_data['embeddings'])
|
||||
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
||||
|
||||
n_components = 3 if dimensions == '3d' else 2
|
||||
|
||||
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
|
||||
reduced_data = reducer.fit_transform(all_embeddings)
|
||||
|
||||
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
|
||||
prompt_reduced = None
|
||||
if has_prompts:
|
||||
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
|
||||
|
||||
documents = [self._dict_to_document(doc) for doc in data['documents']]
|
||||
prompts = None
|
||||
if has_prompts:
|
||||
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
|
||||
|
||||
plot_data = PlotData(
|
||||
documents=documents,
|
||||
coordinates=doc_reduced,
|
||||
prompts=prompts,
|
||||
prompt_coordinates=prompt_reduced
|
||||
)
|
||||
|
||||
return self.plot_factory.create_plot(
|
||||
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return go.Figure().add_annotation(
|
||||
text=f"Error creating visualization: {str(e)}",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_document(doc_dict):
|
||||
return Document(
|
||||
id=doc_dict['id'],
|
||||
text=doc_dict['text'],
|
||||
embedding=doc_dict['embedding'],
|
||||
category=doc_dict.get('category'),
|
||||
subcategory=doc_dict.get('subcategory'),
|
||||
tags=doc_dict.get('tags', [])
|
||||
)
|
0
src/embeddingbuddy/ui/components/__init__.py
Normal file
0
src/embeddingbuddy/ui/components/__init__.py
Normal file
82
src/embeddingbuddy/ui/components/sidebar.py
Normal file
82
src/embeddingbuddy/ui/components/sidebar.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from dash import dcc, html
|
||||
import dash_bootstrap_components as dbc
|
||||
from .upload import UploadComponent
|
||||
|
||||
|
||||
class SidebarComponent:
|
||||
|
||||
def __init__(self):
|
||||
self.upload_component = UploadComponent()
|
||||
|
||||
def create_layout(self):
|
||||
return dbc.Col([
|
||||
html.H5("Upload Data", className="mb-3"),
|
||||
self.upload_component.create_data_upload(),
|
||||
self.upload_component.create_prompts_upload(),
|
||||
self.upload_component.create_reset_button(),
|
||||
|
||||
html.H5("Visualization Controls", className="mb-3"),
|
||||
self._create_method_dropdown(),
|
||||
self._create_color_dropdown(),
|
||||
self._create_dimension_toggle(),
|
||||
self._create_prompts_toggle(),
|
||||
|
||||
html.H5("Point Details", className="mb-3"),
|
||||
html.Div(id='point-details', children="Click on a point to see details")
|
||||
|
||||
], width=3, style={'padding-right': '20px'})
|
||||
|
||||
def _create_method_dropdown(self):
|
||||
return [
|
||||
dbc.Label("Method:"),
|
||||
dcc.Dropdown(
|
||||
id='method-dropdown',
|
||||
options=[
|
||||
{'label': 'PCA', 'value': 'pca'},
|
||||
{'label': 't-SNE', 'value': 'tsne'},
|
||||
{'label': 'UMAP', 'value': 'umap'}
|
||||
],
|
||||
value='pca',
|
||||
style={'margin-bottom': '15px'}
|
||||
)
|
||||
]
|
||||
|
||||
def _create_color_dropdown(self):
|
||||
return [
|
||||
dbc.Label("Color by:"),
|
||||
dcc.Dropdown(
|
||||
id='color-dropdown',
|
||||
options=[
|
||||
{'label': 'Category', 'value': 'category'},
|
||||
{'label': 'Subcategory', 'value': 'subcategory'},
|
||||
{'label': 'Tags', 'value': 'tags'}
|
||||
],
|
||||
value='category',
|
||||
style={'margin-bottom': '15px'}
|
||||
)
|
||||
]
|
||||
|
||||
def _create_dimension_toggle(self):
|
||||
return [
|
||||
dbc.Label("Dimensions:"),
|
||||
dcc.RadioItems(
|
||||
id='dimension-toggle',
|
||||
options=[
|
||||
{'label': '2D', 'value': '2d'},
|
||||
{'label': '3D', 'value': '3d'}
|
||||
],
|
||||
value='3d',
|
||||
style={'margin-bottom': '20px'}
|
||||
)
|
||||
]
|
||||
|
||||
def _create_prompts_toggle(self):
|
||||
return [
|
||||
dbc.Label("Show Prompts:"),
|
||||
dcc.Checklist(
|
||||
id='show-prompts-toggle',
|
||||
options=[{'label': 'Show prompts on plot', 'value': 'show'}],
|
||||
value=['show'],
|
||||
style={'margin-bottom': '20px'}
|
||||
)
|
||||
]
|
60
src/embeddingbuddy/ui/components/upload.py
Normal file
60
src/embeddingbuddy/ui/components/upload.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from dash import dcc, html
|
||||
import dash_bootstrap_components as dbc
|
||||
|
||||
|
||||
class UploadComponent:
|
||||
|
||||
@staticmethod
|
||||
def create_data_upload():
|
||||
return dcc.Upload(
|
||||
id='upload-data',
|
||||
children=html.Div([
|
||||
'Drag and Drop or ',
|
||||
html.A('Select Files')
|
||||
]),
|
||||
style={
|
||||
'width': '100%',
|
||||
'height': '60px',
|
||||
'lineHeight': '60px',
|
||||
'borderWidth': '1px',
|
||||
'borderStyle': 'dashed',
|
||||
'borderRadius': '5px',
|
||||
'textAlign': 'center',
|
||||
'margin-bottom': '20px'
|
||||
},
|
||||
multiple=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_prompts_upload():
|
||||
return dcc.Upload(
|
||||
id='upload-prompts',
|
||||
children=html.Div([
|
||||
'Drag and Drop Prompts or ',
|
||||
html.A('Select Files')
|
||||
]),
|
||||
style={
|
||||
'width': '100%',
|
||||
'height': '60px',
|
||||
'lineHeight': '60px',
|
||||
'borderWidth': '1px',
|
||||
'borderStyle': 'dashed',
|
||||
'borderRadius': '5px',
|
||||
'textAlign': 'center',
|
||||
'margin-bottom': '20px',
|
||||
'borderColor': '#28a745'
|
||||
},
|
||||
multiple=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_reset_button():
|
||||
return dbc.Button(
|
||||
"Reset All Data",
|
||||
id='reset-button',
|
||||
color='danger',
|
||||
outline=True,
|
||||
size='sm',
|
||||
className='mb-3',
|
||||
style={'width': '100%'}
|
||||
)
|
44
src/embeddingbuddy/ui/layout.py
Normal file
44
src/embeddingbuddy/ui/layout.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from dash import dcc, html
|
||||
import dash_bootstrap_components as dbc
|
||||
from .components.sidebar import SidebarComponent
|
||||
|
||||
|
||||
class AppLayout:
|
||||
|
||||
def __init__(self):
|
||||
self.sidebar = SidebarComponent()
|
||||
|
||||
def create_layout(self):
|
||||
return dbc.Container([
|
||||
self._create_header(),
|
||||
self._create_main_content(),
|
||||
self._create_stores()
|
||||
], fluid=True)
|
||||
|
||||
def _create_header(self):
|
||||
return dbc.Row([
|
||||
dbc.Col([
|
||||
html.H1("EmbeddingBuddy", className="text-center mb-4"),
|
||||
], width=12)
|
||||
])
|
||||
|
||||
def _create_main_content(self):
|
||||
return dbc.Row([
|
||||
self.sidebar.create_layout(),
|
||||
self._create_visualization_area()
|
||||
])
|
||||
|
||||
def _create_visualization_area(self):
|
||||
return dbc.Col([
|
||||
dcc.Graph(
|
||||
id='embedding-plot',
|
||||
style={'height': '85vh', 'width': '100%'},
|
||||
config={'responsive': True, 'displayModeBar': True}
|
||||
)
|
||||
], width=9)
|
||||
|
||||
def _create_stores(self):
|
||||
return [
|
||||
dcc.Store(id='processed-data'),
|
||||
dcc.Store(id='processed-prompts')
|
||||
]
|
0
src/embeddingbuddy/utils/__init__.py
Normal file
0
src/embeddingbuddy/utils/__init__.py
Normal file
0
src/embeddingbuddy/visualization/__init__.py
Normal file
0
src/embeddingbuddy/visualization/__init__.py
Normal file
33
src/embeddingbuddy/visualization/colors.py
Normal file
33
src/embeddingbuddy/visualization/colors.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import List, Dict, Any
|
||||
import plotly.colors as pc
|
||||
from ..models.schemas import Document
|
||||
|
||||
|
||||
class ColorMapper:
|
||||
|
||||
@staticmethod
|
||||
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
|
||||
if color_by == 'category':
|
||||
return [doc.category for doc in documents]
|
||||
elif color_by == 'subcategory':
|
||||
return [doc.subcategory for doc in documents]
|
||||
elif color_by == 'tags':
|
||||
return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents]
|
||||
else:
|
||||
return ['All'] * len(documents)
|
||||
|
||||
@staticmethod
|
||||
def to_grayscale_hex(color_str: str) -> str:
|
||||
try:
|
||||
if color_str.startswith('#'):
|
||||
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
|
||||
else:
|
||||
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
|
||||
|
||||
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
|
||||
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
|
||||
gray_value * 0.7 + rgb[1] * 0.3,
|
||||
gray_value * 0.7 + rgb[2] * 0.3)
|
||||
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
|
||||
except:
|
||||
return 'rgb(128,128,128)'
|
145
src/embeddingbuddy/visualization/plots.py
Normal file
145
src/embeddingbuddy/visualization/plots.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import plotly.graph_objects as go
|
||||
from typing import List, Optional
|
||||
from ..models.schemas import Document, PlotData
|
||||
from .colors import ColorMapper
|
||||
|
||||
|
||||
class PlotFactory:
|
||||
|
||||
def __init__(self):
|
||||
self.color_mapper = ColorMapper()
|
||||
|
||||
def create_plot(self, plot_data: PlotData, dimensions: str = '3d',
|
||||
color_by: str = 'category', method: str = 'PCA',
|
||||
show_prompts: Optional[List[str]] = None) -> go.Figure:
|
||||
|
||||
if plot_data.prompts and show_prompts and 'show' in show_prompts:
|
||||
return self._create_dual_plot(plot_data, dimensions, color_by, method)
|
||||
else:
|
||||
return self._create_single_plot(plot_data, dimensions, color_by, method)
|
||||
|
||||
def _create_single_plot(self, plot_data: PlotData, dimensions: str,
|
||||
color_by: str, method: str) -> go.Figure:
|
||||
df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
|
||||
color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
|
||||
|
||||
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
|
||||
|
||||
if dimensions == '3d':
|
||||
fig = px.scatter_3d(
|
||||
df, x='dim_1', y='dim_2', z='dim_3',
|
||||
color=color_values,
|
||||
hover_data=hover_fields,
|
||||
title=f'3D Embedding Visualization - {method} (colored by {color_by})'
|
||||
)
|
||||
fig.update_traces(marker=dict(size=5))
|
||||
else:
|
||||
fig = px.scatter(
|
||||
df, x='dim_1', y='dim_2',
|
||||
color=color_values,
|
||||
hover_data=hover_fields,
|
||||
title=f'2D Embedding Visualization - {method} (colored by {color_by})'
|
||||
)
|
||||
fig.update_traces(marker=dict(size=8))
|
||||
|
||||
fig.update_layout(
|
||||
height=None,
|
||||
autosize=True,
|
||||
margin=dict(l=0, r=0, t=50, b=0)
|
||||
)
|
||||
return fig
|
||||
|
||||
def _create_dual_plot(self, plot_data: PlotData, dimensions: str,
|
||||
color_by: str, method: str) -> go.Figure:
|
||||
fig = go.Figure()
|
||||
|
||||
doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
|
||||
doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
|
||||
|
||||
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
|
||||
|
||||
if dimensions == '3d':
|
||||
doc_fig = px.scatter_3d(
|
||||
doc_df, x='dim_1', y='dim_2', z='dim_3',
|
||||
color=doc_color_values,
|
||||
hover_data=hover_fields
|
||||
)
|
||||
else:
|
||||
doc_fig = px.scatter(
|
||||
doc_df, x='dim_1', y='dim_2',
|
||||
color=doc_color_values,
|
||||
hover_data=hover_fields
|
||||
)
|
||||
|
||||
for trace in doc_fig.data:
|
||||
trace.name = f'Documents - {trace.name}'
|
||||
if dimensions == '3d':
|
||||
trace.marker.size = 5
|
||||
trace.marker.symbol = 'circle'
|
||||
else:
|
||||
trace.marker.size = 8
|
||||
trace.marker.symbol = 'circle'
|
||||
trace.marker.opacity = 1.0
|
||||
fig.add_trace(trace)
|
||||
|
||||
if plot_data.prompts and plot_data.prompt_coordinates is not None:
|
||||
prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions)
|
||||
prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by)
|
||||
|
||||
if dimensions == '3d':
|
||||
prompt_fig = px.scatter_3d(
|
||||
prompt_df, x='dim_1', y='dim_2', z='dim_3',
|
||||
color=prompt_color_values,
|
||||
hover_data=hover_fields
|
||||
)
|
||||
else:
|
||||
prompt_fig = px.scatter(
|
||||
prompt_df, x='dim_1', y='dim_2',
|
||||
color=prompt_color_values,
|
||||
hover_data=hover_fields
|
||||
)
|
||||
|
||||
for trace in prompt_fig.data:
|
||||
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
|
||||
trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color)
|
||||
|
||||
trace.name = f'Prompts - {trace.name}'
|
||||
if dimensions == '3d':
|
||||
trace.marker.size = 6
|
||||
trace.marker.symbol = 'diamond'
|
||||
else:
|
||||
trace.marker.size = 10
|
||||
trace.marker.symbol = 'diamond'
|
||||
trace.marker.opacity = 0.8
|
||||
fig.add_trace(trace)
|
||||
|
||||
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
|
||||
fig.update_layout(
|
||||
title=title,
|
||||
height=None,
|
||||
autosize=True,
|
||||
margin=dict(l=0, r=0, t=50, b=0)
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame:
|
||||
df_data = []
|
||||
for i, doc in enumerate(documents):
|
||||
row = {
|
||||
'id': doc.id,
|
||||
'text': doc.text,
|
||||
'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text,
|
||||
'category': doc.category,
|
||||
'subcategory': doc.subcategory,
|
||||
'tags_str': ', '.join(doc.tags) if doc.tags else 'None',
|
||||
'dim_1': coordinates[i, 0],
|
||||
'dim_2': coordinates[i, 1],
|
||||
}
|
||||
if dimensions == '3d':
|
||||
row['dim_3'] = coordinates[i, 2]
|
||||
df_data.append(row)
|
||||
|
||||
return pd.DataFrame(df_data)
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
73
tests/test_data_processing.py
Normal file
73
tests/test_data_processing.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from src.embeddingbuddy.data.parser import NDJSONParser
|
||||
from src.embeddingbuddy.data.processor import DataProcessor
|
||||
from src.embeddingbuddy.models.schemas import Document
|
||||
|
||||
|
||||
class TestNDJSONParser:
|
||||
|
||||
def test_parse_text_basic(self):
|
||||
text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
|
||||
documents = NDJSONParser.parse_text(text_content)
|
||||
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == "test1"
|
||||
assert documents[0].text == "Hello world"
|
||||
assert documents[0].embedding == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_parse_text_with_metadata(self):
|
||||
text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
|
||||
documents = NDJSONParser.parse_text(text_content)
|
||||
|
||||
assert documents[0].category == "greeting"
|
||||
assert documents[0].tags == ["test"]
|
||||
|
||||
def test_parse_text_missing_id(self):
|
||||
text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
|
||||
documents = NDJSONParser.parse_text(text_content)
|
||||
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id is not None # Should be auto-generated
|
||||
|
||||
|
||||
class TestDataProcessor:
|
||||
|
||||
def test_extract_embeddings(self):
|
||||
documents = [
|
||||
Document(id="1", text="test1", embedding=[0.1, 0.2]),
|
||||
Document(id="2", text="test2", embedding=[0.3, 0.4])
|
||||
]
|
||||
|
||||
processor = DataProcessor()
|
||||
embeddings = processor._extract_embeddings(documents)
|
||||
|
||||
assert embeddings.shape == (2, 2)
|
||||
assert np.allclose(embeddings[0], [0.1, 0.2])
|
||||
assert np.allclose(embeddings[1], [0.3, 0.4])
|
||||
|
||||
def test_combine_data(self):
|
||||
from src.embeddingbuddy.models.schemas import ProcessedData
|
||||
|
||||
doc_data = ProcessedData(
|
||||
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
|
||||
embeddings=np.array([[0.1, 0.2]])
|
||||
)
|
||||
|
||||
prompt_data = ProcessedData(
|
||||
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
|
||||
embeddings=np.array([[0.3, 0.4]])
|
||||
)
|
||||
|
||||
processor = DataProcessor()
|
||||
all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data)
|
||||
|
||||
assert all_embeddings.shape == (2, 2)
|
||||
assert len(documents) == 1
|
||||
assert len(prompts) == 1
|
||||
assert documents[0].id == "1"
|
||||
assert prompts[0].id == "p1"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
89
tests/test_reducers.py
Normal file
89
tests/test_reducers.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer
|
||||
|
||||
|
||||
class TestReducerFactory:
|
||||
|
||||
def test_create_pca_reducer(self):
|
||||
reducer = ReducerFactory.create_reducer('pca', n_components=2)
|
||||
assert isinstance(reducer, PCAReducer)
|
||||
assert reducer.n_components == 2
|
||||
|
||||
def test_create_tsne_reducer(self):
|
||||
reducer = ReducerFactory.create_reducer('tsne', n_components=3)
|
||||
assert isinstance(reducer, TSNEReducer)
|
||||
assert reducer.n_components == 3
|
||||
|
||||
def test_create_umap_reducer(self):
|
||||
reducer = ReducerFactory.create_reducer('umap', n_components=2)
|
||||
assert isinstance(reducer, UMAPReducer)
|
||||
assert reducer.n_components == 2
|
||||
|
||||
def test_invalid_method(self):
|
||||
with pytest.raises(ValueError, match="Unknown reduction method"):
|
||||
ReducerFactory.create_reducer('invalid_method')
|
||||
|
||||
def test_available_methods(self):
|
||||
methods = ReducerFactory.get_available_methods()
|
||||
assert 'pca' in methods
|
||||
assert 'tsne' in methods
|
||||
assert 'umap' in methods
|
||||
|
||||
|
||||
class TestPCAReducer:
|
||||
|
||||
def test_fit_transform(self):
|
||||
embeddings = np.random.rand(100, 512)
|
||||
reducer = PCAReducer(n_components=2)
|
||||
|
||||
result = reducer.fit_transform(embeddings)
|
||||
|
||||
assert result.reduced_embeddings.shape == (100, 2)
|
||||
assert result.variance_explained is not None
|
||||
assert result.method == "PCA"
|
||||
assert result.n_components == 2
|
||||
|
||||
def test_method_name(self):
|
||||
reducer = PCAReducer()
|
||||
assert reducer.get_method_name() == "PCA"
|
||||
|
||||
|
||||
class TestTSNEReducer:
|
||||
|
||||
def test_fit_transform_small_dataset(self):
|
||||
embeddings = np.random.rand(30, 10) # Small dataset for faster testing
|
||||
reducer = TSNEReducer(n_components=2)
|
||||
|
||||
result = reducer.fit_transform(embeddings)
|
||||
|
||||
assert result.reduced_embeddings.shape == (30, 2)
|
||||
assert result.variance_explained is None # t-SNE doesn't provide this
|
||||
assert result.method == "t-SNE"
|
||||
assert result.n_components == 2
|
||||
|
||||
def test_method_name(self):
|
||||
reducer = TSNEReducer()
|
||||
assert reducer.get_method_name() == "t-SNE"
|
||||
|
||||
|
||||
class TestUMAPReducer:
|
||||
|
||||
def test_fit_transform(self):
|
||||
embeddings = np.random.rand(50, 10)
|
||||
reducer = UMAPReducer(n_components=2)
|
||||
|
||||
result = reducer.fit_transform(embeddings)
|
||||
|
||||
assert result.reduced_embeddings.shape == (50, 2)
|
||||
assert result.variance_explained is None # UMAP doesn't provide this
|
||||
assert result.method == "UMAP"
|
||||
assert result.n_components == 2
|
||||
|
||||
def test_method_name(self):
|
||||
reducer = UMAPReducer()
|
||||
assert reducer.get_method_name() == "UMAP"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
47
uv.lock
generated
47
uv.lock
generated
@@ -133,7 +133,7 @@ wheels = [
|
||||
[[package]]
|
||||
name = "embeddingbuddy"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "dash" },
|
||||
{ name = "dash-bootstrap-components" },
|
||||
@@ -142,6 +142,7 @@ dependencies = [
|
||||
{ name = "opentsne" },
|
||||
{ name = "pandas" },
|
||||
{ name = "plotly" },
|
||||
{ name = "pytest" },
|
||||
{ name = "scikit-learn" },
|
||||
{ name = "umap-learn" },
|
||||
]
|
||||
@@ -155,6 +156,7 @@ requires-dist = [
|
||||
{ name = "opentsne", specifier = ">=1.0.0" },
|
||||
{ name = "pandas", specifier = ">=2.1.4" },
|
||||
{ name = "plotly", specifier = ">=5.17.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.1" },
|
||||
{ name = "scikit-learn", specifier = ">=1.3.2" },
|
||||
{ name = "umap-learn", specifier = ">=0.5.8" },
|
||||
]
|
||||
@@ -197,6 +199,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/b0/36bd937216ec521246249be3bf9855081de4c5e06a0c9b4219dbeda50373/importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd", size = 27656, upload-time = "2025-04-27T15:29:00.214Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.2.0"
|
||||
@@ -473,6 +484,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/95/a9/12e2dc726ba1ba775a2c6922d5d5b4488ad60bdab0888c337c194c8e6de8/plotly-6.3.0-py3-none-any.whl", hash = "sha256:7ad806edce9d3cdd882eaebaf97c0c9e252043ed1ed3d382c3e3520ec07806d4", size = 9791257, upload-time = "2025-08-12T20:22:09.205Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pynndescent"
|
||||
version = "0.5.13"
|
||||
@@ -489,6 +518,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/53/d23a97e0a2c690d40b165d1062e2c4ccc796be458a1ce59f6ba030434663/pynndescent-0.5.13-py3-none-any.whl", hash = "sha256:69aabb8f394bc631b6ac475a1c7f3994c54adf3f51cd63b2730fefba5771b949", size = 56850, upload-time = "2024-06-17T15:48:31.184Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
|
Reference in New Issue
Block a user