add ci workflows (#1)
All checks were successful
Security Scan / security (push) Successful in 30s
Security Scan / dependency-check (push) Successful in 25s
Test Suite / test (3.11) (push) Successful in 1m16s
Test Suite / lint (push) Successful in 20s
Test Suite / build (push) Successful in 35s

Reviewed-on: #1
This commit is contained in:
2025-08-13 21:03:42 -07:00
parent 809dbeb783
commit 1ec7e2c38c
24 changed files with 2069 additions and 532 deletions

View File

@@ -0,0 +1,92 @@
name: Release
on:
push:
tags:
- 'v*'
workflow_dispatch:
inputs:
version:
description: 'Release version (e.g., v1.0.0)'
required: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Install dependencies
run: uv sync --extra test
- name: Run full test suite
run: uv run pytest tests/ -v --cov=src/embeddingbuddy --cov-report=term-missing
build-and-release:
runs-on: ubuntu-latest
needs: test
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Install dependencies
run: uv sync
- name: Build package
run: uv build
- name: Create release notes
run: |
echo "# Release Notes" > release-notes.md
echo "" >> release-notes.md
echo "## What's New" >> release-notes.md
echo "" >> release-notes.md
echo "- Modular architecture with improved testability" >> release-notes.md
echo "- Comprehensive test suite" >> release-notes.md
echo "- Enhanced documentation" >> release-notes.md
echo "- Security scanning and dependency management" >> release-notes.md
echo "" >> release-notes.md
echo "## Installation" >> release-notes.md
echo "" >> release-notes.md
echo '```bash' >> release-notes.md
echo 'uv sync' >> release-notes.md
echo 'uv run python main.py' >> release-notes.md
echo '```' >> release-notes.md
- name: Create Release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITEA_TOKEN }}
with:
tag_name: ${{ github.ref_name || github.event.inputs.version }}
release_name: Release ${{ github.ref_name || github.event.inputs.version }}
body_path: release-notes.md
draft: false
prerelease: false
- name: Upload Release Assets
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITEA_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: dist/
asset_name: embeddingbuddy-dist
asset_content_type: application/zip

View File

@@ -0,0 +1,70 @@
name: Security Scan
on:
push:
branches: ["main", "master", "develop"]
pull_request:
branches: ["main", "master"]
schedule:
# Run security scan weekly on Sundays at 2 AM UTC
- cron: '0 2 * * 0'
jobs:
security:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Install dependencies
run: uv sync --extra security
- name: Run bandit security linter
run: uv run bandit -r src/ -f json -o bandit-report.json
continue-on-error: true
- name: Run safety vulnerability check
run: uv run safety check --json --save-json safety-report.json
continue-on-error: true
- name: Upload security reports
uses: actions/upload-artifact@v3
with:
name: security-reports
path: |
bandit-report.json
safety-report.json
dependency-check:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Check for dependency vulnerabilities
run: |
uv sync --extra security
uv run pip-audit --format=json --output=pip-audit-report.json
continue-on-error: true
- name: Upload dependency audit report
uses: actions/upload-artifact@v3
with:
name: dependency-audit
path: pip-audit-report.json

104
.gitea/workflows/test.yml Normal file
View File

@@ -0,0 +1,104 @@
name: Test Suite
on:
push:
branches:
- "main"
- "develop"
pull_request:
branches:
- "main"
workflow_dispatch:
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
run: uv sync --extra test
- name: Run tests with pytest
run: uv run pytest tests/ -v --tb=short
- name: Run tests with coverage
run: uv run pytest tests/ --cov=src/embeddingbuddy --cov-report=term-missing --cov-report=xml
- name: Upload coverage reports
uses: codecov/codecov-action@v4
if: matrix.python-version == '3.11'
with:
file: ./coverage.xml
fail_ci_if_error: false
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Install dependencies
run: uv sync --extra lint
- name: Run ruff linter
run: uv run ruff check src/ tests/
- name: Run ruff formatter check
run: uv run ruff format --check src/ tests/
# TODO fix this it throws errors
# - name: Run mypy type checker
# run: uv run mypy src/embeddingbuddy/ --ignore-missing-imports
build:
runs-on: ubuntu-latest
needs: [test, lint]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Set up Python
run: uv python install 3.11
- name: Install dependencies
run: uv sync
- name: Build package
run: uv build
- name: Test installation
run: |
uv run python -c "from src.embeddingbuddy.app import create_app; app = create_app(); print('✅ Package builds and imports successfully')"
- name: Upload build artifacts
uses: actions/upload-artifact@v3
with:
name: dist-files
path: dist/

76
.gitignore vendored
View File

@@ -1,12 +1,84 @@
# Python-generated files # Python-generated files
__pycache__/ __pycache__/
*.py[oc] *.py[oc]
*.py[cod]
*$py.class
*.so
.Python
build/ build/
develop-eggs/
dist/ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/ wheels/
*.egg-info share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Virtual environments # Virtual environments
.env
.venv .venv
env/
venv/
ENV/
env.bak/
venv.bak/
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Project specific
*.log
.mypy_cache/
.dmypy.json
dmypy.json
temp/ temp/
todo/ todo/
# Security reports
bandit-report.json
safety-report.json
pip-audit-report.json
# Temporary files
*.tmp

View File

