92
.gitea/workflows/release.yml
Normal file
92
.gitea/workflows/release.yml
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Release version (e.g., v1.0.0)'
|
||||||
|
required: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --extra test
|
||||||
|
|
||||||
|
- name: Run full test suite
|
||||||
|
run: uv run pytest tests/ -v --cov=src/embeddingbuddy --cov-report=term-missing
|
||||||
|
|
||||||
|
build-and-release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: uv build
|
||||||
|
|
||||||
|
- name: Create release notes
|
||||||
|
run: |
|
||||||
|
echo "# Release Notes" > release-notes.md
|
||||||
|
echo "" >> release-notes.md
|
||||||
|
echo "## What's New" >> release-notes.md
|
||||||
|
echo "" >> release-notes.md
|
||||||
|
echo "- Modular architecture with improved testability" >> release-notes.md
|
||||||
|
echo "- Comprehensive test suite" >> release-notes.md
|
||||||
|
echo "- Enhanced documentation" >> release-notes.md
|
||||||
|
echo "- Security scanning and dependency management" >> release-notes.md
|
||||||
|
echo "" >> release-notes.md
|
||||||
|
echo "## Installation" >> release-notes.md
|
||||||
|
echo "" >> release-notes.md
|
||||||
|
echo '```bash' >> release-notes.md
|
||||||
|
echo 'uv sync' >> release-notes.md
|
||||||
|
echo 'uv run python main.py' >> release-notes.md
|
||||||
|
echo '```' >> release-notes.md
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
uses: actions/create-release@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
|
with:
|
||||||
|
tag_name: ${{ github.ref_name || github.event.inputs.version }}
|
||||||
|
release_name: Release ${{ github.ref_name || github.event.inputs.version }}
|
||||||
|
body_path: release-notes.md
|
||||||
|
draft: false
|
||||||
|
prerelease: false
|
||||||
|
|
||||||
|
- name: Upload Release Assets
|
||||||
|
uses: actions/upload-release-asset@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITEA_TOKEN }}
|
||||||
|
with:
|
||||||
|
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||||
|
asset_path: dist/
|
||||||
|
asset_name: embeddingbuddy-dist
|
||||||
|
asset_content_type: application/zip
|
70
.gitea/workflows/security.yml
Normal file
70
.gitea/workflows/security.yml
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
name: Security Scan
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["main", "master", "develop"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["main", "master"]
|
||||||
|
schedule:
|
||||||
|
# Run security scan weekly on Sundays at 2 AM UTC
|
||||||
|
- cron: '0 2 * * 0'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
security:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --extra security
|
||||||
|
|
||||||
|
- name: Run bandit security linter
|
||||||
|
run: uv run bandit -r src/ -f json -o bandit-report.json
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Run safety vulnerability check
|
||||||
|
run: uv run safety check --json --save-json safety-report.json
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Upload security reports
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: security-reports
|
||||||
|
path: |
|
||||||
|
bandit-report.json
|
||||||
|
safety-report.json
|
||||||
|
|
||||||
|
dependency-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Check for dependency vulnerabilities
|
||||||
|
run: |
|
||||||
|
uv sync --extra security
|
||||||
|
uv run pip-audit --format=json --output=pip-audit-report.json
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Upload dependency audit report
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: dependency-audit
|
||||||
|
path: pip-audit-report.json
|
104
.gitea/workflows/test.yml
Normal file
104
.gitea/workflows/test.yml
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
name: Test Suite
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- "main"
|
||||||
|
- "develop"
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- "main"
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.11"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
run: uv python install ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --extra test
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
run: uv run pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
- name: Run tests with coverage
|
||||||
|
run: uv run pytest tests/ --cov=src/embeddingbuddy --cov-report=term-missing --cov-report=xml
|
||||||
|
|
||||||
|
- name: Upload coverage reports
|
||||||
|
uses: codecov/codecov-action@v4
|
||||||
|
if: matrix.python-version == '3.11'
|
||||||
|
with:
|
||||||
|
file: ./coverage.xml
|
||||||
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync --extra lint
|
||||||
|
|
||||||
|
- name: Run ruff linter
|
||||||
|
run: uv run ruff check src/ tests/
|
||||||
|
|
||||||
|
- name: Run ruff formatter check
|
||||||
|
run: uv run ruff format --check src/ tests/
|
||||||
|
|
||||||
|
# TODO fix this it throws errors
|
||||||
|
# - name: Run mypy type checker
|
||||||
|
# run: uv run mypy src/embeddingbuddy/ --ignore-missing-imports
|
||||||
|
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [test, lint]
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v3
|
||||||
|
with:
|
||||||
|
version: "latest"
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
run: uv python install 3.11
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Build package
|
||||||
|
run: uv build
|
||||||
|
|
||||||
|
- name: Test installation
|
||||||
|
run: |
|
||||||
|
uv run python -c "from src.embeddingbuddy.app import create_app; app = create_app(); print('✅ Package builds and imports successfully')"
|
||||||
|
|
||||||
|
- name: Upload build artifacts
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: dist-files
|
||||||
|
path: dist/
|
76
.gitignore
vendored
76
.gitignore
vendored
@@ -1,12 +1,84 @@
|
|||||||
# Python-generated files
|
# Python-generated files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[oc]
|
*.py[oc]
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
build/
|
build/
|
||||||
|
develop-eggs/
|
||||||
dist/
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
wheels/
|
wheels/
|
||||||
*.egg-info
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
.DS_Store?
|
||||||
|
._*
|
||||||
|
.Spotlight-V100
|
||||||
|
.Trashes
|
||||||
|
ehthumbs.db
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
*.log
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
temp/
|
temp/
|
||||||
todo/
|
todo/
|
||||||
|
|
||||||
|
# Security reports
|
||||||
|
bandit-report.json
|
||||||
|
safety-report.json
|
||||||
|
pip-audit-report.json
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
*.tmp
|
36
CLAUDE.md
36
CLAUDE.md
@@ -30,9 +30,28 @@ The app will be available at http://127.0.0.1:8050
|
|||||||
**Run tests:**
|
**Run tests:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
uv sync --extra test
|
||||||
uv run pytest tests/ -v
|
uv run pytest tests/ -v
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Development tools:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install all dev dependencies
|
||||||
|
uv sync --extra dev
|
||||||
|
|
||||||
|
# Linting and formatting
|
||||||
|
uv run ruff check src/ tests/
|
||||||
|
uv run ruff format src/ tests/
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
uv run mypy src/embeddingbuddy/
|
||||||
|
|
||||||
|
# Security scanning
|
||||||
|
uv run bandit -r src/
|
||||||
|
uv run safety check
|
||||||
|
```
|
||||||
|
|
||||||
**Test with sample data:**
|
**Test with sample data:**
|
||||||
Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality.
|
Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality.
|
||||||
|
|
||||||
@@ -42,7 +61,7 @@ Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for test
|
|||||||
|
|
||||||
The application follows a modular architecture with clear separation of concerns:
|
The application follows a modular architecture with clear separation of concerns:
|
||||||
|
|
||||||
```
|
```text
|
||||||
src/embeddingbuddy/
|
src/embeddingbuddy/
|
||||||
├── app.py # Main application entry point and factory
|
├── app.py # Main application entry point and factory
|
||||||
├── main.py # Application runner
|
├── main.py # Application runner
|
||||||
@@ -72,27 +91,32 @@ src/embeddingbuddy/
|
|||||||
### Key Components
|
### Key Components
|
||||||
|
|
||||||
**Data Layer:**
|
**Data Layer:**
|
||||||
|
|
||||||
- `data/parser.py` - NDJSON parsing with error handling
|
- `data/parser.py` - NDJSON parsing with error handling
|
||||||
- `data/processor.py` - Data transformation and combination logic
|
- `data/processor.py` - Data transformation and combination logic
|
||||||
- `models/schemas.py` - Dataclasses for type safety and validation
|
- `models/schemas.py` - Dataclasses for type safety and validation
|
||||||
|
|
||||||
**Algorithm Layer:**
|
**Algorithm Layer:**
|
||||||
|
|
||||||
- `models/reducers.py` - Modular dimensionality reduction with factory pattern
|
- `models/reducers.py` - Modular dimensionality reduction with factory pattern
|
||||||
- Supports PCA, t-SNE (openTSNE), and UMAP algorithms
|
- Supports PCA, t-SNE (openTSNE), and UMAP algorithms
|
||||||
- Abstract base class for easy extension
|
- Abstract base class for easy extension
|
||||||
|
|
||||||
**Visualization Layer:**
|
**Visualization Layer:**
|
||||||
|
|
||||||
- `visualization/plots.py` - Plot factory with single and dual plot support
|
- `visualization/plots.py` - Plot factory with single and dual plot support
|
||||||
- `visualization/colors.py` - Color mapping and grayscale conversion utilities
|
- `visualization/colors.py` - Color mapping and grayscale conversion utilities
|
||||||
- Plotly-based 2D/3D scatter plots with interactive features
|
- Plotly-based 2D/3D scatter plots with interactive features
|
||||||
|
|
||||||
**UI Layer:**
|
**UI Layer:**
|
||||||
|
|
||||||
- `ui/layout.py` - Main application layout composition
|
- `ui/layout.py` - Main application layout composition
|
||||||
- `ui/components/` - Reusable, testable UI components
|
- `ui/components/` - Reusable, testable UI components
|
||||||
- `ui/callbacks/` - Organized callbacks grouped by functionality
|
- `ui/callbacks/` - Organized callbacks grouped by functionality
|
||||||
- Bootstrap-styled sidebar with controls and large visualization area
|
- Bootstrap-styled sidebar with controls and large visualization area
|
||||||
|
|
||||||
**Configuration:**
|
**Configuration:**
|
||||||
|
|
||||||
- `config/settings.py` - Centralized settings with environment variable support
|
- `config/settings.py` - Centralized settings with environment variable support
|
||||||
- Plot styling, marker configurations, and app-wide constants
|
- Plot styling, marker configurations, and app-wide constants
|
||||||
|
|
||||||
@@ -112,16 +136,19 @@ Optional fields: `id`, `category`, `subcategory`, `tags`
|
|||||||
The refactored callback system is organized by functionality:
|
The refactored callback system is organized by functionality:
|
||||||
|
|
||||||
**Data Processing (`ui/callbacks/data_processing.py`):**
|
**Data Processing (`ui/callbacks/data_processing.py`):**
|
||||||
|
|
||||||
- File upload handling
|
- File upload handling
|
||||||
- NDJSON parsing and validation
|
- NDJSON parsing and validation
|
||||||
- Data storage in dcc.Store components
|
- Data storage in dcc.Store components
|
||||||
|
|
||||||
**Visualization (`ui/callbacks/visualization.py`):**
|
**Visualization (`ui/callbacks/visualization.py`):**
|
||||||
|
|
||||||
- Dimensionality reduction pipeline
|
- Dimensionality reduction pipeline
|
||||||
- Plot generation and updates
|
- Plot generation and updates
|
||||||
- Method/parameter change handling
|
- Method/parameter change handling
|
||||||
|
|
||||||
**Interactions (`ui/callbacks/interactions.py`):**
|
**Interactions (`ui/callbacks/interactions.py`):**
|
||||||
|
|
||||||
- Point click handling and detail display
|
- Point click handling and detail display
|
||||||
- Reset functionality
|
- Reset functionality
|
||||||
- User interaction management
|
- User interaction management
|
||||||
@@ -131,15 +158,18 @@ The refactored callback system is organized by functionality:
|
|||||||
The modular design enables comprehensive testing:
|
The modular design enables comprehensive testing:
|
||||||
|
|
||||||
**Unit Tests:**
|
**Unit Tests:**
|
||||||
|
|
||||||
- `tests/test_data_processing.py` - Parser and processor logic
|
- `tests/test_data_processing.py` - Parser and processor logic
|
||||||
- `tests/test_reducers.py` - Dimensionality reduction algorithms
|
- `tests/test_reducers.py` - Dimensionality reduction algorithms
|
||||||
- `tests/test_visualization.py` - Plot creation and color mapping
|
- `tests/test_visualization.py` - Plot creation and color mapping
|
||||||
|
|
||||||
**Integration Tests:**
|
**Integration Tests:**
|
||||||
|
|
||||||
- End-to-end data pipeline testing
|
- End-to-end data pipeline testing
|
||||||
- Component integration verification
|
- Component integration verification
|
||||||
|
|
||||||
**Key Testing Benefits:**
|
**Key Testing Benefits:**
|
||||||
|
|
||||||
- Fast test execution (milliseconds vs seconds)
|
- Fast test execution (milliseconds vs seconds)
|
||||||
- Isolated component testing
|
- Isolated component testing
|
||||||
- Easy mocking and fixture creation
|
- Easy mocking and fixture creation
|
||||||
@@ -167,6 +197,7 @@ Uses modern Python stack with uv for dependency management:
|
|||||||
5. **Tests** - Write tests for all new functionality
|
5. **Tests** - Write tests for all new functionality
|
||||||
|
|
||||||
**Code Organization Principles:**
|
**Code Organization Principles:**
|
||||||
|
|
||||||
- Single responsibility principle
|
- Single responsibility principle
|
||||||
- Clear module boundaries
|
- Clear module boundaries
|
||||||
- Testable, isolated components
|
- Testable, isolated components
|
||||||
@@ -174,7 +205,8 @@ Uses modern Python stack with uv for dependency management:
|
|||||||
- Error handling at appropriate layers
|
- Error handling at appropriate layers
|
||||||
|
|
||||||
**Testing Requirements:**
|
**Testing Requirements:**
|
||||||
|
|
||||||
- Unit tests for all core logic
|
- Unit tests for all core logic
|
||||||
- Integration tests for data flow
|
- Integration tests for data flow
|
||||||
- Component tests for UI elements
|
- Component tests for UI elements
|
||||||
- Maintain high test coverage
|
- Maintain high test coverage
|
||||||
|
@@ -14,7 +14,28 @@ dependencies = [
|
|||||||
"umap-learn>=0.5.8",
|
"umap-learn>=0.5.8",
|
||||||
"numba>=0.56.4",
|
"numba>=0.56.4",
|
||||||
"openTSNE>=1.0.0",
|
"openTSNE>=1.0.0",
|
||||||
|
"mypy>=1.17.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
"pytest>=8.4.1",
|
"pytest>=8.4.1",
|
||||||
|
"pytest-cov>=4.1.0",
|
||||||
|
]
|
||||||
|
lint = [
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"mypy>=1.5.0",
|
||||||
|
]
|
||||||
|
security = [
|
||||||
|
"bandit[toml]>=1.7.5",
|
||||||
|
"safety>=2.3.0",
|
||||||
|
"pip-audit>=2.6.0",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"embeddingbuddy[test,lint,security]",
|
||||||
|
]
|
||||||
|
all = [
|
||||||
|
"embeddingbuddy[test,lint,security]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
|
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
@@ -8,32 +8,29 @@ from .ui.callbacks.interactions import InteractionCallbacks
|
|||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
app = dash.Dash(
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
||||||
__name__,
|
|
||||||
external_stylesheets=[dbc.themes.BOOTSTRAP]
|
|
||||||
)
|
|
||||||
|
|
||||||
layout_manager = AppLayout()
|
layout_manager = AppLayout()
|
||||||
app.layout = layout_manager.create_layout()
|
app.layout = layout_manager.create_layout()
|
||||||
|
|
||||||
DataProcessingCallbacks()
|
DataProcessingCallbacks()
|
||||||
VisualizationCallbacks()
|
VisualizationCallbacks()
|
||||||
InteractionCallbacks()
|
InteractionCallbacks()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def run_app(app=None, debug=None, host=None, port=None):
|
def run_app(app=None, debug=None, host=None, port=None):
|
||||||
if app is None:
|
if app is None:
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
app.run(
|
app.run(
|
||||||
debug=debug if debug is not None else AppSettings.DEBUG,
|
debug=debug if debug is not None else AppSettings.DEBUG,
|
||||||
host=host if host is not None else AppSettings.HOST,
|
host=host if host is not None else AppSettings.HOST,
|
||||||
port=port if port is not None else AppSettings.PORT
|
port=port if port is not None else AppSettings.PORT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app = create_app()
|
app = create_app()
|
||||||
run_app(app)
|
run_app(app)
|
||||||
|
@@ -3,105 +3,100 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
class AppSettings:
|
class AppSettings:
|
||||||
|
|
||||||
# UI Configuration
|
# UI Configuration
|
||||||
UPLOAD_STYLE = {
|
UPLOAD_STYLE = {
|
||||||
'width': '100%',
|
"width": "100%",
|
||||||
'height': '60px',
|
"height": "60px",
|
||||||
'lineHeight': '60px',
|
"lineHeight": "60px",
|
||||||
'borderWidth': '1px',
|
"borderWidth": "1px",
|
||||||
'borderStyle': 'dashed',
|
"borderStyle": "dashed",
|
||||||
'borderRadius': '5px',
|
"borderRadius": "5px",
|
||||||
'textAlign': 'center',
|
"textAlign": "center",
|
||||||
'margin-bottom': '20px'
|
"margin-bottom": "20px",
|
||||||
}
|
}
|
||||||
|
|
||||||
PROMPTS_UPLOAD_STYLE = {
|
PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"}
|
||||||
**UPLOAD_STYLE,
|
|
||||||
'borderColor': '#28a745'
|
PLOT_CONFIG = {"responsive": True, "displayModeBar": True}
|
||||||
}
|
|
||||||
|
PLOT_STYLE = {"height": "85vh", "width": "100%"}
|
||||||
PLOT_CONFIG = {
|
|
||||||
'responsive': True,
|
|
||||||
'displayModeBar': True
|
|
||||||
}
|
|
||||||
|
|
||||||
PLOT_STYLE = {
|
|
||||||
'height': '85vh',
|
|
||||||
'width': '100%'
|
|
||||||
}
|
|
||||||
|
|
||||||
PLOT_LAYOUT_CONFIG = {
|
PLOT_LAYOUT_CONFIG = {
|
||||||
'height': None,
|
"height": None,
|
||||||
'autosize': True,
|
"autosize": True,
|
||||||
'margin': dict(l=0, r=0, t=50, b=0)
|
"margin": dict(l=0, r=0, t=50, b=0),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dimensionality Reduction Settings
|
# Dimensionality Reduction Settings
|
||||||
DEFAULT_N_COMPONENTS_3D = 3
|
DEFAULT_N_COMPONENTS_3D = 3
|
||||||
DEFAULT_N_COMPONENTS_2D = 2
|
DEFAULT_N_COMPONENTS_2D = 2
|
||||||
DEFAULT_RANDOM_STATE = 42
|
DEFAULT_RANDOM_STATE = 42
|
||||||
|
|
||||||
# Available Methods
|
# Available Methods
|
||||||
REDUCTION_METHODS = [
|
REDUCTION_METHODS = [
|
||||||
{'label': 'PCA', 'value': 'pca'},
|
{"label": "PCA", "value": "pca"},
|
||||||
{'label': 't-SNE', 'value': 'tsne'},
|
{"label": "t-SNE", "value": "tsne"},
|
||||||
{'label': 'UMAP', 'value': 'umap'}
|
{"label": "UMAP", "value": "umap"},
|
||||||
]
|
]
|
||||||
|
|
||||||
COLOR_OPTIONS = [
|
COLOR_OPTIONS = [
|
||||||
{'label': 'Category', 'value': 'category'},
|
{"label": "Category", "value": "category"},
|
||||||
{'label': 'Subcategory', 'value': 'subcategory'},
|
{"label": "Subcategory", "value": "subcategory"},
|
||||||
{'label': 'Tags', 'value': 'tags'}
|
{"label": "Tags", "value": "tags"},
|
||||||
]
|
]
|
||||||
|
|
||||||
DIMENSION_OPTIONS = [
|
DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}]
|
||||||
{'label': '2D', 'value': '2d'},
|
|
||||||
{'label': '3D', 'value': '3d'}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
DEFAULT_METHOD = 'pca'
|
DEFAULT_METHOD = "pca"
|
||||||
DEFAULT_COLOR_BY = 'category'
|
DEFAULT_COLOR_BY = "category"
|
||||||
DEFAULT_DIMENSIONS = '3d'
|
DEFAULT_DIMENSIONS = "3d"
|
||||||
DEFAULT_SHOW_PROMPTS = ['show']
|
DEFAULT_SHOW_PROMPTS = ["show"]
|
||||||
|
|
||||||
# Plot Marker Settings
|
# Plot Marker Settings
|
||||||
DOCUMENT_MARKER_SIZE_2D = 8
|
DOCUMENT_MARKER_SIZE_2D = 8
|
||||||
DOCUMENT_MARKER_SIZE_3D = 5
|
DOCUMENT_MARKER_SIZE_3D = 5
|
||||||
PROMPT_MARKER_SIZE_2D = 10
|
PROMPT_MARKER_SIZE_2D = 10
|
||||||
PROMPT_MARKER_SIZE_3D = 6
|
PROMPT_MARKER_SIZE_3D = 6
|
||||||
|
|
||||||
DOCUMENT_MARKER_SYMBOL = 'circle'
|
DOCUMENT_MARKER_SYMBOL = "circle"
|
||||||
PROMPT_MARKER_SYMBOL = 'diamond'
|
PROMPT_MARKER_SYMBOL = "diamond"
|
||||||
|
|
||||||
DOCUMENT_OPACITY = 1.0
|
DOCUMENT_OPACITY = 1.0
|
||||||
PROMPT_OPACITY = 0.8
|
PROMPT_OPACITY = 0.8
|
||||||
|
|
||||||
# Text Processing
|
# Text Processing
|
||||||
TEXT_PREVIEW_LENGTH = 100
|
TEXT_PREVIEW_LENGTH = 100
|
||||||
|
|
||||||
# App Configuration
|
# App Configuration
|
||||||
DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true'
|
DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true"
|
||||||
HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1')
|
HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1")
|
||||||
PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050'))
|
PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050"))
|
||||||
|
|
||||||
# Bootstrap Theme
|
# Bootstrap Theme
|
||||||
EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css']
|
EXTERNAL_STYLESHEETS = [
|
||||||
|
"https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
|
||||||
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]:
|
def get_plot_marker_config(
|
||||||
|
cls, dimensions: str, is_prompt: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D
|
size = (
|
||||||
|
cls.PROMPT_MARKER_SIZE_3D
|
||||||
|
if dimensions == "3d"
|
||||||
|
else cls.PROMPT_MARKER_SIZE_2D
|
||||||
|
)
|
||||||
symbol = cls.PROMPT_MARKER_SYMBOL
|
symbol = cls.PROMPT_MARKER_SYMBOL
|
||||||
opacity = cls.PROMPT_OPACITY
|
opacity = cls.PROMPT_OPACITY
|
||||||
else:
|
else:
|
||||||
size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D
|
size = (
|
||||||
|
cls.DOCUMENT_MARKER_SIZE_3D
|
||||||
|
if dimensions == "3d"
|
||||||
|
else cls.DOCUMENT_MARKER_SIZE_2D
|
||||||
|
)
|
||||||
symbol = cls.DOCUMENT_MARKER_SYMBOL
|
symbol = cls.DOCUMENT_MARKER_SYMBOL
|
||||||
opacity = cls.DOCUMENT_OPACITY
|
opacity = cls.DOCUMENT_OPACITY
|
||||||
|
|
||||||
return {
|
return {"size": size, "symbol": symbol, "opacity": opacity}
|
||||||
'size': size,
|
|
||||||
'symbol': symbol,
|
|
||||||
'opacity': opacity
|
|
||||||
}
|
|
||||||
|
@@ -1,39 +1,38 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import base64
|
import base64
|
||||||
from typing import List, Union
|
from typing import List
|
||||||
from ..models.schemas import Document, ProcessedData
|
from ..models.schemas import Document
|
||||||
|
|
||||||
|
|
||||||
class NDJSONParser:
|
class NDJSONParser:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_upload_contents(contents: str) -> List[Document]:
|
def parse_upload_contents(contents: str) -> List[Document]:
|
||||||
content_type, content_string = contents.split(',')
|
content_type, content_string = contents.split(",")
|
||||||
decoded = base64.b64decode(content_string)
|
decoded = base64.b64decode(content_string)
|
||||||
text_content = decoded.decode('utf-8')
|
text_content = decoded.decode("utf-8")
|
||||||
return NDJSONParser.parse_text(text_content)
|
return NDJSONParser.parse_text(text_content)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_text(text_content: str) -> List[Document]:
|
def parse_text(text_content: str) -> List[Document]:
|
||||||
documents = []
|
documents = []
|
||||||
for line in text_content.strip().split('\n'):
|
for line in text_content.strip().split("\n"):
|
||||||
if line.strip():
|
if line.strip():
|
||||||
doc_dict = json.loads(line)
|
doc_dict = json.loads(line)
|
||||||
doc = NDJSONParser._dict_to_document(doc_dict)
|
doc = NDJSONParser._dict_to_document(doc_dict)
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dict_to_document(doc_dict: dict) -> Document:
|
def _dict_to_document(doc_dict: dict) -> Document:
|
||||||
if 'id' not in doc_dict:
|
if "id" not in doc_dict:
|
||||||
doc_dict['id'] = str(uuid.uuid4())
|
doc_dict["id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
return Document(
|
return Document(
|
||||||
id=doc_dict['id'],
|
id=doc_dict["id"],
|
||||||
text=doc_dict['text'],
|
text=doc_dict["text"],
|
||||||
embedding=doc_dict['embedding'],
|
embedding=doc_dict["embedding"],
|
||||||
category=doc_dict.get('category'),
|
category=doc_dict.get("category"),
|
||||||
subcategory=doc_dict.get('subcategory'),
|
subcategory=doc_dict.get("subcategory"),
|
||||||
tags=doc_dict.get('tags')
|
tags=doc_dict.get("tags"),
|
||||||
)
|
)
|
||||||
|
@@ -5,18 +5,19 @@ from .parser import NDJSONParser
|
|||||||
|
|
||||||
|
|
||||||
class DataProcessor:
|
class DataProcessor:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.parser = NDJSONParser()
|
self.parser = NDJSONParser()
|
||||||
|
|
||||||
def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData:
|
def process_upload(
|
||||||
|
self, contents: str, filename: Optional[str] = None
|
||||||
|
) -> ProcessedData:
|
||||||
try:
|
try:
|
||||||
documents = self.parser.parse_upload_contents(contents)
|
documents = self.parser.parse_upload_contents(contents)
|
||||||
embeddings = self._extract_embeddings(documents)
|
embeddings = self._extract_embeddings(documents)
|
||||||
return ProcessedData(documents=documents, embeddings=embeddings)
|
return ProcessedData(documents=documents, embeddings=embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
||||||
|
|
||||||
def process_text(self, text_content: str) -> ProcessedData:
|
def process_text(self, text_content: str) -> ProcessedData:
|
||||||
try:
|
try:
|
||||||
documents = self.parser.parse_text(text_content)
|
documents = self.parser.parse_text(text_content)
|
||||||
@@ -24,31 +25,35 @@ class DataProcessor:
|
|||||||
return ProcessedData(documents=documents, embeddings=embeddings)
|
return ProcessedData(documents=documents, embeddings=embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
|
||||||
|
|
||||||
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
|
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
|
||||||
if not documents:
|
if not documents:
|
||||||
return np.array([])
|
return np.array([])
|
||||||
return np.array([doc.embedding for doc in documents])
|
return np.array([doc.embedding for doc in documents])
|
||||||
|
|
||||||
def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
|
def combine_data(
|
||||||
|
self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None
|
||||||
|
) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
|
||||||
if not doc_data or doc_data.error:
|
if not doc_data or doc_data.error:
|
||||||
raise ValueError("Invalid document data")
|
raise ValueError("Invalid document data")
|
||||||
|
|
||||||
all_embeddings = doc_data.embeddings
|
all_embeddings = doc_data.embeddings
|
||||||
documents = doc_data.documents
|
documents = doc_data.documents
|
||||||
prompts = None
|
prompts = None
|
||||||
|
|
||||||
if prompt_data and not prompt_data.error and prompt_data.documents:
|
if prompt_data and not prompt_data.error and prompt_data.documents:
|
||||||
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
|
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
|
||||||
prompts = prompt_data.documents
|
prompts = prompt_data.documents
|
||||||
|
|
||||||
return all_embeddings, documents, prompts
|
return all_embeddings, documents, prompts
|
||||||
|
|
||||||
def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
def split_reduced_data(
|
||||||
|
self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0
|
||||||
|
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||||
doc_reduced = reduced_embeddings[:n_documents]
|
doc_reduced = reduced_embeddings[:n_documents]
|
||||||
prompt_reduced = None
|
prompt_reduced = None
|
||||||
|
|
||||||
if n_prompts > 0:
|
if n_prompts > 0:
|
||||||
prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts]
|
prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts]
|
||||||
|
|
||||||
return doc_reduced, prompt_reduced
|
return doc_reduced, prompt_reduced
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional, Tuple
|
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
import umap
|
import umap
|
||||||
from openTSNE import TSNE
|
from openTSNE import TSNE
|
||||||
@@ -8,88 +7,89 @@ from .schemas import ReducedData
|
|||||||
|
|
||||||
|
|
||||||
class DimensionalityReducer(ABC):
|
class DimensionalityReducer(ABC):
|
||||||
|
|
||||||
def __init__(self, n_components: int = 3, random_state: int = 42):
|
def __init__(self, n_components: int = 3, random_state: int = 42):
|
||||||
self.n_components = n_components
|
self.n_components = n_components
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self._reducer = None
|
self._reducer = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_method_name(self) -> str:
|
def get_method_name(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PCAReducer(DimensionalityReducer):
|
class PCAReducer(DimensionalityReducer):
|
||||||
|
|
||||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||||
self._reducer = PCA(n_components=self.n_components)
|
self._reducer = PCA(n_components=self.n_components)
|
||||||
reduced = self._reducer.fit_transform(embeddings)
|
reduced = self._reducer.fit_transform(embeddings)
|
||||||
variance_explained = self._reducer.explained_variance_ratio_
|
variance_explained = self._reducer.explained_variance_ratio_
|
||||||
|
|
||||||
return ReducedData(
|
return ReducedData(
|
||||||
reduced_embeddings=reduced,
|
reduced_embeddings=reduced,
|
||||||
variance_explained=variance_explained,
|
variance_explained=variance_explained,
|
||||||
method=self.get_method_name(),
|
method=self.get_method_name(),
|
||||||
n_components=self.n_components
|
n_components=self.n_components,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_method_name(self) -> str:
|
def get_method_name(self) -> str:
|
||||||
return "PCA"
|
return "PCA"
|
||||||
|
|
||||||
|
|
||||||
class TSNEReducer(DimensionalityReducer):
|
class TSNEReducer(DimensionalityReducer):
|
||||||
|
|
||||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||||
self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state)
|
self._reducer = TSNE(
|
||||||
|
n_components=self.n_components, random_state=self.random_state
|
||||||
|
)
|
||||||
reduced = self._reducer.fit(embeddings)
|
reduced = self._reducer.fit(embeddings)
|
||||||
|
|
||||||
return ReducedData(
|
return ReducedData(
|
||||||
reduced_embeddings=reduced,
|
reduced_embeddings=reduced,
|
||||||
variance_explained=None,
|
variance_explained=None,
|
||||||
method=self.get_method_name(),
|
method=self.get_method_name(),
|
||||||
n_components=self.n_components
|
n_components=self.n_components,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_method_name(self) -> str:
|
def get_method_name(self) -> str:
|
||||||
return "t-SNE"
|
return "t-SNE"
|
||||||
|
|
||||||
|
|
||||||
class UMAPReducer(DimensionalityReducer):
|
class UMAPReducer(DimensionalityReducer):
|
||||||
|
|
||||||
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
|
||||||
self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state)
|
self._reducer = umap.UMAP(
|
||||||
|
n_components=self.n_components, random_state=self.random_state
|
||||||
|
)
|
||||||
reduced = self._reducer.fit_transform(embeddings)
|
reduced = self._reducer.fit_transform(embeddings)
|
||||||
|
|
||||||
return ReducedData(
|
return ReducedData(
|
||||||
reduced_embeddings=reduced,
|
reduced_embeddings=reduced,
|
||||||
variance_explained=None,
|
variance_explained=None,
|
||||||
method=self.get_method_name(),
|
method=self.get_method_name(),
|
||||||
n_components=self.n_components
|
n_components=self.n_components,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_method_name(self) -> str:
|
def get_method_name(self) -> str:
|
||||||
return "UMAP"
|
return "UMAP"
|
||||||
|
|
||||||
|
|
||||||
class ReducerFactory:
|
class ReducerFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer:
|
def create_reducer(
|
||||||
|
method: str, n_components: int = 3, random_state: int = 42
|
||||||
|
) -> DimensionalityReducer:
|
||||||
method_lower = method.lower()
|
method_lower = method.lower()
|
||||||
|
|
||||||
if method_lower == 'pca':
|
if method_lower == "pca":
|
||||||
return PCAReducer(n_components=n_components, random_state=random_state)
|
return PCAReducer(n_components=n_components, random_state=random_state)
|
||||||
elif method_lower == 'tsne':
|
elif method_lower == "tsne":
|
||||||
return TSNEReducer(n_components=n_components, random_state=random_state)
|
return TSNEReducer(n_components=n_components, random_state=random_state)
|
||||||
elif method_lower == 'umap':
|
elif method_lower == "umap":
|
||||||
return UMAPReducer(n_components=n_components, random_state=random_state)
|
return UMAPReducer(n_components=n_components, random_state=random_state)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown reduction method: {method}")
|
raise ValueError(f"Unknown reduction method: {method}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_available_methods() -> list:
|
def get_available_methods() -> list:
|
||||||
return ['pca', 'tsne', 'umap']
|
return ["pca", "tsne", "umap"]
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Any, Dict
|
from typing import List, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -50,9 +50,11 @@ class PlotData:
|
|||||||
coordinates: np.ndarray
|
coordinates: np.ndarray
|
||||||
prompts: Optional[List[Document]] = None
|
prompts: Optional[List[Document]] = None
|
||||||
prompt_coordinates: Optional[np.ndarray] = None
|
prompt_coordinates: Optional[np.ndarray] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not isinstance(self.coordinates, np.ndarray):
|
if not isinstance(self.coordinates, np.ndarray):
|
||||||
self.coordinates = np.array(self.coordinates)
|
self.coordinates = np.array(self.coordinates)
|
||||||
if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray):
|
if self.prompt_coordinates is not None and not isinstance(
|
||||||
self.prompt_coordinates = np.array(self.prompt_coordinates)
|
self.prompt_coordinates, np.ndarray
|
||||||
|
):
|
||||||
|
self.prompt_coordinates = np.array(self.prompt_coordinates)
|
||||||
|
@@ -1,61 +1,62 @@
|
|||||||
import numpy as np
|
|
||||||
from dash import callback, Input, Output, State
|
from dash import callback, Input, Output, State
|
||||||
from ...data.processor import DataProcessor
|
from ...data.processor import DataProcessor
|
||||||
|
|
||||||
|
|
||||||
class DataProcessingCallbacks:
|
class DataProcessingCallbacks:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.processor = DataProcessor()
|
self.processor = DataProcessor()
|
||||||
self._register_callbacks()
|
self._register_callbacks()
|
||||||
|
|
||||||
def _register_callbacks(self):
|
def _register_callbacks(self):
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output('processed-data', 'data'),
|
Output("processed-data", "data"),
|
||||||
Input('upload-data', 'contents'),
|
Input("upload-data", "contents"),
|
||||||
State('upload-data', 'filename')
|
State("upload-data", "filename"),
|
||||||
)
|
)
|
||||||
def process_uploaded_file(contents, filename):
|
def process_uploaded_file(contents, filename):
|
||||||
if contents is None:
|
if contents is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
processed_data = self.processor.process_upload(contents, filename)
|
processed_data = self.processor.process_upload(contents, filename)
|
||||||
|
|
||||||
if processed_data.error:
|
if processed_data.error:
|
||||||
return {'error': processed_data.error}
|
return {"error": processed_data.error}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
|
"documents": [
|
||||||
'embeddings': processed_data.embeddings.tolist()
|
self._document_to_dict(doc) for doc in processed_data.documents
|
||||||
|
],
|
||||||
|
"embeddings": processed_data.embeddings.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output('processed-prompts', 'data'),
|
Output("processed-prompts", "data"),
|
||||||
Input('upload-prompts', 'contents'),
|
Input("upload-prompts", "contents"),
|
||||||
State('upload-prompts', 'filename')
|
State("upload-prompts", "filename"),
|
||||||
)
|
)
|
||||||
def process_uploaded_prompts(contents, filename):
|
def process_uploaded_prompts(contents, filename):
|
||||||
if contents is None:
|
if contents is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
processed_data = self.processor.process_upload(contents, filename)
|
processed_data = self.processor.process_upload(contents, filename)
|
||||||
|
|
||||||
if processed_data.error:
|
if processed_data.error:
|
||||||
return {'error': processed_data.error}
|
return {"error": processed_data.error}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
|
"prompts": [
|
||||||
'embeddings': processed_data.embeddings.tolist()
|
self._document_to_dict(doc) for doc in processed_data.documents
|
||||||
|
],
|
||||||
|
"embeddings": processed_data.embeddings.tolist(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _document_to_dict(doc):
|
def _document_to_dict(doc):
|
||||||
return {
|
return {
|
||||||
'id': doc.id,
|
"id": doc.id,
|
||||||
'text': doc.text,
|
"text": doc.text,
|
||||||
'embedding': doc.embedding,
|
"embedding": doc.embedding,
|
||||||
'category': doc.category,
|
"category": doc.category,
|
||||||
'subcategory': doc.subcategory,
|
"subcategory": doc.subcategory,
|
||||||
'tags': doc.tags
|
"tags": doc.tags,
|
||||||
}
|
}
|
||||||
|
@@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc
|
|||||||
|
|
||||||
|
|
||||||
class InteractionCallbacks:
|
class InteractionCallbacks:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._register_callbacks()
|
self._register_callbacks()
|
||||||
|
|
||||||
def _register_callbacks(self):
|
def _register_callbacks(self):
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output('point-details', 'children'),
|
Output("point-details", "children"),
|
||||||
Input('embedding-plot', 'clickData'),
|
Input("embedding-plot", "clickData"),
|
||||||
[State('processed-data', 'data'),
|
[State("processed-data", "data"), State("processed-prompts", "data")],
|
||||||
State('processed-prompts', 'data')]
|
|
||||||
)
|
)
|
||||||
def display_click_data(clickData, data, prompts_data):
|
def display_click_data(clickData, data, prompts_data):
|
||||||
if not clickData or not data:
|
if not clickData or not data:
|
||||||
return "Click on a point to see details"
|
return "Click on a point to see details"
|
||||||
|
|
||||||
point_data = clickData['points'][0]
|
point_data = clickData["points"][0]
|
||||||
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
|
trace_name = point_data.get("fullData", {}).get("name", "Documents")
|
||||||
|
|
||||||
if 'pointIndex' in point_data:
|
if "pointIndex" in point_data:
|
||||||
point_index = point_data['pointIndex']
|
point_index = point_data["pointIndex"]
|
||||||
elif 'pointNumber' in point_data:
|
elif "pointNumber" in point_data:
|
||||||
point_index = point_data['pointNumber']
|
point_index = point_data["pointNumber"]
|
||||||
else:
|
else:
|
||||||
return "Could not identify clicked point"
|
return "Could not identify clicked point"
|
||||||
|
|
||||||
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
|
if (
|
||||||
item = prompts_data['prompts'][point_index]
|
trace_name.startswith("Prompts")
|
||||||
item_type = 'Prompt'
|
and prompts_data
|
||||||
|
and "prompts" in prompts_data
|
||||||
|
):
|
||||||
|
item = prompts_data["prompts"][point_index]
|
||||||
|
item_type = "Prompt"
|
||||||
else:
|
else:
|
||||||
item = data['documents'][point_index]
|
item = data["documents"][point_index]
|
||||||
item_type = 'Document'
|
item_type = "Document"
|
||||||
|
|
||||||
return self._create_detail_card(item, item_type)
|
return self._create_detail_card(item, item_type)
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
[Output('processed-data', 'data', allow_duplicate=True),
|
[
|
||||||
Output('processed-prompts', 'data', allow_duplicate=True),
|
Output("processed-data", "data", allow_duplicate=True),
|
||||||
Output('point-details', 'children', allow_duplicate=True)],
|
Output("processed-prompts", "data", allow_duplicate=True),
|
||||||
Input('reset-button', 'n_clicks'),
|
Output("point-details", "children", allow_duplicate=True),
|
||||||
prevent_initial_call=True
|
],
|
||||||
|
Input("reset-button", "n_clicks"),
|
||||||
|
prevent_initial_call=True,
|
||||||
)
|
)
|
||||||
def reset_data(n_clicks):
|
def reset_data(n_clicks):
|
||||||
if n_clicks is None or n_clicks == 0:
|
if n_clicks is None or n_clicks == 0:
|
||||||
return dash.no_update, dash.no_update, dash.no_update
|
return dash.no_update, dash.no_update, dash.no_update
|
||||||
|
|
||||||
return None, None, "Click on a point to see details"
|
return None, None, "Click on a point to see details"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_detail_card(item, item_type):
|
def _create_detail_card(item, item_type):
|
||||||
return dbc.Card([
|
return dbc.Card(
|
||||||
dbc.CardBody([
|
[
|
||||||
html.H5(f"{item_type}: {item['id']}", className="card-title"),
|
dbc.CardBody(
|
||||||
html.P(f"Text: {item['text']}", className="card-text"),
|
[
|
||||||
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
|
html.H5(f"{item_type}: {item['id']}", className="card-title"),
|
||||||
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
|
html.P(f"Text: {item['text']}", className="card-text"),
|
||||||
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
|
html.P(
|
||||||
html.P(f"Type: {item_type}", className="card-text text-muted")
|
f"Category: {item.get('category', 'Unknown')}",
|
||||||
])
|
className="card-text",
|
||||||
])
|
),
|
||||||
|
html.P(
|
||||||
|
f"Subcategory: {item.get('subcategory', 'Unknown')}",
|
||||||
|
className="card-text",
|
||||||
|
),
|
||||||
|
html.P(
|
||||||
|
f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
|
||||||
|
className="card-text",
|
||||||
|
),
|
||||||
|
html.P(f"Type: {item_type}", className="card-text text-muted"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory
|
|||||||
|
|
||||||
|
|
||||||
class VisualizationCallbacks:
|
class VisualizationCallbacks:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.plot_factory = PlotFactory()
|
self.plot_factory = PlotFactory()
|
||||||
self._register_callbacks()
|
self._register_callbacks()
|
||||||
|
|
||||||
def _register_callbacks(self):
|
def _register_callbacks(self):
|
||||||
|
|
||||||
@callback(
|
@callback(
|
||||||
Output('embedding-plot', 'figure'),
|
Output("embedding-plot", "figure"),
|
||||||
[Input('processed-data', 'data'),
|
[
|
||||||
Input('processed-prompts', 'data'),
|
Input("processed-data", "data"),
|
||||||
Input('method-dropdown', 'value'),
|
Input("processed-prompts", "data"),
|
||||||
Input('color-dropdown', 'value'),
|
Input("method-dropdown", "value"),
|
||||||
Input('dimension-toggle', 'value'),
|
Input("color-dropdown", "value"),
|
||||||
Input('show-prompts-toggle', 'value')]
|
Input("dimension-toggle", "value"),
|
||||||
|
Input("show-prompts-toggle", "value"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
||||||
if not data or 'error' in data:
|
if not data or "error" in data:
|
||||||
return go.Figure().add_annotation(
|
return go.Figure().add_annotation(
|
||||||
text="Upload a valid NDJSON file to see visualization",
|
text="Upload a valid NDJSON file to see visualization",
|
||||||
xref="paper", yref="paper",
|
xref="paper",
|
||||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
yref="paper",
|
||||||
showarrow=False, font=dict(size=16)
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
|
showarrow=False,
|
||||||
|
font=dict(size=16),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
doc_embeddings = np.array(data['embeddings'])
|
doc_embeddings = np.array(data["embeddings"])
|
||||||
all_embeddings = doc_embeddings
|
all_embeddings = doc_embeddings
|
||||||
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
|
has_prompts = (
|
||||||
|
prompts_data
|
||||||
|
and "error" not in prompts_data
|
||||||
|
and prompts_data.get("prompts")
|
||||||
|
)
|
||||||
|
|
||||||
if has_prompts:
|
if has_prompts:
|
||||||
prompt_embeddings = np.array(prompts_data['embeddings'])
|
prompt_embeddings = np.array(prompts_data["embeddings"])
|
||||||
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
||||||
|
|
||||||
n_components = 3 if dimensions == '3d' else 2
|
n_components = 3 if dimensions == "3d" else 2
|
||||||
|
|
||||||
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
|
reducer = ReducerFactory.create_reducer(
|
||||||
|
method, n_components=n_components
|
||||||
|
)
|
||||||
reduced_data = reducer.fit_transform(all_embeddings)
|
reduced_data = reducer.fit_transform(all_embeddings)
|
||||||
|
|
||||||
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
|
doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
|
||||||
prompt_reduced = None
|
prompt_reduced = None
|
||||||
if has_prompts:
|
if has_prompts:
|
||||||
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
|
prompt_reduced = reduced_data.reduced_embeddings[
|
||||||
|
len(doc_embeddings) :
|
||||||
documents = [self._dict_to_document(doc) for doc in data['documents']]
|
]
|
||||||
|
|
||||||
|
documents = [self._dict_to_document(doc) for doc in data["documents"]]
|
||||||
prompts = None
|
prompts = None
|
||||||
if has_prompts:
|
if has_prompts:
|
||||||
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
|
prompts = [
|
||||||
|
self._dict_to_document(prompt)
|
||||||
|
for prompt in prompts_data["prompts"]
|
||||||
|
]
|
||||||
|
|
||||||
plot_data = PlotData(
|
plot_data = PlotData(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
coordinates=doc_reduced,
|
coordinates=doc_reduced,
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
prompt_coordinates=prompt_reduced
|
prompt_coordinates=prompt_reduced,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.plot_factory.create_plot(
|
return self.plot_factory.create_plot(
|
||||||
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return go.Figure().add_annotation(
|
return go.Figure().add_annotation(
|
||||||
text=f"Error creating visualization: {str(e)}",
|
text=f"Error creating visualization: {str(e)}",
|
||||||
xref="paper", yref="paper",
|
xref="paper",
|
||||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
yref="paper",
|
||||||
showarrow=False, font=dict(size=16)
|
x=0.5,
|
||||||
|
y=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
yanchor="middle",
|
||||||
|
showarrow=False,
|
||||||
|
font=dict(size=16),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _dict_to_document(doc_dict):
|
def _dict_to_document(doc_dict):
|
||||||
return Document(
|
return Document(
|
||||||
id=doc_dict['id'],
|
id=doc_dict["id"],
|
||||||
text=doc_dict['text'],
|
text=doc_dict["text"],
|
||||||
embedding=doc_dict['embedding'],
|
embedding=doc_dict["embedding"],
|
||||||
category=doc_dict.get('category'),
|
category=doc_dict.get("category"),
|
||||||
subcategory=doc_dict.get('subcategory'),
|
subcategory=doc_dict.get("subcategory"),
|
||||||
tags=doc_dict.get('tags', [])
|
tags=doc_dict.get("tags", []),
|
||||||
)
|
)
|
||||||
|
@@ -4,79 +4,81 @@ from .upload import UploadComponent
|
|||||||
|
|
||||||
|
|
||||||
class SidebarComponent:
|
class SidebarComponent:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.upload_component = UploadComponent()
|
self.upload_component = UploadComponent()
|
||||||
|
|
||||||
def create_layout(self):
|
def create_layout(self):
|
||||||
return dbc.Col([
|
return dbc.Col(
|
||||||
html.H5("Upload Data", className="mb-3"),
|
[
|
||||||
self.upload_component.create_data_upload(),
|
html.H5("Upload Data", className="mb-3"),
|
||||||
self.upload_component.create_prompts_upload(),
|
self.upload_component.create_data_upload(),
|
||||||
self.upload_component.create_reset_button(),
|
self.upload_component.create_prompts_upload(),
|
||||||
|
self.upload_component.create_reset_button(),
|
||||||
html.H5("Visualization Controls", className="mb-3"),
|
html.H5("Visualization Controls", className="mb-3"),
|
||||||
self._create_method_dropdown(),
|
self._create_method_dropdown(),
|
||||||
self._create_color_dropdown(),
|
self._create_color_dropdown(),
|
||||||
self._create_dimension_toggle(),
|
self._create_dimension_toggle(),
|
||||||
self._create_prompts_toggle(),
|
self._create_prompts_toggle(),
|
||||||
|
html.H5("Point Details", className="mb-3"),
|
||||||
html.H5("Point Details", className="mb-3"),
|
html.Div(
|
||||||
html.Div(id='point-details', children="Click on a point to see details")
|
id="point-details", children="Click on a point to see details"
|
||||||
|
),
|
||||||
], width=3, style={'padding-right': '20px'})
|
],
|
||||||
|
width=3,
|
||||||
|
style={"padding-right": "20px"},
|
||||||
|
)
|
||||||
|
|
||||||
def _create_method_dropdown(self):
|
def _create_method_dropdown(self):
|
||||||
return [
|
return [
|
||||||
dbc.Label("Method:"),
|
dbc.Label("Method:"),
|
||||||
dcc.Dropdown(
|
dcc.Dropdown(
|
||||||
id='method-dropdown',
|
id="method-dropdown",
|
||||||
options=[
|
options=[
|
||||||
{'label': 'PCA', 'value': 'pca'},
|
{"label": "PCA", "value": "pca"},
|
||||||
{'label': 't-SNE', 'value': 'tsne'},
|
{"label": "t-SNE", "value": "tsne"},
|
||||||
{'label': 'UMAP', 'value': 'umap'}
|
{"label": "UMAP", "value": "umap"},
|
||||||
],
|
],
|
||||||
value='pca',
|
value="pca",
|
||||||
style={'margin-bottom': '15px'}
|
style={"margin-bottom": "15px"},
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_color_dropdown(self):
|
def _create_color_dropdown(self):
|
||||||
return [
|
return [
|
||||||
dbc.Label("Color by:"),
|
dbc.Label("Color by:"),
|
||||||
dcc.Dropdown(
|
dcc.Dropdown(
|
||||||
id='color-dropdown',
|
id="color-dropdown",
|
||||||
options=[
|
options=[
|
||||||
{'label': 'Category', 'value': 'category'},
|
{"label": "Category", "value": "category"},
|
||||||
{'label': 'Subcategory', 'value': 'subcategory'},
|
{"label": "Subcategory", "value": "subcategory"},
|
||||||
{'label': 'Tags', 'value': 'tags'}
|
{"label": "Tags", "value": "tags"},
|
||||||
],
|
],
|
||||||
value='category',
|
value="category",
|
||||||
style={'margin-bottom': '15px'}
|
style={"margin-bottom": "15px"},
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_dimension_toggle(self):
|
def _create_dimension_toggle(self):
|
||||||
return [
|
return [
|
||||||
dbc.Label("Dimensions:"),
|
dbc.Label("Dimensions:"),
|
||||||
dcc.RadioItems(
|
dcc.RadioItems(
|
||||||
id='dimension-toggle',
|
id="dimension-toggle",
|
||||||
options=[
|
options=[
|
||||||
{'label': '2D', 'value': '2d'},
|
{"label": "2D", "value": "2d"},
|
||||||
{'label': '3D', 'value': '3d'}
|
{"label": "3D", "value": "3d"},
|
||||||
],
|
],
|
||||||
value='3d',
|
value="3d",
|
||||||
style={'margin-bottom': '20px'}
|
style={"margin-bottom": "20px"},
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _create_prompts_toggle(self):
|
def _create_prompts_toggle(self):
|
||||||
return [
|
return [
|
||||||
dbc.Label("Show Prompts:"),
|
dbc.Label("Show Prompts:"),
|
||||||
dcc.Checklist(
|
dcc.Checklist(
|
||||||
id='show-prompts-toggle',
|
id="show-prompts-toggle",
|
||||||
options=[{'label': 'Show prompts on plot', 'value': 'show'}],
|
options=[{"label": "Show prompts on plot", "value": "show"}],
|
||||||
value=['show'],
|
value=["show"],
|
||||||
style={'margin-bottom': '20px'}
|
style={"margin-bottom": "20px"},
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
@@ -3,58 +3,51 @@ import dash_bootstrap_components as dbc
|
|||||||
|
|
||||||
|
|
||||||
class UploadComponent:
|
class UploadComponent:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_data_upload():
|
def create_data_upload():
|
||||||
return dcc.Upload(
|
return dcc.Upload(
|
||||||
id='upload-data',
|
id="upload-data",
|
||||||
children=html.Div([
|
children=html.Div(["Drag and Drop or ", html.A("Select Files")]),
|
||||||
'Drag and Drop or ',
|
|
||||||
html.A('Select Files')
|
|
||||||
]),
|
|
||||||
style={
|
style={
|
||||||
'width': '100%',
|
"width": "100%",
|
||||||
'height': '60px',
|
"height": "60px",
|
||||||
'lineHeight': '60px',
|
"lineHeight": "60px",
|
||||||
'borderWidth': '1px',
|
"borderWidth": "1px",
|
||||||
'borderStyle': 'dashed',
|
"borderStyle": "dashed",
|
||||||
'borderRadius': '5px',
|
"borderRadius": "5px",
|
||||||
'textAlign': 'center',
|
"textAlign": "center",
|
||||||
'margin-bottom': '20px'
|
"margin-bottom": "20px",
|
||||||
},
|
},
|
||||||
multiple=False
|
multiple=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_prompts_upload():
|
def create_prompts_upload():
|
||||||
return dcc.Upload(
|
return dcc.Upload(
|
||||||
id='upload-prompts',
|
id="upload-prompts",
|
||||||
children=html.Div([
|
children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]),
|
||||||
'Drag and Drop Prompts or ',
|
|
||||||
html.A('Select Files')
|
|
||||||
]),
|
|
||||||
style={
|
style={
|
||||||
'width': '100%',
|
"width": "100%",
|
||||||
'height': '60px',
|
"height": "60px",
|
||||||
'lineHeight': '60px',
|
"lineHeight": "60px",
|
||||||
'borderWidth': '1px',
|
"borderWidth": "1px",
|
||||||
'borderStyle': 'dashed',
|
"borderStyle": "dashed",
|
||||||
'borderRadius': '5px',
|
"borderRadius": "5px",
|
||||||
'textAlign': 'center',
|
"textAlign": "center",
|
||||||
'margin-bottom': '20px',
|
"margin-bottom": "20px",
|
||||||
'borderColor': '#28a745'
|
"borderColor": "#28a745",
|
||||||
},
|
},
|
||||||
multiple=False
|
multiple=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_reset_button():
|
def create_reset_button():
|
||||||
return dbc.Button(
|
return dbc.Button(
|
||||||
"Reset All Data",
|
"Reset All Data",
|
||||||
id='reset-button',
|
id="reset-button",
|
||||||
color='danger',
|
color="danger",
|
||||||
outline=True,
|
outline=True,
|
||||||
size='sm',
|
size="sm",
|
||||||
className='mb-3',
|
className="mb-3",
|
||||||
style={'width': '100%'}
|
style={"width": "100%"},
|
||||||
)
|
)
|
||||||
|
@@ -4,41 +4,43 @@ from .components.sidebar import SidebarComponent
|
|||||||
|
|
||||||
|
|
||||||
class AppLayout:
|
class AppLayout:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sidebar = SidebarComponent()
|
self.sidebar = SidebarComponent()
|
||||||
|
|
||||||
def create_layout(self):
|
def create_layout(self):
|
||||||
return dbc.Container([
|
return dbc.Container(
|
||||||
self._create_header(),
|
[self._create_header(), self._create_main_content(), self._create_stores()],
|
||||||
self._create_main_content(),
|
fluid=True,
|
||||||
self._create_stores()
|
)
|
||||||
], fluid=True)
|
|
||||||
|
|
||||||
def _create_header(self):
|
def _create_header(self):
|
||||||
return dbc.Row([
|
return dbc.Row(
|
||||||
dbc.Col([
|
[
|
||||||
html.H1("EmbeddingBuddy", className="text-center mb-4"),
|
dbc.Col(
|
||||||
], width=12)
|
[
|
||||||
])
|
html.H1("EmbeddingBuddy", className="text-center mb-4"),
|
||||||
|
],
|
||||||
|
width=12,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _create_main_content(self):
|
def _create_main_content(self):
|
||||||
return dbc.Row([
|
return dbc.Row(
|
||||||
self.sidebar.create_layout(),
|
[self.sidebar.create_layout(), self._create_visualization_area()]
|
||||||
self._create_visualization_area()
|
)
|
||||||
])
|
|
||||||
|
|
||||||
def _create_visualization_area(self):
|
def _create_visualization_area(self):
|
||||||
return dbc.Col([
|
return dbc.Col(
|
||||||
dcc.Graph(
|
[
|
||||||
id='embedding-plot',
|
dcc.Graph(
|
||||||
style={'height': '85vh', 'width': '100%'},
|
id="embedding-plot",
|
||||||
config={'responsive': True, 'displayModeBar': True}
|
style={"height": "85vh", "width": "100%"},
|
||||||
)
|
config={"responsive": True, "displayModeBar": True},
|
||||||
], width=9)
|
)
|
||||||
|
],
|
||||||
|
width=9,
|
||||||
|
)
|
||||||
|
|
||||||
def _create_stores(self):
|
def _create_stores(self):
|
||||||
return [
|
return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")]
|
||||||
dcc.Store(id='processed-data'),
|
|
||||||
dcc.Store(id='processed-prompts')
|
|
||||||
]
|
|
||||||
|
@@ -1,33 +1,36 @@
|
|||||||
from typing import List, Dict, Any
|
from typing import List
|
||||||
import plotly.colors as pc
|
import plotly.colors as pc
|
||||||
from ..models.schemas import Document
|
from ..models.schemas import Document
|
||||||
|
|
||||||
|
|
||||||
class ColorMapper:
|
class ColorMapper:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
|
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
|
||||||
if color_by == 'category':
|
if color_by == "category":
|
||||||
return [doc.category for doc in documents]
|
return [doc.category for doc in documents]
|
||||||
elif color_by == 'subcategory':
|
elif color_by == "subcategory":
|
||||||
return [doc.subcategory for doc in documents]
|
return [doc.subcategory for doc in documents]
|
||||||
elif color_by == 'tags':
|
elif color_by == "tags":
|
||||||
return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents]
|
return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents]
|
||||||
else:
|
else:
|
||||||
return ['All'] * len(documents)
|
return ["All"] * len(documents)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_grayscale_hex(color_str: str) -> str:
|
def to_grayscale_hex(color_str: str) -> str:
|
||||||
try:
|
try:
|
||||||
if color_str.startswith('#'):
|
if color_str.startswith("#"):
|
||||||
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
|
rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5))
|
||||||
else:
|
else:
|
||||||
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
|
rgb = pc.hex_to_rgb(
|
||||||
|
pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0]
|
||||||
|
)
|
||||||
|
|
||||||
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
|
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
|
||||||
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
|
gray_rgb = (
|
||||||
gray_value * 0.7 + rgb[1] * 0.3,
|
gray_value * 0.7 + rgb[0] * 0.3,
|
||||||
gray_value * 0.7 + rgb[2] * 0.3)
|
gray_value * 0.7 + rgb[1] * 0.3,
|
||||||
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
|
gray_value * 0.7 + rgb[2] * 0.3,
|
||||||
except:
|
)
|
||||||
return 'rgb(128,128,128)'
|
return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})"
|
||||||
|
except: # noqa: E722
|
||||||
|
return "rgb(128,128,128)"
|
||||||
|
@@ -7,139 +7,172 @@ from .colors import ColorMapper
|
|||||||
|
|
||||||
|
|
||||||
class PlotFactory:
|
class PlotFactory:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.color_mapper = ColorMapper()
|
self.color_mapper = ColorMapper()
|
||||||
|
|
||||||
def create_plot(self, plot_data: PlotData, dimensions: str = '3d',
|
def create_plot(
|
||||||
color_by: str = 'category', method: str = 'PCA',
|
self,
|
||||||
show_prompts: Optional[List[str]] = None) -> go.Figure:
|
plot_data: PlotData,
|
||||||
|
dimensions: str = "3d",
|
||||||
if plot_data.prompts and show_prompts and 'show' in show_prompts:
|
color_by: str = "category",
|
||||||
|
method: str = "PCA",
|
||||||
|
show_prompts: Optional[List[str]] = None,
|
||||||
|
) -> go.Figure:
|
||||||
|
if plot_data.prompts and show_prompts and "show" in show_prompts:
|
||||||
return self._create_dual_plot(plot_data, dimensions, color_by, method)
|
return self._create_dual_plot(plot_data, dimensions, color_by, method)
|
||||||
else:
|
else:
|
||||||
return self._create_single_plot(plot_data, dimensions, color_by, method)
|
return self._create_single_plot(plot_data, dimensions, color_by, method)
|
||||||
|
|
||||||
def _create_single_plot(self, plot_data: PlotData, dimensions: str,
|
def _create_single_plot(
|
||||||
color_by: str, method: str) -> go.Figure:
|
self, plot_data: PlotData, dimensions: str, color_by: str, method: str
|
||||||
df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
|
) -> go.Figure:
|
||||||
color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
|
df = self._prepare_dataframe(
|
||||||
|
plot_data.documents, plot_data.coordinates, dimensions
|
||||||
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
|
)
|
||||||
|
color_values = self.color_mapper.create_color_mapping(
|
||||||
if dimensions == '3d':
|
plot_data.documents, color_by
|
||||||
|
)
|
||||||
|
|
||||||
|
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
|
||||||
|
|
||||||
|
if dimensions == "3d":
|
||||||
fig = px.scatter_3d(
|
fig = px.scatter_3d(
|
||||||
df, x='dim_1', y='dim_2', z='dim_3',
|
df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
|
z="dim_3",
|
||||||
color=color_values,
|
color=color_values,
|
||||||
hover_data=hover_fields,
|
hover_data=hover_fields,
|
||||||
title=f'3D Embedding Visualization - {method} (colored by {color_by})'
|
title=f"3D Embedding Visualization - {method} (colored by {color_by})",
|
||||||
)
|
)
|
||||||
fig.update_traces(marker=dict(size=5))
|
fig.update_traces(marker=dict(size=5))
|
||||||
else:
|
else:
|
||||||
fig = px.scatter(
|
fig = px.scatter(
|
||||||
df, x='dim_1', y='dim_2',
|
df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
color=color_values,
|
color=color_values,
|
||||||
hover_data=hover_fields,
|
hover_data=hover_fields,
|
||||||
title=f'2D Embedding Visualization - {method} (colored by {color_by})'
|
title=f"2D Embedding Visualization - {method} (colored by {color_by})",
|
||||||
)
|
)
|
||||||
fig.update_traces(marker=dict(size=8))
|
fig.update_traces(marker=dict(size=8))
|
||||||
|
|
||||||
fig.update_layout(
|
fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0))
|
||||||
height=None,
|
|
||||||
autosize=True,
|
|
||||||
margin=dict(l=0, r=0, t=50, b=0)
|
|
||||||
)
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
def _create_dual_plot(self, plot_data: PlotData, dimensions: str,
|
def _create_dual_plot(
|
||||||
color_by: str, method: str) -> go.Figure:
|
self, plot_data: PlotData, dimensions: str, color_by: str, method: str
|
||||||
|
) -> go.Figure:
|
||||||
fig = go.Figure()
|
fig = go.Figure()
|
||||||
|
|
||||||
doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
|
doc_df = self._prepare_dataframe(
|
||||||
doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
|
plot_data.documents, plot_data.coordinates, dimensions
|
||||||
|
)
|
||||||
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
|
doc_color_values = self.color_mapper.create_color_mapping(
|
||||||
|
plot_data.documents, color_by
|
||||||
if dimensions == '3d':
|
)
|
||||||
|
|
||||||
|
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
|
||||||
|
|
||||||
|
if dimensions == "3d":
|
||||||
doc_fig = px.scatter_3d(
|
doc_fig = px.scatter_3d(
|
||||||
doc_df, x='dim_1', y='dim_2', z='dim_3',
|
doc_df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
|
z="dim_3",
|
||||||
color=doc_color_values,
|
color=doc_color_values,
|
||||||
hover_data=hover_fields
|
hover_data=hover_fields,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
doc_fig = px.scatter(
|
doc_fig = px.scatter(
|
||||||
doc_df, x='dim_1', y='dim_2',
|
doc_df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
color=doc_color_values,
|
color=doc_color_values,
|
||||||
hover_data=hover_fields
|
hover_data=hover_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
for trace in doc_fig.data:
|
for trace in doc_fig.data:
|
||||||
trace.name = f'Documents - {trace.name}'
|
trace.name = f"Documents - {trace.name}"
|
||||||
if dimensions == '3d':
|
if dimensions == "3d":
|
||||||
trace.marker.size = 5
|
trace.marker.size = 5
|
||||||
trace.marker.symbol = 'circle'
|
trace.marker.symbol = "circle"
|
||||||
else:
|
else:
|
||||||
trace.marker.size = 8
|
trace.marker.size = 8
|
||||||
trace.marker.symbol = 'circle'
|
trace.marker.symbol = "circle"
|
||||||
trace.marker.opacity = 1.0
|
trace.marker.opacity = 1.0
|
||||||
fig.add_trace(trace)
|
fig.add_trace(trace)
|
||||||
|
|
||||||
if plot_data.prompts and plot_data.prompt_coordinates is not None:
|
if plot_data.prompts and plot_data.prompt_coordinates is not None:
|
||||||
prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions)
|
prompt_df = self._prepare_dataframe(
|
||||||
prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by)
|
plot_data.prompts, plot_data.prompt_coordinates, dimensions
|
||||||
|
)
|
||||||
if dimensions == '3d':
|
prompt_color_values = self.color_mapper.create_color_mapping(
|
||||||
|
plot_data.prompts, color_by
|
||||||
|
)
|
||||||
|
|
||||||
|
if dimensions == "3d":
|
||||||
prompt_fig = px.scatter_3d(
|
prompt_fig = px.scatter_3d(
|
||||||
prompt_df, x='dim_1', y='dim_2', z='dim_3',
|
prompt_df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
|
z="dim_3",
|
||||||
color=prompt_color_values,
|
color=prompt_color_values,
|
||||||
hover_data=hover_fields
|
hover_data=hover_fields,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_fig = px.scatter(
|
prompt_fig = px.scatter(
|
||||||
prompt_df, x='dim_1', y='dim_2',
|
prompt_df,
|
||||||
|
x="dim_1",
|
||||||
|
y="dim_2",
|
||||||
color=prompt_color_values,
|
color=prompt_color_values,
|
||||||
hover_data=hover_fields
|
hover_data=hover_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
for trace in prompt_fig.data:
|
for trace in prompt_fig.data:
|
||||||
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
|
if hasattr(trace.marker, "color") and isinstance(
|
||||||
trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color)
|
trace.marker.color, str
|
||||||
|
):
|
||||||
trace.name = f'Prompts - {trace.name}'
|
trace.marker.color = self.color_mapper.to_grayscale_hex(
|
||||||
if dimensions == '3d':
|
trace.marker.color
|
||||||
|
)
|
||||||
|
|
||||||
|
trace.name = f"Prompts - {trace.name}"
|
||||||
|
if dimensions == "3d":
|
||||||
trace.marker.size = 6
|
trace.marker.size = 6
|
||||||
trace.marker.symbol = 'diamond'
|
trace.marker.symbol = "diamond"
|
||||||
else:
|
else:
|
||||||
trace.marker.size = 10
|
trace.marker.size = 10
|
||||||
trace.marker.symbol = 'diamond'
|
trace.marker.symbol = "diamond"
|
||||||
trace.marker.opacity = 0.8
|
trace.marker.opacity = 0.8
|
||||||
fig.add_trace(trace)
|
fig.add_trace(trace)
|
||||||
|
|
||||||
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
|
title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})"
|
||||||
fig.update_layout(
|
fig.update_layout(
|
||||||
title=title,
|
title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)
|
||||||
height=None,
|
|
||||||
autosize=True,
|
|
||||||
margin=dict(l=0, r=0, t=50, b=0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame:
|
def _prepare_dataframe(
|
||||||
|
self, documents: List[Document], coordinates, dimensions: str
|
||||||
|
) -> pd.DataFrame:
|
||||||
df_data = []
|
df_data = []
|
||||||
for i, doc in enumerate(documents):
|
for i, doc in enumerate(documents):
|
||||||
row = {
|
row = {
|
||||||
'id': doc.id,
|
"id": doc.id,
|
||||||
'text': doc.text,
|
"text": doc.text,
|
||||||
'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text,
|
"text_preview": doc.text[:100] + "..."
|
||||||
'category': doc.category,
|
if len(doc.text) > 100
|
||||||
'subcategory': doc.subcategory,
|
else doc.text,
|
||||||
'tags_str': ', '.join(doc.tags) if doc.tags else 'None',
|
"category": doc.category,
|
||||||
'dim_1': coordinates[i, 0],
|
"subcategory": doc.subcategory,
|
||||||
'dim_2': coordinates[i, 1],
|
"tags_str": ", ".join(doc.tags) if doc.tags else "None",
|
||||||
|
"dim_1": coordinates[i, 0],
|
||||||
|
"dim_2": coordinates[i, 1],
|
||||||
}
|
}
|
||||||
if dimensions == '3d':
|
if dimensions == "3d":
|
||||||
row['dim_3'] = coordinates[i, 2]
|
row["dim_3"] = coordinates[i, 2]
|
||||||
df_data.append(row)
|
df_data.append(row)
|
||||||
|
|
||||||
return pd.DataFrame(df_data)
|
return pd.DataFrame(df_data)
|
||||||
|
@@ -6,62 +6,64 @@ from src.embeddingbuddy.models.schemas import Document
|
|||||||
|
|
||||||
|
|
||||||
class TestNDJSONParser:
|
class TestNDJSONParser:
|
||||||
|
|
||||||
def test_parse_text_basic(self):
|
def test_parse_text_basic(self):
|
||||||
text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
|
text_content = (
|
||||||
|
'{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
|
||||||
|
)
|
||||||
documents = NDJSONParser.parse_text(text_content)
|
documents = NDJSONParser.parse_text(text_content)
|
||||||
|
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
assert documents[0].id == "test1"
|
assert documents[0].id == "test1"
|
||||||
assert documents[0].text == "Hello world"
|
assert documents[0].text == "Hello world"
|
||||||
assert documents[0].embedding == [0.1, 0.2, 0.3]
|
assert documents[0].embedding == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
def test_parse_text_with_metadata(self):
|
def test_parse_text_with_metadata(self):
|
||||||
text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
|
text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
|
||||||
documents = NDJSONParser.parse_text(text_content)
|
documents = NDJSONParser.parse_text(text_content)
|
||||||
|
|
||||||
assert documents[0].category == "greeting"
|
assert documents[0].category == "greeting"
|
||||||
assert documents[0].tags == ["test"]
|
assert documents[0].tags == ["test"]
|
||||||
|
|
||||||
def test_parse_text_missing_id(self):
|
def test_parse_text_missing_id(self):
|
||||||
text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
|
text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
|
||||||
documents = NDJSONParser.parse_text(text_content)
|
documents = NDJSONParser.parse_text(text_content)
|
||||||
|
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
assert documents[0].id is not None # Should be auto-generated
|
assert documents[0].id is not None # Should be auto-generated
|
||||||
|
|
||||||
|
|
||||||
class TestDataProcessor:
|
class TestDataProcessor:
|
||||||
|
|
||||||
def test_extract_embeddings(self):
|
def test_extract_embeddings(self):
|
||||||
documents = [
|
documents = [
|
||||||
Document(id="1", text="test1", embedding=[0.1, 0.2]),
|
Document(id="1", text="test1", embedding=[0.1, 0.2]),
|
||||||
Document(id="2", text="test2", embedding=[0.3, 0.4])
|
Document(id="2", text="test2", embedding=[0.3, 0.4]),
|
||||||
]
|
]
|
||||||
|
|
||||||
processor = DataProcessor()
|
processor = DataProcessor()
|
||||||
embeddings = processor._extract_embeddings(documents)
|
embeddings = processor._extract_embeddings(documents)
|
||||||
|
|
||||||
assert embeddings.shape == (2, 2)
|
assert embeddings.shape == (2, 2)
|
||||||
assert np.allclose(embeddings[0], [0.1, 0.2])
|
assert np.allclose(embeddings[0], [0.1, 0.2])
|
||||||
assert np.allclose(embeddings[1], [0.3, 0.4])
|
assert np.allclose(embeddings[1], [0.3, 0.4])
|
||||||
|
|
||||||
def test_combine_data(self):
|
def test_combine_data(self):
|
||||||
from src.embeddingbuddy.models.schemas import ProcessedData
|
from src.embeddingbuddy.models.schemas import ProcessedData
|
||||||
|
|
||||||
doc_data = ProcessedData(
|
doc_data = ProcessedData(
|
||||||
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
|
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
|
||||||
embeddings=np.array([[0.1, 0.2]])
|
embeddings=np.array([[0.1, 0.2]]),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_data = ProcessedData(
|
prompt_data = ProcessedData(
|
||||||
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
|
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
|
||||||
embeddings=np.array([[0.3, 0.4]])
|
embeddings=np.array([[0.3, 0.4]]),
|
||||||
)
|
)
|
||||||
|
|
||||||
processor = DataProcessor()
|
processor = DataProcessor()
|
||||||
all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data)
|
all_embeddings, documents, prompts = processor.combine_data(
|
||||||
|
doc_data, prompt_data
|
||||||
|
)
|
||||||
|
|
||||||
assert all_embeddings.shape == (2, 2)
|
assert all_embeddings.shape == (2, 2)
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
assert len(prompts) == 1
|
assert len(prompts) == 1
|
||||||
@@ -70,4 +72,4 @@ class TestDataProcessor:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
@@ -1,89 +1,90 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer
|
from src.embeddingbuddy.models.reducers import (
|
||||||
|
ReducerFactory,
|
||||||
|
PCAReducer,
|
||||||
|
TSNEReducer,
|
||||||
|
UMAPReducer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestReducerFactory:
|
class TestReducerFactory:
|
||||||
|
|
||||||
def test_create_pca_reducer(self):
|
def test_create_pca_reducer(self):
|
||||||
reducer = ReducerFactory.create_reducer('pca', n_components=2)
|
reducer = ReducerFactory.create_reducer("pca", n_components=2)
|
||||||
assert isinstance(reducer, PCAReducer)
|
assert isinstance(reducer, PCAReducer)
|
||||||
assert reducer.n_components == 2
|
assert reducer.n_components == 2
|
||||||
|
|
||||||
def test_create_tsne_reducer(self):
|
def test_create_tsne_reducer(self):
|
||||||
reducer = ReducerFactory.create_reducer('tsne', n_components=3)
|
reducer = ReducerFactory.create_reducer("tsne", n_components=3)
|
||||||
assert isinstance(reducer, TSNEReducer)
|
assert isinstance(reducer, TSNEReducer)
|
||||||
assert reducer.n_components == 3
|
assert reducer.n_components == 3
|
||||||
|
|
||||||
def test_create_umap_reducer(self):
|
def test_create_umap_reducer(self):
|
||||||
reducer = ReducerFactory.create_reducer('umap', n_components=2)
|
reducer = ReducerFactory.create_reducer("umap", n_components=2)
|
||||||
assert isinstance(reducer, UMAPReducer)
|
assert isinstance(reducer, UMAPReducer)
|
||||||
assert reducer.n_components == 2
|
assert reducer.n_components == 2
|
||||||
|
|
||||||
def test_invalid_method(self):
|
def test_invalid_method(self):
|
||||||
with pytest.raises(ValueError, match="Unknown reduction method"):
|
with pytest.raises(ValueError, match="Unknown reduction method"):
|
||||||
ReducerFactory.create_reducer('invalid_method')
|
ReducerFactory.create_reducer("invalid_method")
|
||||||
|
|
||||||
def test_available_methods(self):
|
def test_available_methods(self):
|
||||||
methods = ReducerFactory.get_available_methods()
|
methods = ReducerFactory.get_available_methods()
|
||||||
assert 'pca' in methods
|
assert "pca" in methods
|
||||||
assert 'tsne' in methods
|
assert "tsne" in methods
|
||||||
assert 'umap' in methods
|
assert "umap" in methods
|
||||||
|
|
||||||
|
|
||||||
class TestPCAReducer:
|
class TestPCAReducer:
|
||||||
|
|
||||||
def test_fit_transform(self):
|
def test_fit_transform(self):
|
||||||
embeddings = np.random.rand(100, 512)
|
embeddings = np.random.rand(100, 512)
|
||||||
reducer = PCAReducer(n_components=2)
|
reducer = PCAReducer(n_components=2)
|
||||||
|
|
||||||
result = reducer.fit_transform(embeddings)
|
result = reducer.fit_transform(embeddings)
|
||||||
|
|
||||||
assert result.reduced_embeddings.shape == (100, 2)
|
assert result.reduced_embeddings.shape == (100, 2)
|
||||||
assert result.variance_explained is not None
|
assert result.variance_explained is not None
|
||||||
assert result.method == "PCA"
|
assert result.method == "PCA"
|
||||||
assert result.n_components == 2
|
assert result.n_components == 2
|
||||||
|
|
||||||
def test_method_name(self):
|
def test_method_name(self):
|
||||||
reducer = PCAReducer()
|
reducer = PCAReducer()
|
||||||
assert reducer.get_method_name() == "PCA"
|
assert reducer.get_method_name() == "PCA"
|
||||||
|
|
||||||
|
|
||||||
class TestTSNEReducer:
|
class TestTSNEReducer:
|
||||||
|
|
||||||
def test_fit_transform_small_dataset(self):
|
def test_fit_transform_small_dataset(self):
|
||||||
embeddings = np.random.rand(30, 10) # Small dataset for faster testing
|
embeddings = np.random.rand(30, 10) # Small dataset for faster testing
|
||||||
reducer = TSNEReducer(n_components=2)
|
reducer = TSNEReducer(n_components=2)
|
||||||
|
|
||||||
result = reducer.fit_transform(embeddings)
|
result = reducer.fit_transform(embeddings)
|
||||||
|
|
||||||
assert result.reduced_embeddings.shape == (30, 2)
|
assert result.reduced_embeddings.shape == (30, 2)
|
||||||
assert result.variance_explained is None # t-SNE doesn't provide this
|
assert result.variance_explained is None # t-SNE doesn't provide this
|
||||||
assert result.method == "t-SNE"
|
assert result.method == "t-SNE"
|
||||||
assert result.n_components == 2
|
assert result.n_components == 2
|
||||||
|
|
||||||
def test_method_name(self):
|
def test_method_name(self):
|
||||||
reducer = TSNEReducer()
|
reducer = TSNEReducer()
|
||||||
assert reducer.get_method_name() == "t-SNE"
|
assert reducer.get_method_name() == "t-SNE"
|
||||||
|
|
||||||
|
|
||||||
class TestUMAPReducer:
|
class TestUMAPReducer:
|
||||||
|
|
||||||
def test_fit_transform(self):
|
def test_fit_transform(self):
|
||||||
embeddings = np.random.rand(50, 10)
|
embeddings = np.random.rand(50, 10)
|
||||||
reducer = UMAPReducer(n_components=2)
|
reducer = UMAPReducer(n_components=2)
|
||||||
|
|
||||||
result = reducer.fit_transform(embeddings)
|
result = reducer.fit_transform(embeddings)
|
||||||
|
|
||||||
assert result.reduced_embeddings.shape == (50, 2)
|
assert result.reduced_embeddings.shape == (50, 2)
|
||||||
assert result.variance_explained is None # UMAP doesn't provide this
|
assert result.variance_explained is None # UMAP doesn't provide this
|
||||||
assert result.method == "UMAP"
|
assert result.method == "UMAP"
|
||||||
assert result.n_components == 2
|
assert result.n_components == 2
|
||||||
|
|
||||||
def test_method_name(self):
|
def test_method_name(self):
|
||||||
reducer = UMAPReducer()
|
reducer = UMAPReducer()
|
||||||
assert reducer.get_method_name() == "UMAP"
|
assert reducer.get_method_name() == "UMAP"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
Reference in New Issue
Block a user