@@ -30,9 +30,28 @@ The app will be available at http://127.0.0.1:8050
**Run tests:** **Run tests:**
```bash ```bash
uv sync --extra test
uv run pytest tests/ -v uv run pytest tests/ -v
``` ```
**Development tools:**
```bash
# Install all dev dependencies
uv sync --extra dev
# Linting and formatting
uv run ruff check src/ tests/
uv run ruff format src/ tests/
# Type checking
uv run mypy src/embeddingbuddy/
# Security scanning
uv run bandit -r src/
uv run safety check
```
**Test with sample data:** **Test with sample data:**
Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality. Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality.
@@ -42,7 +61,7 @@ Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for test
The application follows a modular architecture with clear separation of concerns: The application follows a modular architecture with clear separation of concerns:
``` ```text
src/embeddingbuddy/ src/embeddingbuddy/
├── app.py # Main application entry point and factory ├── app.py # Main application entry point and factory
├── main.py # Application runner ├── main.py # Application runner
@@ -72,27 +91,32 @@ src/embeddingbuddy/
### Key Components ### Key Components
**Data Layer:** **Data Layer:**
- `data/parser.py` - NDJSON parsing with error handling - `data/parser.py` - NDJSON parsing with error handling
- `data/processor.py` - Data transformation and combination logic - `data/processor.py` - Data transformation and combination logic
- `models/schemas.py` - Dataclasses for type safety and validation - `models/schemas.py` - Dataclasses for type safety and validation
**Algorithm Layer:** **Algorithm Layer:**
- `models/reducers.py` - Modular dimensionality reduction with factory pattern - `models/reducers.py` - Modular dimensionality reduction with factory pattern
- Supports PCA, t-SNE (openTSNE), and UMAP algorithms - Supports PCA, t-SNE (openTSNE), and UMAP algorithms
- Abstract base class for easy extension - Abstract base class for easy extension
**Visualization Layer:** **Visualization Layer:**
- `visualization/plots.py` - Plot factory with single and dual plot support - `visualization/plots.py` - Plot factory with single and dual plot support
- `visualization/colors.py` - Color mapping and grayscale conversion utilities - `visualization/colors.py` - Color mapping and grayscale conversion utilities
- Plotly-based 2D/3D scatter plots with interactive features - Plotly-based 2D/3D scatter plots with interactive features
**UI Layer:** **UI Layer:**
- `ui/layout.py` - Main application layout composition - `ui/layout.py` - Main application layout composition
- `ui/components/` - Reusable, testable UI components - `ui/components/` - Reusable, testable UI components
- `ui/callbacks/` - Organized callbacks grouped by functionality - `ui/callbacks/` - Organized callbacks grouped by functionality
- Bootstrap-styled sidebar with controls and large visualization area - Bootstrap-styled sidebar with controls and large visualization area
**Configuration:** **Configuration:**
- `config/settings.py` - Centralized settings with environment variable support - `config/settings.py` - Centralized settings with environment variable support
- Plot styling, marker configurations, and app-wide constants - Plot styling, marker configurations, and app-wide constants
@@ -112,16 +136,19 @@ Optional fields: `id`, `category`, `subcategory`, `tags`
The refactored callback system is organized by functionality: The refactored callback system is organized by functionality:
**Data Processing (`ui/callbacks/data_processing.py`):** **Data Processing (`ui/callbacks/data_processing.py`):**
- File upload handling - File upload handling
- NDJSON parsing and validation - NDJSON parsing and validation
- Data storage in dcc.Store components - Data storage in dcc.Store components
**Visualization (`ui/callbacks/visualization.py`):** **Visualization (`ui/callbacks/visualization.py`):**
- Dimensionality reduction pipeline - Dimensionality reduction pipeline
- Plot generation and updates - Plot generation and updates
- Method/parameter change handling - Method/parameter change handling
**Interactions (`ui/callbacks/interactions.py`):** **Interactions (`ui/callbacks/interactions.py`):**
- Point click handling and detail display - Point click handling and detail display
- Reset functionality - Reset functionality
- User interaction management - User interaction management
@@ -131,15 +158,18 @@ The refactored callback system is organized by functionality:
The modular design enables comprehensive testing: The modular design enables comprehensive testing:
**Unit Tests:** **Unit Tests:**
- `tests/test_data_processing.py` - Parser and processor logic - `tests/test_data_processing.py` - Parser and processor logic
- `tests/test_reducers.py` - Dimensionality reduction algorithms - `tests/test_reducers.py` - Dimensionality reduction algorithms
- `tests/test_visualization.py` - Plot creation and color mapping - `tests/test_visualization.py` - Plot creation and color mapping
**Integration Tests:** **Integration Tests:**
- End-to-end data pipeline testing - End-to-end data pipeline testing
- Component integration verification - Component integration verification
**Key Testing Benefits:** **Key Testing Benefits:**
- Fast test execution (milliseconds vs seconds) - Fast test execution (milliseconds vs seconds)
- Isolated component testing - Isolated component testing
- Easy mocking and fixture creation - Easy mocking and fixture creation
@@ -167,6 +197,7 @@ Uses modern Python stack with uv for dependency management:
5. **Tests** - Write tests for all new functionality 5. **Tests** - Write tests for all new functionality
**Code Organization Principles:** **Code Organization Principles:**
- Single responsibility principle - Single responsibility principle
- Clear module boundaries - Clear module boundaries
- Testable, isolated components - Testable, isolated components
@@ -174,7 +205,8 @@ Uses modern Python stack with uv for dependency management:
- Error handling at appropriate layers - Error handling at appropriate layers
**Testing Requirements:** **Testing Requirements:**
- Unit tests for all core logic - Unit tests for all core logic
- Integration tests for data flow - Integration tests for data flow
- Component tests for UI elements - Component tests for UI elements
- Maintain high test coverage - Maintain high test coverage

View File

@@ -14,7 +14,28 @@ dependencies = [
"umap-learn>=0.5.8", "umap-learn>=0.5.8",
"numba>=0.56.4", "numba>=0.56.4",
"openTSNE>=1.0.0", "openTSNE>=1.0.0",
"mypy>=1.17.1",
]
[project.optional-dependencies]
test = [
"pytest>=8.4.1", "pytest>=8.4.1",
"pytest-cov>=4.1.0",
]
lint = [
"ruff>=0.1.0",
"mypy>=1.5.0",
]
security = [
"bandit[toml]>=1.7.5",
"safety>=2.3.0",
"pip-audit>=2.6.0",
]
dev = [
"embeddingbuddy[test,lint,security]",
]
all = [
"embeddingbuddy[test,lint,security]",
] ]
[build-system] [build-system]

View File

@@ -1,3 +1,3 @@
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors.""" """EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@@ -8,32 +8,29 @@ from .ui.callbacks.interactions import InteractionCallbacks
def create_app(): def create_app():
app = dash.Dash( app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP]
)
layout_manager = AppLayout() layout_manager = AppLayout()
app.layout = layout_manager.create_layout() app.layout = layout_manager.create_layout()
DataProcessingCallbacks() DataProcessingCallbacks()
VisualizationCallbacks() VisualizationCallbacks()
InteractionCallbacks() InteractionCallbacks()
return app return app
def run_app(app=None, debug=None, host=None, port=None): def run_app(app=None, debug=None, host=None, port=None):
if app is None: if app is None:
app = create_app() app = create_app()
app.run( app.run(
debug=debug if debug is not None else AppSettings.DEBUG, debug=debug if debug is not None else AppSettings.DEBUG,
host=host if host is not None else AppSettings.HOST, host=host if host is not None else AppSettings.HOST,
port=port if port is not None else AppSettings.PORT port=port if port is not None else AppSettings.PORT,
) )
if __name__ == '__main__': if __name__ == "__main__":
app = create_app() app = create_app()
run_app(app) run_app(app)

View File

@@ -3,105 +3,100 @@ import os
class AppSettings: class AppSettings:
# UI Configuration # UI Configuration
UPLOAD_STYLE = { UPLOAD_STYLE = {
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px' "margin-bottom": "20px",
} }
PROMPTS_UPLOAD_STYLE = { PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"}
**UPLOAD_STYLE,
'borderColor': '#28a745' PLOT_CONFIG = {"responsive": True, "displayModeBar": True}
}
PLOT_STYLE = {"height": "85vh", "width": "100%"}
PLOT_CONFIG = {
'responsive': True,
'displayModeBar': True
}
PLOT_STYLE = {
'height': '85vh',
'width': '100%'
}
PLOT_LAYOUT_CONFIG = { PLOT_LAYOUT_CONFIG = {
'height': None, "height": None,
'autosize': True, "autosize": True,
'margin': dict(l=0, r=0, t=50, b=0) "margin": dict(l=0, r=0, t=50, b=0),
} }
# Dimensionality Reduction Settings # Dimensionality Reduction Settings
DEFAULT_N_COMPONENTS_3D = 3 DEFAULT_N_COMPONENTS_3D = 3
DEFAULT_N_COMPONENTS_2D = 2 DEFAULT_N_COMPONENTS_2D = 2
DEFAULT_RANDOM_STATE = 42 DEFAULT_RANDOM_STATE = 42
# Available Methods # Available Methods
REDUCTION_METHODS = [ REDUCTION_METHODS = [
{'label': 'PCA', 'value': 'pca'}, {"label": "PCA", "value": "pca"},
{'label': 't-SNE', 'value': 'tsne'}, {"label": "t-SNE", "value": "tsne"},
{'label': 'UMAP', 'value': 'umap'} {"label": "UMAP", "value": "umap"},
] ]
COLOR_OPTIONS = [ COLOR_OPTIONS = [
{'label': 'Category', 'value': 'category'}, {"label": "Category", "value": "category"},
{'label': 'Subcategory', 'value': 'subcategory'}, {"label": "Subcategory", "value": "subcategory"},
{'label': 'Tags', 'value': 'tags'} {"label": "Tags", "value": "tags"},
] ]
DIMENSION_OPTIONS = [ DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}]
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
]
# Default Values # Default Values
DEFAULT_METHOD = 'pca' DEFAULT_METHOD = "pca"
DEFAULT_COLOR_BY = 'category' DEFAULT_COLOR_BY = "category"
DEFAULT_DIMENSIONS = '3d' DEFAULT_DIMENSIONS = "3d"
DEFAULT_SHOW_PROMPTS = ['show'] DEFAULT_SHOW_PROMPTS = ["show"]
# Plot Marker Settings # Plot Marker Settings
DOCUMENT_MARKER_SIZE_2D = 8 DOCUMENT_MARKER_SIZE_2D = 8
DOCUMENT_MARKER_SIZE_3D = 5 DOCUMENT_MARKER_SIZE_3D = 5
PROMPT_MARKER_SIZE_2D = 10 PROMPT_MARKER_SIZE_2D = 10
PROMPT_MARKER_SIZE_3D = 6 PROMPT_MARKER_SIZE_3D = 6
DOCUMENT_MARKER_SYMBOL = 'circle' DOCUMENT_MARKER_SYMBOL = "circle"
PROMPT_MARKER_SYMBOL = 'diamond' PROMPT_MARKER_SYMBOL = "diamond"
DOCUMENT_OPACITY = 1.0 DOCUMENT_OPACITY = 1.0
PROMPT_OPACITY = 0.8 PROMPT_OPACITY = 0.8
# Text Processing # Text Processing
TEXT_PREVIEW_LENGTH = 100 TEXT_PREVIEW_LENGTH = 100
# App Configuration # App Configuration
DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true' DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true"
HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1') HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1")
PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050')) PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050"))
# Bootstrap Theme # Bootstrap Theme
EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css'] EXTERNAL_STYLESHEETS = [
"https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
]
@classmethod @classmethod
def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]: def get_plot_marker_config(
cls, dimensions: str, is_prompt: bool = False
) -> Dict[str, Any]:
if is_prompt: if is_prompt:
size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D size = (
cls.PROMPT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.PROMPT_MARKER_SIZE_2D
)
symbol = cls.PROMPT_MARKER_SYMBOL symbol = cls.PROMPT_MARKER_SYMBOL
opacity = cls.PROMPT_OPACITY opacity = cls.PROMPT_OPACITY
else: else:
size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D size = (
cls.DOCUMENT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.DOCUMENT_MARKER_SIZE_2D
)
symbol = cls.DOCUMENT_MARKER_SYMBOL symbol = cls.DOCUMENT_MARKER_SYMBOL
opacity = cls.DOCUMENT_OPACITY opacity = cls.DOCUMENT_OPACITY
return { return {"size": size, "symbol": symbol, "opacity": opacity}
'size': size,
'symbol': symbol,
'opacity': opacity
}

View File

@@ -1,39 +1,38 @@
import json import json
import uuid import uuid
import base64 import base64
from typing import List, Union from typing import List
from ..models.schemas import Document, ProcessedData from ..models.schemas import Document
class NDJSONParser: class NDJSONParser:
@staticmethod @staticmethod
def parse_upload_contents(contents: str) -> List[Document]: def parse_upload_contents(contents: str) -> List[Document]:
content_type, content_string = contents.split(',') content_type, content_string = contents.split(",")
decoded = base64.b64decode(content_string) decoded = base64.b64decode(content_string)
text_content = decoded.decode('utf-8') text_content = decoded.decode("utf-8")
return NDJSONParser.parse_text(text_content) return NDJSONParser.parse_text(text_content)
@staticmethod @staticmethod
def parse_text(text_content: str) -> List[Document]: def parse_text(text_content: str) -> List[Document]:
documents = [] documents = []
for line in text_content.strip().split('\n'): for line in text_content.strip().split("\n"):
if line.strip(): if line.strip():
doc_dict = json.loads(line) doc_dict = json.loads(line)
doc = NDJSONParser._dict_to_document(doc_dict) doc = NDJSONParser._dict_to_document(doc_dict)
documents.append(doc) documents.append(doc)
return documents return documents
@staticmethod @staticmethod
def _dict_to_document(doc_dict: dict) -> Document: def _dict_to_document(doc_dict: dict) -> Document:
if 'id' not in doc_dict: if "id" not in doc_dict:
doc_dict['id'] = str(uuid.uuid4()) doc_dict["id"] = str(uuid.uuid4())
return Document( return Document(
id=doc_dict['id'], id=doc_dict["id"],
text=doc_dict['text'], text=doc_dict["text"],
embedding=doc_dict['embedding'], embedding=doc_dict["embedding"],
category=doc_dict.get('category'), category=doc_dict.get("category"),
subcategory=doc_dict.get('subcategory'), subcategory=doc_dict.get("subcategory"),
tags=doc_dict.get('tags') tags=doc_dict.get("tags"),
) )

View File

@@ -5,18 +5,19 @@ from .parser import NDJSONParser
class DataProcessor: class DataProcessor:
def __init__(self): def __init__(self):
self.parser = NDJSONParser() self.parser = NDJSONParser()
def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData: def process_upload(
self, contents: str, filename: Optional[str] = None
) -> ProcessedData:
try: try:
documents = self.parser.parse_upload_contents(contents) documents = self.parser.parse_upload_contents(contents)
embeddings = self._extract_embeddings(documents) embeddings = self._extract_embeddings(documents)
return ProcessedData(documents=documents, embeddings=embeddings) return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e: except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def process_text(self, text_content: str) -> ProcessedData: def process_text(self, text_content: str) -> ProcessedData:
try: try:
documents = self.parser.parse_text(text_content) documents = self.parser.parse_text(text_content)
@@ -24,31 +25,35 @@ class DataProcessor:
return ProcessedData(documents=documents, embeddings=embeddings) return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e: except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e)) return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray: def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
if not documents: if not documents:
return np.array([]) return np.array([])
return np.array([doc.embedding for doc in documents]) return np.array([doc.embedding for doc in documents])
def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]: def combine_data(
self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None
) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
if not doc_data or doc_data.error: if not doc_data or doc_data.error:
raise ValueError("Invalid document data") raise ValueError("Invalid document data")
all_embeddings = doc_data.embeddings all_embeddings = doc_data.embeddings
documents = doc_data.documents documents = doc_data.documents
prompts = None prompts = None
if prompt_data and not prompt_data.error and prompt_data.documents: if prompt_data and not prompt_data.error and prompt_data.documents:
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings]) all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
prompts = prompt_data.documents prompts = prompt_data.documents
return all_embeddings, documents, prompts return all_embeddings, documents, prompts
def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]: def split_reduced_data(
self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
doc_reduced = reduced_embeddings[:n_documents] doc_reduced = reduced_embeddings[:n_documents]
prompt_reduced = None prompt_reduced = None
if n_prompts > 0: if n_prompts > 0:
prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts] prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts]
return doc_reduced, prompt_reduced return doc_reduced, prompt_reduced

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import numpy as np import numpy as np
from typing import Optional, Tuple
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
import umap import umap
from openTSNE import TSNE from openTSNE import TSNE
@@ -8,88 +7,89 @@ from .schemas import ReducedData
class DimensionalityReducer(ABC): class DimensionalityReducer(ABC):
def __init__(self, n_components: int = 3, random_state: int = 42): def __init__(self, n_components: int = 3, random_state: int = 42):
self.n_components = n_components self.n_components = n_components
self.random_state = random_state self.random_state = random_state
self._reducer = None self._reducer = None
@abstractmethod @abstractmethod
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
pass pass
@abstractmethod @abstractmethod
def get_method_name(self) -> str: def get_method_name(self) -> str:
pass pass
class PCAReducer(DimensionalityReducer): class PCAReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = PCA(n_components=self.n_components) self._reducer = PCA(n_components=self.n_components)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
variance_explained = self._reducer.explained_variance_ratio_ variance_explained = self._reducer.explained_variance_ratio_
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=variance_explained, variance_explained=variance_explained,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "PCA" return "PCA"
class TSNEReducer(DimensionalityReducer): class TSNEReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state) self._reducer = TSNE(
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit(embeddings) reduced = self._reducer.fit(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "t-SNE" return "t-SNE"
class UMAPReducer(DimensionalityReducer): class UMAPReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state) self._reducer = umap.UMAP(
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components n_components=self.n_components,
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
return "UMAP" return "UMAP"
class ReducerFactory: class ReducerFactory:
@staticmethod @staticmethod
def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer: def create_reducer(
method: str, n_components: int = 3, random_state: int = 42
) -> DimensionalityReducer:
method_lower = method.lower() method_lower = method.lower()
if method_lower == 'pca': if method_lower == "pca":
return PCAReducer(n_components=n_components, random_state=random_state) return PCAReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'tsne': elif method_lower == "tsne":
return TSNEReducer(n_components=n_components, random_state=random_state) return TSNEReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'umap': elif method_lower == "umap":
return UMAPReducer(n_components=n_components, random_state=random_state) return UMAPReducer(n_components=n_components, random_state=random_state)
else: else:
raise ValueError(f"Unknown reduction method: {method}") raise ValueError(f"Unknown reduction method: {method}")
@staticmethod @staticmethod
def get_available_methods() -> list: def get_available_methods() -> list:
return ['pca', 'tsne', 'umap'] return ["pca", "tsne", "umap"]

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Any, Dict from typing import List, Optional
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@@ -50,9 +50,11 @@ class PlotData:
coordinates: np.ndarray coordinates: np.ndarray
prompts: Optional[List[Document]] = None prompts: Optional[List[Document]] = None
prompt_coordinates: Optional[np.ndarray] = None prompt_coordinates: Optional[np.ndarray] = None
def __post_init__(self): def __post_init__(self):
if not isinstance(self.coordinates, np.ndarray): if not isinstance(self.coordinates, np.ndarray):
self.coordinates = np.array(self.coordinates) self.coordinates = np.array(self.coordinates)
if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray): if self.prompt_coordinates is not None and not isinstance(
self.prompt_coordinates = np.array(self.prompt_coordinates) self.prompt_coordinates, np.ndarray
):
self.prompt_coordinates = np.array(self.prompt_coordinates)

View File

@@ -1,61 +1,62 @@
import numpy as np
from dash import callback, Input, Output, State from dash import callback, Input, Output, State
from ...data.processor import DataProcessor from ...data.processor import DataProcessor
class DataProcessingCallbacks: class DataProcessingCallbacks:
def __init__(self): def __init__(self):
self.processor = DataProcessor() self.processor = DataProcessor()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('processed-data', 'data'), Output("processed-data", "data"),
Input('upload-data', 'contents'), Input("upload-data", "contents"),
State('upload-data', 'filename') State("upload-data", "filename"),
) )
def process_uploaded_file(contents, filename): def process_uploaded_file(contents, filename):
if contents is None: if contents is None:
return None return None
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
return {'error': processed_data.error} return {"error": processed_data.error}
return { return {
'documents': [self._document_to_dict(doc) for doc in processed_data.documents], "documents": [
'embeddings': processed_data.embeddings.tolist() self._document_to_dict(doc) for doc in processed_data.documents
],
"embeddings": processed_data.embeddings.tolist(),
} }
@callback( @callback(
Output('processed-prompts', 'data'), Output("processed-prompts", "data"),
Input('upload-prompts', 'contents'), Input("upload-prompts", "contents"),
State('upload-prompts', 'filename') State("upload-prompts", "filename"),
) )
def process_uploaded_prompts(contents, filename): def process_uploaded_prompts(contents, filename):
if contents is None: if contents is None:
return None return None
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
return {'error': processed_data.error} return {"error": processed_data.error}
return { return {
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents], "prompts": [
'embeddings': processed_data.embeddings.tolist() self._document_to_dict(doc) for doc in processed_data.documents
],
"embeddings": processed_data.embeddings.tolist(),
} }
@staticmethod @staticmethod
def _document_to_dict(doc): def _document_to_dict(doc):
return { return {
'id': doc.id, "id": doc.id,
'text': doc.text, "text": doc.text,
'embedding': doc.embedding, "embedding": doc.embedding,
'category': doc.category, "category": doc.category,
'subcategory': doc.subcategory, "subcategory": doc.subcategory,
'tags': doc.tags "tags": doc.tags,
} }

View File

@@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc
class InteractionCallbacks: class InteractionCallbacks:
def __init__(self): def __init__(self):
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('point-details', 'children'), Output("point-details", "children"),
Input('embedding-plot', 'clickData'), Input("embedding-plot", "clickData"),
[State('processed-data', 'data'), [State("processed-data", "data"), State("processed-prompts", "data")],
State('processed-prompts', 'data')]
) )
def display_click_data(clickData, data, prompts_data): def display_click_data(clickData, data, prompts_data):
if not clickData or not data: if not clickData or not data:
return "Click on a point to see details" return "Click on a point to see details"
point_data = clickData['points'][0] point_data = clickData["points"][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents') trace_name = point_data.get("fullData", {}).get("name", "Documents")
if 'pointIndex' in point_data: if "pointIndex" in point_data:
point_index = point_data['pointIndex'] point_index = point_data["pointIndex"]
elif 'pointNumber' in point_data: elif "pointNumber" in point_data:
point_index = point_data['pointNumber'] point_index = point_data["pointNumber"]
else: else:
return "Could not identify clicked point" return "Could not identify clicked point"
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data: if (
item = prompts_data['prompts'][point_index] trace_name.startswith("Prompts")
item_type = 'Prompt' and prompts_data
and "prompts" in prompts_data
):
item = prompts_data["prompts"][point_index]
item_type = "Prompt"
else: else:
item = data['documents'][point_index] item = data["documents"][point_index]
item_type = 'Document' item_type = "Document"
return self._create_detail_card(item, item_type) return self._create_detail_card(item, item_type)
@callback( @callback(
[Output('processed-data', 'data', allow_duplicate=True), [
Output('processed-prompts', 'data', allow_duplicate=True), Output("processed-data", "data", allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)], Output("processed-prompts", "data", allow_duplicate=True),
Input('reset-button', 'n_clicks'), Output("point-details", "children", allow_duplicate=True),
prevent_initial_call=True ],
Input("reset-button", "n_clicks"),
prevent_initial_call=True,
) )
def reset_data(n_clicks): def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0: if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details" return None, None, "Click on a point to see details"
@staticmethod @staticmethod
def _create_detail_card(item, item_type): def _create_detail_card(item, item_type):
return dbc.Card([ return dbc.Card(
dbc.CardBody([ [
html.H5(f"{item_type}: {item['id']}", className="card-title"), dbc.CardBody(
html.P(f"Text: {item['text']}", className="card-text"), [
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"), html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"), html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"), html.P(
html.P(f"Type: {item_type}", className="card-text text-muted") f"Category: {item.get('category', 'Unknown')}",
]) className="card-text",
]) ),
html.P(
f"Subcategory: {item.get('subcategory', 'Unknown')}",
className="card-text",
),
html.P(
f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
className="card-text",
),
html.P(f"Type: {item_type}", className="card-text text-muted"),
]
)
]
)

View File

@@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory
class VisualizationCallbacks: class VisualizationCallbacks:
def __init__(self): def __init__(self):
self.plot_factory = PlotFactory() self.plot_factory = PlotFactory()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output('embedding-plot', 'figure'), Output("embedding-plot", "figure"),
[Input('processed-data', 'data'), [
Input('processed-prompts', 'data'), Input("processed-data", "data"),
Input('method-dropdown', 'value'), Input("processed-prompts", "data"),
Input('color-dropdown', 'value'), Input("method-dropdown", "value"),
Input('dimension-toggle', 'value'), Input("color-dropdown", "value"),
Input('show-prompts-toggle', 'value')] Input("dimension-toggle", "value"),
Input("show-prompts-toggle", "value"),
],
) )
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data: if not data or "error" in data:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization", text="Upload a valid NDJSON file to see visualization",
xref="paper", yref="paper", xref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle', yref="paper",
showarrow=False, font=dict(size=16) x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
try: try:
doc_embeddings = np.array(data['embeddings']) doc_embeddings = np.array(data["embeddings"])
all_embeddings = doc_embeddings all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts') has_prompts = (
prompts_data
and "error" not in prompts_data
and prompts_data.get("prompts")
)
if has_prompts: if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings']) prompt_embeddings = np.array(prompts_data["embeddings"])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2 n_components = 3 if dimensions == "3d" else 2
reducer = ReducerFactory.create_reducer(method, n_components=n_components) reducer = ReducerFactory.create_reducer(
method, n_components=n_components
)
reduced_data = reducer.fit_transform(all_embeddings) reduced_data = reducer.fit_transform(all_embeddings)
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)] doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
prompt_reduced = None prompt_reduced = None
if has_prompts: if has_prompts:
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):] prompt_reduced = reduced_data.reduced_embeddings[
len(doc_embeddings) :
documents = [self._dict_to_document(doc) for doc in data['documents']] ]
documents = [self._dict_to_document(doc) for doc in data["documents"]]
prompts = None prompts = None
if has_prompts: if has_prompts:
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']] prompts = [
self._dict_to_document(prompt)
for prompt in prompts_data["prompts"]
]
plot_data = PlotData( plot_data = PlotData(
documents=documents, documents=documents,
coordinates=doc_reduced, coordinates=doc_reduced,
prompts=prompts, prompts=prompts,
prompt_coordinates=prompt_reduced prompt_coordinates=prompt_reduced,
) )
return self.plot_factory.create_plot( return self.plot_factory.create_plot(
plot_data, dimensions, color_by, reduced_data.method, show_prompts plot_data, dimensions, color_by, reduced_data.method, show_prompts
) )
except Exception as e: except Exception as e:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text=f"Error creating visualization: {str(e)}", text=f"Error creating visualization: {str(e)}",
xref="paper", yref="paper", xref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle', yref="paper",
showarrow=False, font=dict(size=16) x=0.5,
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
@staticmethod @staticmethod
def _dict_to_document(doc_dict): def _dict_to_document(doc_dict):
return Document( return Document(
id=doc_dict['id'], id=doc_dict["id"],
text=doc_dict['text'], text=doc_dict["text"],
embedding=doc_dict['embedding'], embedding=doc_dict["embedding"],
category=doc_dict.get('category'), category=doc_dict.get("category"),
subcategory=doc_dict.get('subcategory'), subcategory=doc_dict.get("subcategory"),
tags=doc_dict.get('tags', []) tags=doc_dict.get("tags", []),
) )

View File

@@ -4,79 +4,81 @@ from .upload import UploadComponent
class SidebarComponent: class SidebarComponent:
def __init__(self): def __init__(self):
self.upload_component = UploadComponent() self.upload_component = UploadComponent()
def create_layout(self): def create_layout(self):
return dbc.Col([ return dbc.Col(
html.H5("Upload Data", className="mb-3"), [
self.upload_component.create_data_upload(), html.H5("Upload Data", className="mb-3"),
self.upload_component.create_prompts_upload(), self.upload_component.create_data_upload(),
self.upload_component.create_reset_button(), self.upload_component.create_prompts_upload(),
self.upload_component.create_reset_button(),
html.H5("Visualization Controls", className="mb-3"), html.H5("Visualization Controls", className="mb-3"),
self._create_method_dropdown(), self._create_method_dropdown(),
self._create_color_dropdown(), self._create_color_dropdown(),
self._create_dimension_toggle(), self._create_dimension_toggle(),
self._create_prompts_toggle(), self._create_prompts_toggle(),
html.H5("Point Details", className="mb-3"),
html.H5("Point Details", className="mb-3"), html.Div(
html.Div(id='point-details', children="Click on a point to see details") id="point-details", children="Click on a point to see details"
),
], width=3, style={'padding-right': '20px'}) ],
width=3,
style={"padding-right": "20px"},
)
def _create_method_dropdown(self): def _create_method_dropdown(self):
return [ return [
dbc.Label("Method:"), dbc.Label("Method:"),
dcc.Dropdown( dcc.Dropdown(
id='method-dropdown', id="method-dropdown",
options=[ options=[
{'label': 'PCA', 'value': 'pca'}, {"label": "PCA", "value": "pca"},
{'label': 't-SNE', 'value': 'tsne'}, {"label": "t-SNE", "value": "tsne"},
{'label': 'UMAP', 'value': 'umap'} {"label": "UMAP", "value": "umap"},
], ],
value='pca', value="pca",
style={'margin-bottom': '15px'} style={"margin-bottom": "15px"},
) ),
] ]
def _create_color_dropdown(self): def _create_color_dropdown(self):
return [ return [
dbc.Label("Color by:"), dbc.Label("Color by:"),
dcc.Dropdown( dcc.Dropdown(
id='color-dropdown', id="color-dropdown",
options=[ options=[
{'label': 'Category', 'value': 'category'}, {"label": "Category", "value": "category"},
{'label': 'Subcategory', 'value': 'subcategory'}, {"label": "Subcategory", "value": "subcategory"},
{'label': 'Tags', 'value': 'tags'} {"label": "Tags", "value": "tags"},
], ],
value='category', value="category",
style={'margin-bottom': '15px'} style={"margin-bottom": "15px"},
) ),
] ]
def _create_dimension_toggle(self): def _create_dimension_toggle(self):
return [ return [
dbc.Label("Dimensions:"), dbc.Label("Dimensions:"),
dcc.RadioItems( dcc.RadioItems(
id='dimension-toggle', id="dimension-toggle",
options=[ options=[
{'label': '2D', 'value': '2d'}, {"label": "2D", "value": "2d"},
{'label': '3D', 'value': '3d'} {"label": "3D", "value": "3d"},
], ],
value='3d', value="3d",
style={'margin-bottom': '20px'} style={"margin-bottom": "20px"},
) ),
] ]
def _create_prompts_toggle(self): def _create_prompts_toggle(self):
return [ return [
dbc.Label("Show Prompts:"), dbc.Label("Show Prompts:"),
dcc.Checklist( dcc.Checklist(
id='show-prompts-toggle', id="show-prompts-toggle",
options=[{'label': 'Show prompts on plot', 'value': 'show'}], options=[{"label": "Show prompts on plot", "value": "show"}],
value=['show'], value=["show"],
style={'margin-bottom': '20px'} style={"margin-bottom": "20px"},
) ),
] ]

View File

@@ -3,58 +3,51 @@ import dash_bootstrap_components as dbc
class UploadComponent: class UploadComponent:
@staticmethod @staticmethod
def create_data_upload(): def create_data_upload():
return dcc.Upload( return dcc.Upload(
id='upload-data', id="upload-data",
children=html.Div([ children=html.Div(["Drag and Drop or ", html.A("Select Files")]),
'Drag and Drop or ',
html.A('Select Files')
]),
style={ style={
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px' "margin-bottom": "20px",
}, },
multiple=False multiple=False,
) )
@staticmethod @staticmethod
def create_prompts_upload(): def create_prompts_upload():
return dcc.Upload( return dcc.Upload(
id='upload-prompts', id="upload-prompts",
children=html.Div([ children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]),
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={ style={
'width': '100%', "width": "100%",
'height': '60px', "height": "60px",
'lineHeight': '60px', "lineHeight": "60px",
'borderWidth': '1px', "borderWidth": "1px",
'borderStyle': 'dashed', "borderStyle": "dashed",
'borderRadius': '5px', "borderRadius": "5px",
'textAlign': 'center', "textAlign": "center",
'margin-bottom': '20px', "margin-bottom": "20px",
'borderColor': '#28a745' "borderColor": "#28a745",
}, },
multiple=False multiple=False,
) )
@staticmethod @staticmethod
def create_reset_button(): def create_reset_button():
return dbc.Button( return dbc.Button(
"Reset All Data", "Reset All Data",
id='reset-button', id="reset-button",
color='danger', color="danger",
outline=True, outline=True,
size='sm', size="sm",
className='mb-3', className="mb-3",
style={'width': '100%'} style={"width": "100%"},
) )

View File

@@ -4,41 +4,43 @@ from .components.sidebar import SidebarComponent
class AppLayout: class AppLayout:
def __init__(self): def __init__(self):
self.sidebar = SidebarComponent() self.sidebar = SidebarComponent()
def create_layout(self): def create_layout(self):
return dbc.Container([ return dbc.Container(
self._create_header(), [self._create_header(), self._create_main_content(), self._create_stores()],
self._create_main_content(), fluid=True,
self._create_stores() )
], fluid=True)
def _create_header(self): def _create_header(self):
return dbc.Row([ return dbc.Row(
dbc.Col([ [
html.H1("EmbeddingBuddy", className="text-center mb-4"), dbc.Col(
], width=12) [
]) html.H1("EmbeddingBuddy", className="text-center mb-4"),
],
width=12,
)
]
)
def _create_main_content(self): def _create_main_content(self):
return dbc.Row([ return dbc.Row(
self.sidebar.create_layout(), [self.sidebar.create_layout(), self._create_visualization_area()]
self._create_visualization_area() )
])
def _create_visualization_area(self): def _create_visualization_area(self):
return dbc.Col([ return dbc.Col(
dcc.Graph( [
id='embedding-plot', dcc.Graph(
style={'height': '85vh', 'width': '100%'}, id="embedding-plot",
config={'responsive': True, 'displayModeBar': True} style={"height": "85vh", "width": "100%"},
) config={"responsive": True, "displayModeBar": True},
], width=9) )
],
width=9,
)
def _create_stores(self): def _create_stores(self):
return [ return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")]
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
]

View File

@@ -1,33 +1,36 @@
from typing import List, Dict, Any from typing import List
import plotly.colors as pc import plotly.colors as pc
from ..models.schemas import Document from ..models.schemas import Document
class ColorMapper: class ColorMapper:
@staticmethod @staticmethod
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]: def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
if color_by == 'category': if color_by == "category":
return [doc.category for doc in documents] return [doc.category for doc in documents]
elif color_by == 'subcategory': elif color_by == "subcategory":
return [doc.subcategory for doc in documents] return [doc.subcategory for doc in documents]
elif color_by == 'tags': elif color_by == "tags":
return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents] return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents]
else: else:
return ['All'] * len(documents) return ["All"] * len(documents)
@staticmethod @staticmethod
def to_grayscale_hex(color_str: str) -> str: def to_grayscale_hex(color_str: str) -> str:
try: try:
if color_str.startswith('#'): if color_str.startswith("#"):
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5)) rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5))
else: else:
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0]) rgb = pc.hex_to_rgb(
pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0]
)
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, gray_rgb = (
gray_value * 0.7 + rgb[1] * 0.3, gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3) gray_value * 0.7 + rgb[1] * 0.3,
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})' gray_value * 0.7 + rgb[2] * 0.3,
except: )
return 'rgb(128,128,128)' return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})"
except: # noqa: E722
return "rgb(128,128,128)"

View File

@@ -7,139 +7,172 @@ from .colors import ColorMapper
class PlotFactory: class PlotFactory:
def __init__(self): def __init__(self):
self.color_mapper = ColorMapper() self.color_mapper = ColorMapper()
def create_plot(self, plot_data: PlotData, dimensions: str = '3d', def create_plot(
color_by: str = 'category', method: str = 'PCA', self,
show_prompts: Optional[List[str]] = None) -> go.Figure: plot_data: PlotData,
dimensions: str = "3d",
if plot_data.prompts and show_prompts and 'show' in show_prompts: color_by: str = "category",
method: str = "PCA",
show_prompts: Optional[List[str]] = None,
) -> go.Figure:
if plot_data.prompts and show_prompts and "show" in show_prompts:
return self._create_dual_plot(plot_data, dimensions, color_by, method) return self._create_dual_plot(plot_data, dimensions, color_by, method)
else: else:
return self._create_single_plot(plot_data, dimensions, color_by, method) return self._create_single_plot(plot_data, dimensions, color_by, method)
def _create_single_plot(self, plot_data: PlotData, dimensions: str, def _create_single_plot(
color_by: str, method: str) -> go.Figure: self, plot_data: PlotData, dimensions: str, color_by: str, method: str
df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) ) -> go.Figure:
color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) df = self._prepare_dataframe(
plot_data.documents, plot_data.coordinates, dimensions
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] )
color_values = self.color_mapper.create_color_mapping(
if dimensions == '3d': plot_data.documents, color_by
)
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
if dimensions == "3d":
fig = px.scatter_3d( fig = px.scatter_3d(
df, x='dim_1', y='dim_2', z='dim_3', df,
x="dim_1",
y="dim_2",
z="dim_3",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f'3D Embedding Visualization - {method} (colored by {color_by})' title=f"3D Embedding Visualization - {method} (colored by {color_by})",
) )
fig.update_traces(marker=dict(size=5)) fig.update_traces(marker=dict(size=5))
else: else:
fig = px.scatter( fig = px.scatter(
df, x='dim_1', y='dim_2', df,
x="dim_1",
y="dim_2",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f'2D Embedding Visualization - {method} (colored by {color_by})' title=f"2D Embedding Visualization - {method} (colored by {color_by})",
) )
fig.update_traces(marker=dict(size=8)) fig.update_traces(marker=dict(size=8))
fig.update_layout( fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0))
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig return fig
def _create_dual_plot(self, plot_data: PlotData, dimensions: str, def _create_dual_plot(
color_by: str, method: str) -> go.Figure: self, plot_data: PlotData, dimensions: str, color_by: str, method: str
) -> go.Figure:
fig = go.Figure() fig = go.Figure()
doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions) doc_df = self._prepare_dataframe(
doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by) plot_data.documents, plot_data.coordinates, dimensions
)
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] doc_color_values = self.color_mapper.create_color_mapping(
plot_data.documents, color_by
if dimensions == '3d': )
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
if dimensions == "3d":
doc_fig = px.scatter_3d( doc_fig = px.scatter_3d(
doc_df, x='dim_1', y='dim_2', z='dim_3', doc_df,
x="dim_1",
y="dim_2",
z="dim_3",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
else: else:
doc_fig = px.scatter( doc_fig = px.scatter(
doc_df, x='dim_1', y='dim_2', doc_df,
x="dim_1",
y="dim_2",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
for trace in doc_fig.data: for trace in doc_fig.data:
trace.name = f'Documents - {trace.name}' trace.name = f"Documents - {trace.name}"
if dimensions == '3d': if dimensions == "3d":
trace.marker.size = 5 trace.marker.size = 5
trace.marker.symbol = 'circle' trace.marker.symbol = "circle"
else: else:
trace.marker.size = 8 trace.marker.size = 8
trace.marker.symbol = 'circle' trace.marker.symbol = "circle"
trace.marker.opacity = 1.0 trace.marker.opacity = 1.0
fig.add_trace(trace) fig.add_trace(trace)
if plot_data.prompts and plot_data.prompt_coordinates is not None: if plot_data.prompts and plot_data.prompt_coordinates is not None:
prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions) prompt_df = self._prepare_dataframe(
prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by) plot_data.prompts, plot_data.prompt_coordinates, dimensions
)
if dimensions == '3d': prompt_color_values = self.color_mapper.create_color_mapping(
plot_data.prompts, color_by
)
if dimensions == "3d":
prompt_fig = px.scatter_3d( prompt_fig = px.scatter_3d(
prompt_df, x='dim_1', y='dim_2', z='dim_3', prompt_df,
x="dim_1",
y="dim_2",
z="dim_3",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
else: else:
prompt_fig = px.scatter( prompt_fig = px.scatter(
prompt_df, x='dim_1', y='dim_2', prompt_df,
x="dim_1",
y="dim_2",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields hover_data=hover_fields,
) )
for trace in prompt_fig.data: for trace in prompt_fig.data:
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str): if hasattr(trace.marker, "color") and isinstance(
trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color) trace.marker.color, str
):
trace.name = f'Prompts - {trace.name}' trace.marker.color = self.color_mapper.to_grayscale_hex(
if dimensions == '3d': trace.marker.color
)
trace.name = f"Prompts - {trace.name}"
if dimensions == "3d":
trace.marker.size = 6 trace.marker.size = 6
trace.marker.symbol = 'diamond' trace.marker.symbol = "diamond"
else: else:
trace.marker.size = 10 trace.marker.size = 10
trace.marker.symbol = 'diamond' trace.marker.symbol = "diamond"
trace.marker.opacity = 0.8 trace.marker.opacity = 0.8
fig.add_trace(trace) fig.add_trace(trace)
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})' title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})"
fig.update_layout( fig.update_layout(
title=title, title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
) )
return fig return fig
def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame: def _prepare_dataframe(
self, documents: List[Document], coordinates, dimensions: str
) -> pd.DataFrame:
df_data = [] df_data = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
row = { row = {
'id': doc.id, "id": doc.id,
'text': doc.text, "text": doc.text,
'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text, "text_preview": doc.text[:100] + "..."
'category': doc.category, if len(doc.text) > 100
'subcategory': doc.subcategory, else doc.text,
'tags_str': ', '.join(doc.tags) if doc.tags else 'None', "category": doc.category,
'dim_1': coordinates[i, 0], "subcategory": doc.subcategory,
'dim_2': coordinates[i, 1], "tags_str": ", ".join(doc.tags) if doc.tags else "None",
"dim_1": coordinates[i, 0],
"dim_2": coordinates[i, 1],
} }
if dimensions == '3d': if dimensions == "3d":
row['dim_3'] = coordinates[i, 2] row["dim_3"] = coordinates[i, 2]
df_data.append(row) df_data.append(row)
return pd.DataFrame(df_data) return pd.DataFrame(df_data)

View File

@@ -6,62 +6,64 @@ from src.embeddingbuddy.models.schemas import Document
class TestNDJSONParser: class TestNDJSONParser:
def test_parse_text_basic(self): def test_parse_text_basic(self):
text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}' text_content = (
'{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
)
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].id == "test1" assert documents[0].id == "test1"
assert documents[0].text == "Hello world" assert documents[0].text == "Hello world"
assert documents[0].embedding == [0.1, 0.2, 0.3] assert documents[0].embedding == [0.1, 0.2, 0.3]
def test_parse_text_with_metadata(self): def test_parse_text_with_metadata(self):
text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}' text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert documents[0].category == "greeting" assert documents[0].category == "greeting"
assert documents[0].tags == ["test"] assert documents[0].tags == ["test"]
def test_parse_text_missing_id(self): def test_parse_text_missing_id(self):
text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}' text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert len(documents) == 1 assert len(documents) == 1
assert documents[0].id is not None # Should be auto-generated assert documents[0].id is not None # Should be auto-generated
class TestDataProcessor: class TestDataProcessor:
def test_extract_embeddings(self): def test_extract_embeddings(self):
documents = [ documents = [
Document(id="1", text="test1", embedding=[0.1, 0.2]), Document(id="1", text="test1", embedding=[0.1, 0.2]),
Document(id="2", text="test2", embedding=[0.3, 0.4]) Document(id="2", text="test2", embedding=[0.3, 0.4]),
] ]
processor = DataProcessor() processor = DataProcessor()
embeddings = processor._extract_embeddings(documents) embeddings = processor._extract_embeddings(documents)
assert embeddings.shape == (2, 2) assert embeddings.shape == (2, 2)
assert np.allclose(embeddings[0], [0.1, 0.2]) assert np.allclose(embeddings[0], [0.1, 0.2])
assert np.allclose(embeddings[1], [0.3, 0.4]) assert np.allclose(embeddings[1], [0.3, 0.4])
def test_combine_data(self): def test_combine_data(self):
from src.embeddingbuddy.models.schemas import ProcessedData from src.embeddingbuddy.models.schemas import ProcessedData
doc_data = ProcessedData( doc_data = ProcessedData(
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])], documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
embeddings=np.array([[0.1, 0.2]]) embeddings=np.array([[0.1, 0.2]]),
) )
prompt_data = ProcessedData( prompt_data = ProcessedData(
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])], documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
embeddings=np.array([[0.3, 0.4]]) embeddings=np.array([[0.3, 0.4]]),
) )
processor = DataProcessor() processor = DataProcessor()
all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data) all_embeddings, documents, prompts = processor.combine_data(
doc_data, prompt_data
)
assert all_embeddings.shape == (2, 2) assert all_embeddings.shape == (2, 2)
assert len(documents) == 1 assert len(documents) == 1
assert len(prompts) == 1 assert len(prompts) == 1
@@ -70,4 +72,4 @@ class TestDataProcessor:
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@@ -1,89 +1,90 @@
import pytest import pytest
import numpy as np import numpy as np
from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer from src.embeddingbuddy.models.reducers import (
ReducerFactory,
PCAReducer,
TSNEReducer,
UMAPReducer,
)
class TestReducerFactory: class TestReducerFactory:
def test_create_pca_reducer(self): def test_create_pca_reducer(self):
reducer = ReducerFactory.create_reducer('pca', n_components=2) reducer = ReducerFactory.create_reducer("pca", n_components=2)
assert isinstance(reducer, PCAReducer) assert isinstance(reducer, PCAReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_create_tsne_reducer(self): def test_create_tsne_reducer(self):
reducer = ReducerFactory.create_reducer('tsne', n_components=3) reducer = ReducerFactory.create_reducer("tsne", n_components=3)
assert isinstance(reducer, TSNEReducer) assert isinstance(reducer, TSNEReducer)
assert reducer.n_components == 3 assert reducer.n_components == 3
def test_create_umap_reducer(self): def test_create_umap_reducer(self):
reducer = ReducerFactory.create_reducer('umap', n_components=2) reducer = ReducerFactory.create_reducer("umap", n_components=2)
assert isinstance(reducer, UMAPReducer) assert isinstance(reducer, UMAPReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_invalid_method(self): def test_invalid_method(self):
with pytest.raises(ValueError, match="Unknown reduction method"): with pytest.raises(ValueError, match="Unknown reduction method"):
ReducerFactory.create_reducer('invalid_method') ReducerFactory.create_reducer("invalid_method")
def test_available_methods(self): def test_available_methods(self):
methods = ReducerFactory.get_available_methods() methods = ReducerFactory.get_available_methods()
assert 'pca' in methods assert "pca" in methods
assert 'tsne' in methods assert "tsne" in methods
assert 'umap' in methods assert "umap" in methods
class TestPCAReducer: class TestPCAReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(100, 512) embeddings = np.random.rand(100, 512)
reducer = PCAReducer(n_components=2) reducer = PCAReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (100, 2) assert result.reduced_embeddings.shape == (100, 2)
assert result.variance_explained is not None assert result.variance_explained is not None
assert result.method == "PCA" assert result.method == "PCA"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = PCAReducer() reducer = PCAReducer()
assert reducer.get_method_name() == "PCA" assert reducer.get_method_name() == "PCA"
class TestTSNEReducer: class TestTSNEReducer:
def test_fit_transform_small_dataset(self): def test_fit_transform_small_dataset(self):
embeddings = np.random.rand(30, 10) # Small dataset for faster testing embeddings = np.random.rand(30, 10) # Small dataset for faster testing
reducer = TSNEReducer(n_components=2) reducer = TSNEReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (30, 2) assert result.reduced_embeddings.shape == (30, 2)
assert result.variance_explained is None # t-SNE doesn't provide this assert result.variance_explained is None # t-SNE doesn't provide this
assert result.method == "t-SNE" assert result.method == "t-SNE"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = TSNEReducer() reducer = TSNEReducer()
assert reducer.get_method_name() == "t-SNE" assert reducer.get_method_name() == "t-SNE"
class TestUMAPReducer: class TestUMAPReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(50, 10) embeddings = np.random.rand(50, 10)
reducer = UMAPReducer(n_components=2) reducer = UMAPReducer(n_components=2)
result = reducer.fit_transform(embeddings) result = reducer.fit_transform(embeddings)
assert result.reduced_embeddings.shape == (50, 2) assert result.reduced_embeddings.shape == (50, 2)
assert result.variance_explained is None # UMAP doesn't provide this assert result.variance_explained is None # UMAP doesn't provide this
assert result.method == "UMAP" assert result.method == "UMAP"
assert result.n_components == 2 assert result.n_components == 2
def test_method_name(self): def test_method_name(self):
reducer = UMAPReducer() reducer = UMAPReducer()
assert reducer.get_method_name() == "UMAP" assert reducer.get_method_name() == "UMAP"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

1082
uv.lock generated

File diff suppressed because it is too large Load Diff