3 Commits

Author SHA1 Message Date
091a2a0f97 fix colors linting
Some checks failed
Test Suite / test (3.12) (push) Successful in 1m27s
Test Suite / test (3.11) (push) Successful in 1m29s
Test Suite / lint (push) Failing after 10s
Test Suite / build (push) Has been skipped
Test Suite / test (3.11) (pull_request) Has been cancelled
Test Suite / test (3.12) (pull_request) Has been cancelled
Test Suite / lint (pull_request) Has been cancelled
Test Suite / build (pull_request) Has been cancelled
Security Scan / security (pull_request) Failing after 39s
Security Scan / dependency-check (pull_request) Failing after 34s
2025-08-13 20:40:08 -07:00
48326c6335 fix linting and other ci errors 2025-08-13 20:39:53 -07:00
450f6b23e0 add ci workflows
Some checks failed
Test Suite / test (3.11) (push) Successful in 2m17s
Test Suite / test (3.12) (push) Successful in 2m20s
Test Suite / lint (push) Failing after 39s
Test Suite / build (push) Has been skipped
Security Scan / security (pull_request) Failing after 50s
Security Scan / dependency-check (pull_request) Failing after 47s
Test Suite / test (3.11) (pull_request) Successful in 1m32s
Test Suite / lint (pull_request) Failing after 23s
Test Suite / test (3.12) (pull_request) Successful in 1m24s
Test Suite / build (pull_request) Has been skipped
2025-08-13 20:26:06 -07:00
37 changed files with 1079 additions and 2120 deletions

View File

@@ -6,7 +6,6 @@
"Bash(uv add:*)" "Bash(uv add:*)"
], ],
"deny": [], "deny": [],
"ask": [], "ask": []
"defaultMode": "acceptEdits"
} }
} }

View File

@@ -26,10 +26,12 @@ jobs:
run: uv python install 3.11 run: uv python install 3.11
- name: Install dependencies - name: Install dependencies
run: uv sync --extra test run: uv sync
- name: Run full test suite - name: Run full test suite
run: uv run pytest tests/ -v --cov=src/embeddingbuddy --cov-report=term-missing run: |
uv add pytest-cov
uv run pytest tests/ -v --cov=src/embeddingbuddy --cov-report=term-missing
build-and-release: build-and-release:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -25,7 +25,12 @@ jobs:
run: uv python install 3.11 run: uv python install 3.11
- name: Install dependencies - name: Install dependencies
run: uv sync --extra security run: uv sync
- name: Add security tools
run: |
uv add bandit[toml]
uv add safety
- name: Run bandit security linter - name: Run bandit security linter
run: uv run bandit -r src/ -f json -o bandit-report.json run: uv run bandit -r src/ -f json -o bandit-report.json
@@ -36,7 +41,7 @@ jobs:
continue-on-error: true continue-on-error: true
- name: Upload security reports - name: Upload security reports
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: security-reports name: security-reports
path: | path: |
@@ -59,12 +64,13 @@ jobs:
- name: Check for dependency vulnerabilities - name: Check for dependency vulnerabilities
run: | run: |
uv sync --extra security uv sync
uv add pip-audit
uv run pip-audit --format=json --output=pip-audit-report.json uv run pip-audit --format=json --output=pip-audit-report.json
continue-on-error: true continue-on-error: true
- name: Upload dependency audit report - name: Upload dependency audit report
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: dependency-audit name: dependency-audit
path: pip-audit-report.json path: pip-audit-report.json

View File

@@ -2,20 +2,16 @@ name: Test Suite
on: on:
push: push:
branches: branches: ["*"]
- "main"
- "develop"
pull_request: pull_request:
branches: branches: ["main", "master"]
- "main"
workflow_dispatch:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.11"] python-version: ["3.11", "3.12"]
steps: steps:
- name: Checkout code - name: Checkout code
@@ -30,13 +26,15 @@ jobs:
run: uv python install ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: uv sync --extra test run: uv sync
- name: Run tests with pytest - name: Run tests with pytest
run: uv run pytest tests/ -v --tb=short run: uv run pytest tests/ -v --tb=short
- name: Run tests with coverage - name: Run tests with coverage
run: uv run pytest tests/ --cov=src/embeddingbuddy --cov-report=term-missing --cov-report=xml run: |
uv add pytest-cov
uv run pytest tests/ --cov=src/embeddingbuddy --cov-report=term-missing --cov-report=xml
- name: Upload coverage reports - name: Upload coverage reports
uses: codecov/codecov-action@v4 uses: codecov/codecov-action@v4
@@ -60,7 +58,12 @@ jobs:
run: uv python install 3.11 run: uv python install 3.11
- name: Install dependencies - name: Install dependencies
run: uv sync --extra lint run: uv sync
- name: Add linting tools
run: |
uv add ruff
uv add mypy
- name: Run ruff linter - name: Run ruff linter
run: uv run ruff check src/ tests/ run: uv run ruff check src/ tests/
@@ -68,9 +71,8 @@ jobs:
- name: Run ruff formatter check - name: Run ruff formatter check
run: uv run ruff format --check src/ tests/ run: uv run ruff format --check src/ tests/
# TODO fix this it throws errors - name: Run mypy type checker
# - name: Run mypy type checker run: uv run mypy src/embeddingbuddy/ --ignore-missing-imports
# run: uv run mypy src/embeddingbuddy/ --ignore-missing-imports
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -98,7 +100,7 @@ jobs:
uv run python -c "from src.embeddingbuddy.app import create_app; app = create_app(); print('✅ Package builds and imports successfully')" uv run python -c "from src.embeddingbuddy.app import create_app; app = create_app(); print('✅ Package builds and imports successfully')"
- name: Upload build artifacts - name: Upload build artifacts
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: dist-files name: dist-files
path: dist/ path: dist/

View File

@@ -30,28 +30,9 @@ The app will be available at http://127.0.0.1:8050
**Run tests:** **Run tests:**
```bash ```bash
uv sync --extra test
uv run pytest tests/ -v uv run pytest tests/ -v
``` ```
**Development tools:**
```bash
# Install all dev dependencies
uv sync --extra dev
# Linting and formatting
uv run ruff check src/ tests/
uv run ruff format src/ tests/
# Type checking
uv run mypy src/embeddingbuddy/
# Security scanning
uv run bandit -r src/
uv run safety check
```
**Test with sample data:** **Test with sample data:**
Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality. Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for testing the application functionality.
@@ -61,7 +42,7 @@ Use the included `sample_data.ndjson` and `sample_prompts.ndjson` files for test
The application follows a modular architecture with clear separation of concerns: The application follows a modular architecture with clear separation of concerns:
```text ```
src/embeddingbuddy/ src/embeddingbuddy/
├── app.py # Main application entry point and factory ├── app.py # Main application entry point and factory
├── main.py # Application runner ├── main.py # Application runner
@@ -91,32 +72,27 @@ src/embeddingbuddy/
### Key Components ### Key Components
**Data Layer:** **Data Layer:**
- `data/parser.py` - NDJSON parsing with error handling - `data/parser.py` - NDJSON parsing with error handling
- `data/processor.py` - Data transformation and combination logic - `data/processor.py` - Data transformation and combination logic
- `models/schemas.py` - Dataclasses for type safety and validation - `models/schemas.py` - Dataclasses for type safety and validation
**Algorithm Layer:** **Algorithm Layer:**
- `models/reducers.py` - Modular dimensionality reduction with factory pattern - `models/reducers.py` - Modular dimensionality reduction with factory pattern
- Supports PCA, t-SNE (openTSNE), and UMAP algorithms - Supports PCA, t-SNE (openTSNE), and UMAP algorithms
- Abstract base class for easy extension - Abstract base class for easy extension
**Visualization Layer:** **Visualization Layer:**
- `visualization/plots.py` - Plot factory with single and dual plot support - `visualization/plots.py` - Plot factory with single and dual plot support
- `visualization/colors.py` - Color mapping and grayscale conversion utilities - `visualization/colors.py` - Color mapping and grayscale conversion utilities
- Plotly-based 2D/3D scatter plots with interactive features - Plotly-based 2D/3D scatter plots with interactive features
**UI Layer:** **UI Layer:**
- `ui/layout.py` - Main application layout composition - `ui/layout.py` - Main application layout composition
- `ui/components/` - Reusable, testable UI components - `ui/components/` - Reusable, testable UI components
- `ui/callbacks/` - Organized callbacks grouped by functionality - `ui/callbacks/` - Organized callbacks grouped by functionality
- Bootstrap-styled sidebar with controls and large visualization area - Bootstrap-styled sidebar with controls and large visualization area
**Configuration:** **Configuration:**
- `config/settings.py` - Centralized settings with environment variable support - `config/settings.py` - Centralized settings with environment variable support
- Plot styling, marker configurations, and app-wide constants - Plot styling, marker configurations, and app-wide constants
@@ -136,19 +112,16 @@ Optional fields: `id`, `category`, `subcategory`, `tags`
The refactored callback system is organized by functionality: The refactored callback system is organized by functionality:
**Data Processing (`ui/callbacks/data_processing.py`):** **Data Processing (`ui/callbacks/data_processing.py`):**
- File upload handling - File upload handling
- NDJSON parsing and validation - NDJSON parsing and validation
- Data storage in dcc.Store components - Data storage in dcc.Store components
**Visualization (`ui/callbacks/visualization.py`):** **Visualization (`ui/callbacks/visualization.py`):**
- Dimensionality reduction pipeline - Dimensionality reduction pipeline
- Plot generation and updates - Plot generation and updates
- Method/parameter change handling - Method/parameter change handling
**Interactions (`ui/callbacks/interactions.py`):** **Interactions (`ui/callbacks/interactions.py`):**
- Point click handling and detail display - Point click handling and detail display
- Reset functionality - Reset functionality
- User interaction management - User interaction management
@@ -158,18 +131,15 @@ The refactored callback system is organized by functionality:
The modular design enables comprehensive testing: The modular design enables comprehensive testing:
**Unit Tests:** **Unit Tests:**
- `tests/test_data_processing.py` - Parser and processor logic - `tests/test_data_processing.py` - Parser and processor logic
- `tests/test_reducers.py` - Dimensionality reduction algorithms - `tests/test_reducers.py` - Dimensionality reduction algorithms
- `tests/test_visualization.py` - Plot creation and color mapping - `tests/test_visualization.py` - Plot creation and color mapping
**Integration Tests:** **Integration Tests:**
- End-to-end data pipeline testing - End-to-end data pipeline testing
- Component integration verification - Component integration verification
**Key Testing Benefits:** **Key Testing Benefits:**
- Fast test execution (milliseconds vs seconds) - Fast test execution (milliseconds vs seconds)
- Isolated component testing - Isolated component testing
- Easy mocking and fixture creation - Easy mocking and fixture creation
@@ -197,7 +167,6 @@ Uses modern Python stack with uv for dependency management:
5. **Tests** - Write tests for all new functionality 5. **Tests** - Write tests for all new functionality
**Code Organization Principles:** **Code Organization Principles:**
- Single responsibility principle - Single responsibility principle
- Clear module boundaries - Clear module boundaries
- Testable, isolated components - Testable, isolated components
@@ -205,7 +174,6 @@ Uses modern Python stack with uv for dependency management:
- Error handling at appropriate layers - Error handling at appropriate layers
**Testing Requirements:** **Testing Requirements:**
- Unit tests for all core logic - Unit tests for all core logic
- Integration tests for data flow - Integration tests for data flow
- Component tests for UI elements - Component tests for UI elements

View File

@@ -90,7 +90,7 @@ uv run python main.py
The application follows a modular architecture for improved maintainability and testability: The application follows a modular architecture for improved maintainability and testability:
```text ```
src/embeddingbuddy/ src/embeddingbuddy/
├── config/ # Configuration management ├── config/ # Configuration management
│ └── settings.py # Centralized app settings │ └── settings.py # Centralized app settings
@@ -115,8 +115,8 @@ src/embeddingbuddy/
Run the test suite to verify functionality: Run the test suite to verify functionality:
```bash ```bash
# Install test dependencies # Install pytest
uv sync --extra test uv add pytest
# Run all tests # Run all tests
uv run pytest tests/ -v uv run pytest tests/ -v
@@ -128,31 +128,6 @@ uv run pytest tests/test_data_processing.py -v
uv run pytest tests/ --cov=src/embeddingbuddy uv run pytest tests/ --cov=src/embeddingbuddy
``` ```
### Development Tools
Install development dependencies for linting, type checking, and security:
```bash
# Install all dev dependencies
uv sync --extra dev
# Or install specific groups
uv sync --extra test # Testing tools
uv sync --extra lint # Linting and formatting
uv sync --extra security # Security scanning tools
# Run linting
uv run ruff check src/ tests/
uv run ruff format src/ tests/
# Run type checking
uv run mypy src/embeddingbuddy/
# Run security scans
uv run bandit -r src/
uv run safety check
```
### Adding New Features ### Adding New Features
The modular architecture makes it easy to extend functionality: The modular architecture makes it easy to extend functionality:

515
app.py Normal file
View File

@@ -0,0 +1,515 @@
import json
import uuid
from io import StringIO
import base64
import dash
from dash import dcc, html, Input, Output, State, callback
import dash_bootstrap_components as dbc
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import umap
from openTSNE import TSNE
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
def parse_ndjson(contents):
"""Parse NDJSON content and return list of documents."""
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
text_content = decoded.decode('utf-8')
documents = []
for line in text_content.strip().split('\n'):
if line.strip():
doc = json.loads(line)
if 'id' not in doc:
doc['id'] = str(uuid.uuid4())
documents.append(doc)
return documents
def apply_dimensionality_reduction(embeddings, method='pca', n_components=3):
"""Apply dimensionality reduction to embeddings."""
if method == 'pca':
reducer = PCA(n_components=n_components)
reduced = reducer.fit_transform(embeddings)
variance_explained = reducer.explained_variance_ratio_
return reduced, variance_explained
elif method == 'tsne':
reducer = TSNE(n_components=n_components, random_state=42)
reduced = reducer.fit(embeddings)
return reduced, None
elif method == 'umap':
reducer = umap.UMAP(n_components=n_components, random_state=42)
reduced = reducer.fit_transform(embeddings)
return reduced, None
else:
raise ValueError(f"Unknown method: {method}")
def create_color_mapping(documents, color_by):
"""Create color mapping for documents based on specified field."""
if color_by == 'category':
values = [doc.get('category', 'Unknown') for doc in documents]
elif color_by == 'subcategory':
values = [doc.get('subcategory', 'Unknown') for doc in documents]
elif color_by == 'tags':
values = [', '.join(doc.get('tags', [])) if doc.get('tags') else 'No tags' for doc in documents]
else:
values = ['All'] * len(documents)
return values
def create_plot(df, dimensions='3d', color_by='category', method='PCA'):
"""Create plotly scatter plot."""
color_values = create_color_mapping(df.to_dict('records'), color_by)
# Truncate text for hover display
df_display = df.copy()
df_display['text_preview'] = df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x)
# Include all metadata fields in hover
hover_fields = ['id', 'text_preview', 'category', 'subcategory']
# Add tags as a string for hover
df_display['tags_str'] = df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None')
hover_fields.append('tags_str')
if dimensions == '3d':
fig = px.scatter_3d(
df_display, x='dim_1', y='dim_2', z='dim_3',
color=color_values,
hover_data=hover_fields,
title=f'3D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=5))
else:
fig = px.scatter(
df_display, x='dim_1', y='dim_2',
color=color_values,
hover_data=hover_fields,
title=f'2D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=8))
fig.update_layout(
height=None, # Let CSS height control this
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
def create_dual_plot(doc_df, prompt_df, dimensions='3d', color_by='category', method='PCA', show_prompts=None):
"""Create plotly scatter plot with separate traces for documents and prompts."""
# Create the base figure
fig = go.Figure()
# Helper function to convert colors to grayscale
def to_grayscale_hex(color_str):
"""Convert a color to grayscale while maintaining some distinction."""
import plotly.colors as pc
# Try to get RGB values from the color
try:
if color_str.startswith('#'):
# Hex color
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
else:
# Named color or other format - convert through plotly
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
# Convert to grayscale using luminance formula, but keep some color
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
# Make it a bit more gray but not completely
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[1] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3)
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
except:
return 'rgb(128,128,128)' # fallback gray
# Create document plot using plotly express for consistent colors
doc_color_values = create_color_mapping(doc_df.to_dict('records'), color_by)
doc_df_display = doc_df.copy()
doc_df_display['text_preview'] = doc_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x)
doc_df_display['tags_str'] = doc_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None')
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
# Create documents plot to get the color mapping
if dimensions == '3d':
doc_fig = px.scatter_3d(
doc_df_display, x='dim_1', y='dim_2', z='dim_3',
color=doc_color_values,
hover_data=hover_fields
)
else:
doc_fig = px.scatter(
doc_df_display, x='dim_1', y='dim_2',
color=doc_color_values,
hover_data=hover_fields
)
# Add document traces to main figure
for trace in doc_fig.data:
trace.name = f'Documents - {trace.name}'
if dimensions == '3d':
trace.marker.size = 5
trace.marker.symbol = 'circle'
else:
trace.marker.size = 8
trace.marker.symbol = 'circle'
trace.marker.opacity = 1.0
fig.add_trace(trace)
# Add prompt traces if they exist
if prompt_df is not None and show_prompts and 'show' in show_prompts:
prompt_color_values = create_color_mapping(prompt_df.to_dict('records'), color_by)
prompt_df_display = prompt_df.copy()
prompt_df_display['text_preview'] = prompt_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x)
prompt_df_display['tags_str'] = prompt_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None')
# Create prompts plot to get consistent color grouping
if dimensions == '3d':
prompt_fig = px.scatter_3d(
prompt_df_display, x='dim_1', y='dim_2', z='dim_3',
color=prompt_color_values,
hover_data=hover_fields
)
else:
prompt_fig = px.scatter(
prompt_df_display, x='dim_1', y='dim_2',
color=prompt_color_values,
hover_data=hover_fields
)
# Add prompt traces with grayed colors
for trace in prompt_fig.data:
# Convert the color to grayscale
original_color = trace.marker.color
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
trace.marker.color = to_grayscale_hex(trace.marker.color)
trace.name = f'Prompts - {trace.name}'
if dimensions == '3d':
trace.marker.size = 6
trace.marker.symbol = 'diamond'
else:
trace.marker.size = 10
trace.marker.symbol = 'diamond'
trace.marker.opacity = 0.8
fig.add_trace(trace)
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
fig.update_layout(
title=title,
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
# Layout
app.layout = dbc.Container([
dbc.Row([
dbc.Col([
html.H1("EmbeddingBuddy", className="text-center mb-4"),
], width=12)
]),
dbc.Row([
# Left sidebar with controls
dbc.Col([
html.H5("Upload Data", className="mb-3"),
dcc.Upload(
id='upload-data',
children=html.Div([
'Drag and Drop or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px'
},
multiple=False
),
dcc.Upload(
id='upload-prompts',
children=html.Div([
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px',
'borderColor': '#28a745'
},
multiple=False
),
dbc.Button(
"Reset All Data",
id='reset-button',
color='danger',
outline=True,
size='sm',
className='mb-3',
style={'width': '100%'}
),
html.H5("Visualization Controls", className="mb-3"),
dbc.Label("Method:"),
dcc.Dropdown(
id='method-dropdown',
options=[
{'label': 'PCA', 'value': 'pca'},
{'label': 't-SNE', 'value': 'tsne'},
{'label': 'UMAP', 'value': 'umap'}
],
value='pca',
style={'margin-bottom': '15px'}
),
dbc.Label("Color by:"),
dcc.Dropdown(
id='color-dropdown',
options=[
{'label': 'Category', 'value': 'category'},
{'label': 'Subcategory', 'value': 'subcategory'},
{'label': 'Tags', 'value': 'tags'}
],
value='category',
style={'margin-bottom': '15px'}
),
dbc.Label("Dimensions:"),
dcc.RadioItems(
id='dimension-toggle',
options=[
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
],
value='3d',
style={'margin-bottom': '20px'}
),
dbc.Label("Show Prompts:"),
dcc.Checklist(
id='show-prompts-toggle',
options=[{'label': 'Show prompts on plot', 'value': 'show'}],
value=['show'],
style={'margin-bottom': '20px'}
),
html.H5("Point Details", className="mb-3"),
html.Div(id='point-details', children="Click on a point to see details")
], width=3, style={'padding-right': '20px'}),
# Main visualization area
dbc.Col([
dcc.Graph(
id='embedding-plot',
style={'height': '85vh', 'width': '100%'},
config={'responsive': True, 'displayModeBar': True}
)
], width=9)
]),
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
], fluid=True)
@callback(
Output('processed-data', 'data'),
Input('upload-data', 'contents'),
State('upload-data', 'filename')
)
def process_uploaded_file(contents, filename):
if contents is None:
return None
try:
documents = parse_ndjson(contents)
embeddings = np.array([doc['embedding'] for doc in documents])
# Store original embeddings and documents
return {
'documents': documents,
'embeddings': embeddings.tolist()
}
except Exception as e:
return {'error': str(e)}
@callback(
Output('processed-prompts', 'data'),
Input('upload-prompts', 'contents'),
State('upload-prompts', 'filename')
)
def process_uploaded_prompts(contents, filename):
if contents is None:
return None
try:
prompts = parse_ndjson(contents)
embeddings = np.array([prompt['embedding'] for prompt in prompts])
# Store original embeddings and prompts
return {
'prompts': prompts,
'embeddings': embeddings.tolist()
}
except Exception as e:
return {'error': str(e)}
@callback(
Output('embedding-plot', 'figure'),
[Input('processed-data', 'data'),
Input('processed-prompts', 'data'),
Input('method-dropdown', 'value'),
Input('color-dropdown', 'value'),
Input('dimension-toggle', 'value'),
Input('show-prompts-toggle', 'value')]
)
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data:
return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
# Prepare embeddings for dimensionality reduction
doc_embeddings = np.array(data['embeddings'])
all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings'])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2
# Apply dimensionality reduction to combined data
reduced, variance_explained = apply_dimensionality_reduction(
all_embeddings, method=method, n_components=n_components
)
# Split reduced embeddings back
doc_reduced = reduced[:len(doc_embeddings)]
prompt_reduced = reduced[len(doc_embeddings):] if has_prompts else None
# Create dataframes
doc_df_data = []
for i, doc in enumerate(data['documents']):
row = {
'id': doc['id'],
'text': doc['text'],
'category': doc.get('category', 'Unknown'),
'subcategory': doc.get('subcategory', 'Unknown'),
'tags': doc.get('tags', []),
'dim_1': doc_reduced[i, 0],
'dim_2': doc_reduced[i, 1],
'type': 'document'
}
if dimensions == '3d':
row['dim_3'] = doc_reduced[i, 2]
doc_df_data.append(row)
doc_df = pd.DataFrame(doc_df_data)
prompt_df = None
if has_prompts and prompt_reduced is not None:
prompt_df_data = []
for i, prompt in enumerate(prompts_data['prompts']):
row = {
'id': prompt['id'],
'text': prompt['text'],
'category': prompt.get('category', 'Unknown'),
'subcategory': prompt.get('subcategory', 'Unknown'),
'tags': prompt.get('tags', []),
'dim_1': prompt_reduced[i, 0],
'dim_2': prompt_reduced[i, 1],
'type': 'prompt'
}
if dimensions == '3d':
row['dim_3'] = prompt_reduced[i, 2]
prompt_df_data.append(row)
prompt_df = pd.DataFrame(prompt_df_data)
return create_dual_plot(doc_df, prompt_df, dimensions, color_by, method.upper(), show_prompts)
@callback(
Output('point-details', 'children'),
Input('embedding-plot', 'clickData'),
[State('processed-data', 'data'),
State('processed-prompts', 'data')]
)
def display_click_data(clickData, data, prompts_data):
if not clickData or not data:
return "Click on a point to see details"
# Get point info from click
point_data = clickData['points'][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
if 'pointIndex' in point_data:
point_index = point_data['pointIndex']
elif 'pointNumber' in point_data:
point_index = point_data['pointNumber']
else:
return "Could not identify clicked point"
# Determine which dataset this point belongs to
if trace_name == 'Prompts' and prompts_data and 'prompts' in prompts_data:
item = prompts_data['prompts'][point_index]
item_type = 'Prompt'
else:
item = data['documents'][point_index]
item_type = 'Document'
return dbc.Card([
dbc.CardBody([
html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
html.P(f"Type: {item_type}", className="card-text text-muted")
])
])
@callback(
[Output('processed-data', 'data', allow_duplicate=True),
Output('processed-prompts', 'data', allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)],
Input('reset-button', 'n_clicks'),
prevent_initial_call=True
)
def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details"
if __name__ == '__main__':
app.run(debug=True)

View File

@@ -1,2 +0,0 @@
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, 0.2], "text": "Binary junk at start"}
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2, 0.8], "text": "Normal line"}<7D><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>

View File

@@ -1,6 +0,0 @@
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, 0.2], "text": "First line"}
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2, 0.8], "text": "After empty line"}
{"id": "doc_003", "embedding": [0.3, 0.4, 0.1, -0.1], "text": "After multiple empty lines"}

View File

@@ -1,4 +0,0 @@
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, 0.2], "text": "4D embedding"}
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2], "text": "3D embedding"}
{"id": "doc_003", "embedding": [0.3, 0.4, 0.1, -0.1, 0.8], "text": "5D embedding"}
{"id": "doc_004", "embedding": [0.2, 0.1], "text": "2D embedding"}

View File

@@ -1,8 +0,0 @@
{"id": "doc_001", "embedding": "not_an_array", "text": "Embedding as string"}
{"id": "doc_002", "embedding": [0.1, "text", 0.7, 0.2], "text": "Mixed types in embedding"}
{"id": "doc_003", "embedding": [], "text": "Empty embedding array"}
{"id": "doc_004", "embedding": [0.1], "text": "Single dimension embedding"}
{"id": "doc_005", "embedding": null, "text": "Null embedding"}
{"id": "doc_006", "embedding": [0.1, 0.2, null, 0.4], "text": "Null value in embedding"}
{"id": "doc_007", "embedding": [0.1, 0.2, "NaN", 0.4], "text": "String NaN in embedding"}
{"id": "doc_008", "embedding": [0.1, 0.2, Infinity, 0.4], "text": "Infinity in embedding"}

View File

@@ -1,5 +0,0 @@
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, "text": "Valid line"}
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2, 0.8], "text": "Missing closing brace"
{"id": "doc_003" "embedding": [0.3, 0.4, 0.1, -0.1], "text": "Missing colon after id"}
{id: "doc_004", "embedding": [0.2, 0.1, 0.3, 0.4], "text": "Unquoted key"}
{"id": "doc_005", "embedding": [0.1, 0.2, 0.3, 0.4], "text": "Valid line again"}

View File

@@ -1,3 +0,0 @@
{"id": "doc_001", "text": "Sample text without embedding field", "category": "test"}
{"id": "doc_002", "text": "Another text without embedding", "category": "test"}
{"id": "doc_003", "text": "Third text missing embedding", "category": "test"}

View File

@@ -1,3 +0,0 @@
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, 0.2], "category": "test"}
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2, 0.8], "category": "test"}
{"id": "doc_003", "embedding": [0.3, 0.4, 0.1, -0.1], "category": "test"}

View File

@@ -1,4 +0,0 @@
[
{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, 0.2], "text": "Regular JSON array"},
{"id": "doc_002", "embedding": [0.5, 0.1, -0.2, 0.8], "text": "Instead of NDJSON"}
]

View File

@@ -14,28 +14,7 @@ dependencies = [
"umap-learn>=0.5.8", "umap-learn>=0.5.8",
"numba>=0.56.4", "numba>=0.56.4",
"openTSNE>=1.0.0", "openTSNE>=1.0.0",
"mypy>=1.17.1",
]
[project.optional-dependencies]
test = [
"pytest>=8.4.1", "pytest>=8.4.1",
"pytest-cov>=4.1.0",
]
lint = [
"ruff>=0.1.0",
"mypy>=1.5.0",
]
security = [
"bandit[toml]>=1.7.5",
"safety>=2.3.0",
"pip-audit>=2.6.0",
]
dev = [
"embeddingbuddy[test,lint,security]",
]
all = [
"embeddingbuddy[test,lint,security]",
] ]
[build-system] [build-system]

View File

@@ -8,7 +8,10 @@ from .ui.callbacks.interactions import InteractionCallbacks
def create_app(): def create_app():
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP]
)
layout_manager = AppLayout() layout_manager = AppLayout()
app.layout = layout_manager.create_layout() app.layout = layout_manager.create_layout()
@@ -27,10 +30,10 @@ def run_app(app=None, debug=None, host=None, port=None):
app.run( app.run(
debug=debug if debug is not None else AppSettings.DEBUG, debug=debug if debug is not None else AppSettings.DEBUG,
host=host if host is not None else AppSettings.HOST, host=host if host is not None else AppSettings.HOST,
port=port if port is not None else AppSettings.PORT, port=port if port is not None else AppSettings.PORT
) )
if __name__ == "__main__": if __name__ == '__main__':
app = create_app() app = create_app()
run_app(app) run_app(app)

View File

@@ -3,28 +3,38 @@ import os
class AppSettings: class AppSettings:
# UI Configuration # UI Configuration
UPLOAD_STYLE = { UPLOAD_STYLE = {
"width": "100%", 'width': '100%',
"height": "60px", 'height': '60px',
"lineHeight": "60px", 'lineHeight': '60px',
"borderWidth": "1px", 'borderWidth': '1px',
"borderStyle": "dashed", 'borderStyle': 'dashed',
"borderRadius": "5px", 'borderRadius': '5px',
"textAlign": "center", 'textAlign': 'center',
"margin-bottom": "20px", 'margin-bottom': '20px'
} }
PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"} PROMPTS_UPLOAD_STYLE = {
**UPLOAD_STYLE,
'borderColor': '#28a745'
}
PLOT_CONFIG = {"responsive": True, "displayModeBar": True} PLOT_CONFIG = {
'responsive': True,
'displayModeBar': True
}
PLOT_STYLE = {"height": "85vh", "width": "100%"} PLOT_STYLE = {
'height': '85vh',
'width': '100%'
}
PLOT_LAYOUT_CONFIG = { PLOT_LAYOUT_CONFIG = {
"height": None, 'height': None,
"autosize": True, 'autosize': True,
"margin": dict(l=0, r=0, t=50, b=0), 'margin': dict(l=0, r=0, t=50, b=0)
} }
# Dimensionality Reduction Settings # Dimensionality Reduction Settings
@@ -34,24 +44,27 @@ class AppSettings:
# Available Methods # Available Methods
REDUCTION_METHODS = [ REDUCTION_METHODS = [
{"label": "PCA", "value": "pca"}, {'label': 'PCA', 'value': 'pca'},
{"label": "t-SNE", "value": "tsne"}, {'label': 't-SNE', 'value': 'tsne'},
{"label": "UMAP", "value": "umap"}, {'label': 'UMAP', 'value': 'umap'}
] ]
COLOR_OPTIONS = [ COLOR_OPTIONS = [
{"label": "Category", "value": "category"}, {'label': 'Category', 'value': 'category'},
{"label": "Subcategory", "value": "subcategory"}, {'label': 'Subcategory', 'value': 'subcategory'},
{"label": "Tags", "value": "tags"}, {'label': 'Tags', 'value': 'tags'}
] ]
DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}] DIMENSION_OPTIONS = [
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
]
# Default Values # Default Values
DEFAULT_METHOD = "pca" DEFAULT_METHOD = 'pca'
DEFAULT_COLOR_BY = "category" DEFAULT_COLOR_BY = 'category'
DEFAULT_DIMENSIONS = "3d" DEFAULT_DIMENSIONS = '3d'
DEFAULT_SHOW_PROMPTS = ["show"] DEFAULT_SHOW_PROMPTS = ['show']
# Plot Marker Settings # Plot Marker Settings
DOCUMENT_MARKER_SIZE_2D = 8 DOCUMENT_MARKER_SIZE_2D = 8
@@ -59,8 +72,8 @@ class AppSettings:
PROMPT_MARKER_SIZE_2D = 10 PROMPT_MARKER_SIZE_2D = 10
PROMPT_MARKER_SIZE_3D = 6 PROMPT_MARKER_SIZE_3D = 6
DOCUMENT_MARKER_SYMBOL = "circle" DOCUMENT_MARKER_SYMBOL = 'circle'
PROMPT_MARKER_SYMBOL = "diamond" PROMPT_MARKER_SYMBOL = 'diamond'
DOCUMENT_OPACITY = 1.0 DOCUMENT_OPACITY = 1.0
PROMPT_OPACITY = 0.8 PROMPT_OPACITY = 0.8
@@ -69,34 +82,26 @@ class AppSettings:
TEXT_PREVIEW_LENGTH = 100 TEXT_PREVIEW_LENGTH = 100
# App Configuration # App Configuration
DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true" DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true'
HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1") HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1')
PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050")) PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050'))
# Bootstrap Theme # Bootstrap Theme
EXTERNAL_STYLESHEETS = [ EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css']
"https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
]
@classmethod @classmethod
def get_plot_marker_config( def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]:
cls, dimensions: str, is_prompt: bool = False
) -> Dict[str, Any]:
if is_prompt: if is_prompt:
size = ( size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D
cls.PROMPT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.PROMPT_MARKER_SIZE_2D
)
symbol = cls.PROMPT_MARKER_SYMBOL symbol = cls.PROMPT_MARKER_SYMBOL
opacity = cls.PROMPT_OPACITY opacity = cls.PROMPT_OPACITY
else: else:
size = ( size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D
cls.DOCUMENT_MARKER_SIZE_3D
if dimensions == "3d"
else cls.DOCUMENT_MARKER_SIZE_2D
)
symbol = cls.DOCUMENT_MARKER_SYMBOL symbol = cls.DOCUMENT_MARKER_SYMBOL
opacity = cls.DOCUMENT_OPACITY opacity = cls.DOCUMENT_OPACITY
return {"size": size, "symbol": symbol, "opacity": opacity} return {
'size': size,
'symbol': symbol,
'opacity': opacity
}

View File

@@ -6,67 +6,34 @@ from ..models.schemas import Document
class NDJSONParser: class NDJSONParser:
@staticmethod @staticmethod
def parse_upload_contents(contents: str) -> List[Document]: def parse_upload_contents(contents: str) -> List[Document]:
content_type, content_string = contents.split(",") content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string) decoded = base64.b64decode(content_string)
text_content = decoded.decode("utf-8") text_content = decoded.decode('utf-8')
return NDJSONParser.parse_text(text_content) return NDJSONParser.parse_text(text_content)
@staticmethod @staticmethod
def parse_text(text_content: str) -> List[Document]: def parse_text(text_content: str) -> List[Document]:
documents = [] documents = []
for line_num, line in enumerate(text_content.strip().split("\n"), 1): for line in text_content.strip().split('\n'):
if line.strip(): if line.strip():
try:
doc_dict = json.loads(line) doc_dict = json.loads(line)
doc = NDJSONParser._dict_to_document(doc_dict) doc = NDJSONParser._dict_to_document(doc_dict)
documents.append(doc) documents.append(doc)
except json.JSONDecodeError as e:
raise json.JSONDecodeError(
f"Invalid JSON on line {line_num}: {e.msg}", e.doc, e.pos
)
except KeyError as e:
raise KeyError(f"Missing required field {e} on line {line_num}")
except (TypeError, ValueError) as e:
raise ValueError(
f"Invalid data format on line {line_num}: {str(e)}"
)
return documents return documents
@staticmethod @staticmethod
def _dict_to_document(doc_dict: dict) -> Document: def _dict_to_document(doc_dict: dict) -> Document:
if "id" not in doc_dict: if 'id' not in doc_dict:
doc_dict["id"] = str(uuid.uuid4()) doc_dict['id'] = str(uuid.uuid4())
# Validate required fields
if "text" not in doc_dict:
raise KeyError("'text'")
if "embedding" not in doc_dict:
raise KeyError("'embedding'")
# Validate embedding format
embedding = doc_dict["embedding"]
if not isinstance(embedding, list):
raise ValueError(
f"Embedding must be a list, got {type(embedding).__name__}"
)
if not embedding:
raise ValueError("Embedding cannot be empty")
# Check that all embedding values are numbers
for i, val in enumerate(embedding):
if not isinstance(val, (int, float)) or val != val: # NaN check
raise ValueError(
f"Embedding contains invalid value at index {i}: {val}"
)
return Document( return Document(
id=doc_dict["id"], id=doc_dict['id'],
text=doc_dict["text"], text=doc_dict['text'],
embedding=embedding, embedding=doc_dict['embedding'],
category=doc_dict.get("category"), category=doc_dict.get('category'),
subcategory=doc_dict.get("subcategory"), subcategory=doc_dict.get('subcategory'),
tags=doc_dict.get("tags"), tags=doc_dict.get('tags')
) )

View File

@@ -5,12 +5,11 @@ from .parser import NDJSONParser
class DataProcessor: class DataProcessor:
def __init__(self): def __init__(self):
self.parser = NDJSONParser() self.parser = NDJSONParser()
def process_upload( def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData:
self, contents: str, filename: Optional[str] = None
) -> ProcessedData:
try: try:
documents = self.parser.parse_upload_contents(contents) documents = self.parser.parse_upload_contents(contents)
embeddings = self._extract_embeddings(documents) embeddings = self._extract_embeddings(documents)
@@ -31,9 +30,7 @@ class DataProcessor:
return np.array([]) return np.array([])
return np.array([doc.embedding for doc in documents]) return np.array([doc.embedding for doc in documents])
def combine_data( def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None
) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
if not doc_data or doc_data.error: if not doc_data or doc_data.error:
raise ValueError("Invalid document data") raise ValueError("Invalid document data")
@@ -47,13 +44,11 @@ class DataProcessor:
return all_embeddings, documents, prompts return all_embeddings, documents, prompts
def split_reduced_data( def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]:
self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
doc_reduced = reduced_embeddings[:n_documents] doc_reduced = reduced_embeddings[:n_documents]
prompt_reduced = None prompt_reduced = None
if n_prompts > 0: if n_prompts > 0:
prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts] prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts]
return doc_reduced, prompt_reduced return doc_reduced, prompt_reduced

View File

@@ -7,6 +7,7 @@ from .schemas import ReducedData
class DimensionalityReducer(ABC): class DimensionalityReducer(ABC):
def __init__(self, n_components: int = 3, random_state: int = 42): def __init__(self, n_components: int = 3, random_state: int = 42):
self.n_components = n_components self.n_components = n_components
self.random_state = random_state self.random_state = random_state
@@ -22,6 +23,7 @@ class DimensionalityReducer(ABC):
class PCAReducer(DimensionalityReducer): class PCAReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = PCA(n_components=self.n_components) self._reducer = PCA(n_components=self.n_components)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
@@ -31,7 +33,7 @@ class PCAReducer(DimensionalityReducer):
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=variance_explained, variance_explained=variance_explained,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components, n_components=self.n_components
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
@@ -39,17 +41,16 @@ class PCAReducer(DimensionalityReducer):
class TSNEReducer(DimensionalityReducer): class TSNEReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = TSNE( self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state)
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit(embeddings) reduced = self._reducer.fit(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components, n_components=self.n_components
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
@@ -57,17 +58,16 @@ class TSNEReducer(DimensionalityReducer):
class UMAPReducer(DimensionalityReducer): class UMAPReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData: def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = umap.UMAP( self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state)
n_components=self.n_components, random_state=self.random_state
)
reduced = self._reducer.fit_transform(embeddings) reduced = self._reducer.fit_transform(embeddings)
return ReducedData( return ReducedData(
reduced_embeddings=reduced, reduced_embeddings=reduced,
variance_explained=None, variance_explained=None,
method=self.get_method_name(), method=self.get_method_name(),
n_components=self.n_components, n_components=self.n_components
) )
def get_method_name(self) -> str: def get_method_name(self) -> str:
@@ -75,21 +75,20 @@ class UMAPReducer(DimensionalityReducer):
class ReducerFactory: class ReducerFactory:
@staticmethod @staticmethod
def create_reducer( def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer:
method: str, n_components: int = 3, random_state: int = 42
) -> DimensionalityReducer:
method_lower = method.lower() method_lower = method.lower()
if method_lower == "pca": if method_lower == 'pca':
return PCAReducer(n_components=n_components, random_state=random_state) return PCAReducer(n_components=n_components, random_state=random_state)
elif method_lower == "tsne": elif method_lower == 'tsne':
return TSNEReducer(n_components=n_components, random_state=random_state) return TSNEReducer(n_components=n_components, random_state=random_state)
elif method_lower == "umap": elif method_lower == 'umap':
return UMAPReducer(n_components=n_components, random_state=random_state) return UMAPReducer(n_components=n_components, random_state=random_state)
else: else:
raise ValueError(f"Unknown reduction method: {method}") raise ValueError(f"Unknown reduction method: {method}")
@staticmethod @staticmethod
def get_available_methods() -> list: def get_available_methods() -> list:
return ["pca", "tsne", "umap"] return ['pca', 'tsne', 'umap']

View File

@@ -54,7 +54,5 @@ class PlotData:
def __post_init__(self): def __post_init__(self):
if not isinstance(self.coordinates, np.ndarray): if not isinstance(self.coordinates, np.ndarray):
self.coordinates = np.array(self.coordinates) self.coordinates = np.array(self.coordinates)
if self.prompt_coordinates is not None and not isinstance( if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray):
self.prompt_coordinates, np.ndarray
):
self.prompt_coordinates = np.array(self.prompt_coordinates) self.prompt_coordinates = np.array(self.prompt_coordinates)

View File

@@ -3,53 +3,36 @@ from ...data.processor import DataProcessor
class DataProcessingCallbacks: class DataProcessingCallbacks:
def __init__(self): def __init__(self):
self.processor = DataProcessor() self.processor = DataProcessor()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
[ Output('processed-data', 'data'),
Output("processed-data", "data", allow_duplicate=True), Input('upload-data', 'contents'),
Output("upload-error-alert", "children", allow_duplicate=True), State('upload-data', 'filename')
Output("upload-error-alert", "is_open", allow_duplicate=True),
],
Input("upload-data", "contents"),
State("upload-data", "filename"),
prevent_initial_call=True,
) )
def process_uploaded_file(contents, filename): def process_uploaded_file(contents, filename):
if contents is None: if contents is None:
return None, "", False return None
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
error_message = self._format_error_message( return {'error': processed_data.error}
processed_data.error, filename
)
return (
{"error": processed_data.error},
error_message,
True, # Show error alert
)
return ( return {
{ 'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
"documents": [ 'embeddings': processed_data.embeddings.tolist()
self._document_to_dict(doc) for doc in processed_data.documents }
],
"embeddings": processed_data.embeddings.tolist(),
},
"",
False, # Hide error alert
)
@callback( @callback(
Output("processed-prompts", "data", allow_duplicate=True), Output('processed-prompts', 'data'),
Input("upload-prompts", "contents"), Input('upload-prompts', 'contents'),
State("upload-prompts", "filename"), State('upload-prompts', 'filename')
prevent_initial_call=True,
) )
def process_uploaded_prompts(contents, filename): def process_uploaded_prompts(contents, filename):
if contents is None: if contents is None:
@@ -58,63 +41,20 @@ class DataProcessingCallbacks:
processed_data = self.processor.process_upload(contents, filename) processed_data = self.processor.process_upload(contents, filename)
if processed_data.error: if processed_data.error:
return {"error": processed_data.error} return {'error': processed_data.error}
return { return {
"prompts": [ 'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
self._document_to_dict(doc) for doc in processed_data.documents 'embeddings': processed_data.embeddings.tolist()
],
"embeddings": processed_data.embeddings.tolist(),
} }
@staticmethod @staticmethod
def _document_to_dict(doc): def _document_to_dict(doc):
return { return {
"id": doc.id, 'id': doc.id,
"text": doc.text, 'text': doc.text,
"embedding": doc.embedding, 'embedding': doc.embedding,
"category": doc.category, 'category': doc.category,
"subcategory": doc.subcategory, 'subcategory': doc.subcategory,
"tags": doc.tags, 'tags': doc.tags
} }
@staticmethod
def _format_error_message(error: str, filename: str | None = None) -> str:
"""Format error message with helpful guidance for users."""
file_part = f" in file '{filename}'" if filename else ""
# Check for common error patterns and provide helpful messages
if "embedding" in error.lower() and (
"key" in error.lower() or "required field" in error.lower()
):
return (
f"❌ Missing 'embedding' field{file_part}. "
"Each line must contain an 'embedding' field with a list of numbers."
)
elif "text" in error.lower() and (
"key" in error.lower() or "required field" in error.lower()
):
return (
f"❌ Missing 'text' field{file_part}. "
"Each line must contain a 'text' field with the document content."
)
elif "json" in error.lower() and "decode" in error.lower():
return (
f"❌ Invalid JSON format{file_part}. "
"Please check that each line is valid JSON with proper syntax (quotes, braces, etc.)."
)
elif "unicode" in error.lower() or "decode" in error.lower():
return (
f"❌ File encoding issue{file_part}. "
"Please ensure the file is saved in UTF-8 format and contains no binary data."
)
elif "array" in error.lower() or "list" in error.lower():
return (
f"❌ Invalid embedding format{file_part}. "
"Embeddings must be arrays/lists of numbers, not strings or other types."
)
else:
return (
f"❌ Error processing file{file_part}: {error}. "
"Please check that your file is valid NDJSON with required 'text' and 'embedding' fields."
)

View File

@@ -4,50 +4,47 @@ import dash_bootstrap_components as dbc
class InteractionCallbacks: class InteractionCallbacks:
def __init__(self): def __init__(self):
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output("point-details", "children"), Output('point-details', 'children'),
Input("embedding-plot", "clickData"), Input('embedding-plot', 'clickData'),
[State("processed-data", "data"), State("processed-prompts", "data")], [State('processed-data', 'data'),
State('processed-prompts', 'data')]
) )
def display_click_data(clickData, data, prompts_data): def display_click_data(clickData, data, prompts_data):
if not clickData or not data: if not clickData or not data:
return "Click on a point to see details" return "Click on a point to see details"
point_data = clickData["points"][0] point_data = clickData['points'][0]
trace_name = point_data.get("fullData", {}).get("name", "Documents") trace_name = point_data.get('fullData', {}).get('name', 'Documents')
if "pointIndex" in point_data: if 'pointIndex' in point_data:
point_index = point_data["pointIndex"] point_index = point_data['pointIndex']
elif "pointNumber" in point_data: elif 'pointNumber' in point_data:
point_index = point_data["pointNumber"] point_index = point_data['pointNumber']
else: else:
return "Could not identify clicked point" return "Could not identify clicked point"
if ( if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
trace_name.startswith("Prompts") item = prompts_data['prompts'][point_index]
and prompts_data item_type = 'Prompt'
and "prompts" in prompts_data
):
item = prompts_data["prompts"][point_index]
item_type = "Prompt"
else: else:
item = data["documents"][point_index] item = data['documents'][point_index]
item_type = "Document" item_type = 'Document'
return self._create_detail_card(item, item_type) return self._create_detail_card(item, item_type)
@callback( @callback(
[ [Output('processed-data', 'data', allow_duplicate=True),
Output("processed-data", "data", allow_duplicate=True), Output('processed-prompts', 'data', allow_duplicate=True),
Output("processed-prompts", "data", allow_duplicate=True), Output('point-details', 'children', allow_duplicate=True)],
Output("point-details", "children", allow_duplicate=True), Input('reset-button', 'n_clicks'),
], prevent_initial_call=True
Input("reset-button", "n_clicks"),
prevent_initial_call=True,
) )
def reset_data(n_clicks): def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0: if n_clicks is None or n_clicks == 0:
@@ -57,26 +54,13 @@ class InteractionCallbacks:
@staticmethod @staticmethod
def _create_detail_card(item, item_type): def _create_detail_card(item, item_type):
return dbc.Card( return dbc.Card([
[ dbc.CardBody([
dbc.CardBody(
[
html.H5(f"{item_type}: {item['id']}", className="card-title"), html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Text: {item['text']}", className="card-text"), html.P(f"Text: {item['text']}", className="card-text"),
html.P( html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
f"Category: {item.get('category', 'Unknown')}", html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
className="card-text", html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
), html.P(f"Type: {item_type}", className="card-text text-muted")
html.P( ])
f"Subcategory: {item.get('subcategory', 'Unknown')}", ])
className="card-text",
),
html.P(
f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
className="card-text",
),
html.P(f"Type: {item_type}", className="card-text text-muted"),
]
)
]
)

View File

@@ -7,76 +7,60 @@ from ...visualization.plots import PlotFactory
class VisualizationCallbacks: class VisualizationCallbacks:
def __init__(self): def __init__(self):
self.plot_factory = PlotFactory() self.plot_factory = PlotFactory()
self._register_callbacks() self._register_callbacks()
def _register_callbacks(self): def _register_callbacks(self):
@callback( @callback(
Output("embedding-plot", "figure"), Output('embedding-plot', 'figure'),
[ [Input('processed-data', 'data'),
Input("processed-data", "data"), Input('processed-prompts', 'data'),
Input("processed-prompts", "data"), Input('method-dropdown', 'value'),
Input("method-dropdown", "value"), Input('color-dropdown', 'value'),
Input("color-dropdown", "value"), Input('dimension-toggle', 'value'),
Input("dimension-toggle", "value"), Input('show-prompts-toggle', 'value')]
Input("show-prompts-toggle", "value"),
],
) )
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or "error" in data: if not data or 'error' in data:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization", text="Upload a valid NDJSON file to see visualization",
xref="paper", xref="paper", yref="paper",
yref="paper", x=0.5, y=0.5, xanchor='center', yanchor='middle',
x=0.5, showarrow=False, font=dict(size=16)
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
try: try:
doc_embeddings = np.array(data["embeddings"]) doc_embeddings = np.array(data['embeddings'])
all_embeddings = doc_embeddings all_embeddings = doc_embeddings
has_prompts = ( has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
prompts_data
and "error" not in prompts_data
and prompts_data.get("prompts")
)
if has_prompts: if has_prompts:
prompt_embeddings = np.array(prompts_data["embeddings"]) prompt_embeddings = np.array(prompts_data['embeddings'])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == "3d" else 2 n_components = 3 if dimensions == '3d' else 2
reducer = ReducerFactory.create_reducer( reducer = ReducerFactory.create_reducer(method, n_components=n_components)
method, n_components=n_components
)
reduced_data = reducer.fit_transform(all_embeddings) reduced_data = reducer.fit_transform(all_embeddings)
doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)] doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
prompt_reduced = None prompt_reduced = None
if has_prompts: if has_prompts:
prompt_reduced = reduced_data.reduced_embeddings[ prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
len(doc_embeddings) :
]
documents = [self._dict_to_document(doc) for doc in data["documents"]] documents = [self._dict_to_document(doc) for doc in data['documents']]
prompts = None prompts = None
if has_prompts: if has_prompts:
prompts = [ prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
self._dict_to_document(prompt)
for prompt in prompts_data["prompts"]
]
plot_data = PlotData( plot_data = PlotData(
documents=documents, documents=documents,
coordinates=doc_reduced, coordinates=doc_reduced,
prompts=prompts, prompts=prompts,
prompt_coordinates=prompt_reduced, prompt_coordinates=prompt_reduced
) )
return self.plot_factory.create_plot( return self.plot_factory.create_plot(
@@ -86,23 +70,18 @@ class VisualizationCallbacks:
except Exception as e: except Exception as e:
return go.Figure().add_annotation( return go.Figure().add_annotation(
text=f"Error creating visualization: {str(e)}", text=f"Error creating visualization: {str(e)}",
xref="paper", xref="paper", yref="paper",
yref="paper", x=0.5, y=0.5, xanchor='center', yanchor='middle',
x=0.5, showarrow=False, font=dict(size=16)
y=0.5,
xanchor="center",
yanchor="middle",
showarrow=False,
font=dict(size=16),
) )
@staticmethod @staticmethod
def _dict_to_document(doc_dict): def _dict_to_document(doc_dict):
return Document( return Document(
id=doc_dict["id"], id=doc_dict['id'],
text=doc_dict["text"], text=doc_dict['text'],
embedding=doc_dict["embedding"], embedding=doc_dict['embedding'],
category=doc_dict.get("category"), category=doc_dict.get('category'),
subcategory=doc_dict.get("subcategory"), subcategory=doc_dict.get('subcategory'),
tags=doc_dict.get("tags", []), tags=doc_dict.get('tags', [])
) )

View File

@@ -4,84 +4,79 @@ from .upload import UploadComponent
class SidebarComponent: class SidebarComponent:
def __init__(self): def __init__(self):
self.upload_component = UploadComponent() self.upload_component = UploadComponent()
def create_layout(self): def create_layout(self):
return dbc.Col( return dbc.Col([
[
html.H5("Upload Data", className="mb-3"), html.H5("Upload Data", className="mb-3"),
self.upload_component.create_error_alert(),
self.upload_component.create_data_upload(), self.upload_component.create_data_upload(),
self.upload_component.create_prompts_upload(), self.upload_component.create_prompts_upload(),
self.upload_component.create_reset_button(), self.upload_component.create_reset_button(),
html.H5("Visualization Controls", className="mb-3"), html.H5("Visualization Controls", className="mb-3"),
] self._create_method_dropdown(),
+ self._create_method_dropdown() self._create_color_dropdown(),
+ self._create_color_dropdown() self._create_dimension_toggle(),
+ self._create_dimension_toggle() self._create_prompts_toggle(),
+ self._create_prompts_toggle()
+ [
html.H5("Point Details", className="mb-3"), html.H5("Point Details", className="mb-3"),
html.Div( html.Div(id='point-details', children="Click on a point to see details")
id="point-details", children="Click on a point to see details"
), ], width=3, style={'padding-right': '20px'})
],
width=3,
style={"padding-right": "20px"},
)
def _create_method_dropdown(self): def _create_method_dropdown(self):
return [ return [
dbc.Label("Method:"), dbc.Label("Method:"),
dcc.Dropdown( dcc.Dropdown(
id="method-dropdown", id='method-dropdown',
options=[ options=[
{"label": "PCA", "value": "pca"}, {'label': 'PCA', 'value': 'pca'},
{"label": "t-SNE", "value": "tsne"}, {'label': 't-SNE', 'value': 'tsne'},
{"label": "UMAP", "value": "umap"}, {'label': 'UMAP', 'value': 'umap'}
], ],
value="pca", value='pca',
style={"margin-bottom": "15px"}, style={'margin-bottom': '15px'}
), )
] ]
def _create_color_dropdown(self): def _create_color_dropdown(self):
return [ return [
dbc.Label("Color by:"), dbc.Label("Color by:"),
dcc.Dropdown( dcc.Dropdown(
id="color-dropdown", id='color-dropdown',
options=[ options=[
{"label": "Category", "value": "category"}, {'label': 'Category', 'value': 'category'},
{"label": "Subcategory", "value": "subcategory"}, {'label': 'Subcategory', 'value': 'subcategory'},
{"label": "Tags", "value": "tags"}, {'label': 'Tags', 'value': 'tags'}
], ],
value="category", value='category',
style={"margin-bottom": "15px"}, style={'margin-bottom': '15px'}
), )
] ]
def _create_dimension_toggle(self): def _create_dimension_toggle(self):
return [ return [
dbc.Label("Dimensions:"), dbc.Label("Dimensions:"),
dcc.RadioItems( dcc.RadioItems(
id="dimension-toggle", id='dimension-toggle',
options=[ options=[
{"label": "2D", "value": "2d"}, {'label': '2D', 'value': '2d'},
{"label": "3D", "value": "3d"}, {'label': '3D', 'value': '3d'}
], ],
value="3d", value='3d',
style={"margin-bottom": "20px"}, style={'margin-bottom': '20px'}
), )
] ]
def _create_prompts_toggle(self): def _create_prompts_toggle(self):
return [ return [
dbc.Label("Show Prompts:"), dbc.Label("Show Prompts:"),
dcc.Checklist( dcc.Checklist(
id="show-prompts-toggle", id='show-prompts-toggle',
options=[{"label": "Show prompts on plot", "value": "show"}], options=[{'label': 'Show prompts on plot', 'value': 'show'}],
value=["show"], value=['show'],
style={"margin-bottom": "20px"}, style={'margin-bottom': '20px'}
), )
] ]

View File

@@ -3,62 +3,58 @@ import dash_bootstrap_components as dbc
class UploadComponent: class UploadComponent:
@staticmethod @staticmethod
def create_data_upload(): def create_data_upload():
return dcc.Upload( return dcc.Upload(
id="upload-data", id='upload-data',
children=html.Div(["Drag and Drop or ", html.A("Select Files")]), children=html.Div([
'Drag and Drop or ',
html.A('Select Files')
]),
style={ style={
"width": "100%", 'width': '100%',
"height": "60px", 'height': '60px',
"lineHeight": "60px", 'lineHeight': '60px',
"borderWidth": "1px", 'borderWidth': '1px',
"borderStyle": "dashed", 'borderStyle': 'dashed',
"borderRadius": "5px", 'borderRadius': '5px',
"textAlign": "center", 'textAlign': 'center',
"margin-bottom": "20px", 'margin-bottom': '20px'
}, },
multiple=False, multiple=False
) )
@staticmethod @staticmethod
def create_prompts_upload(): def create_prompts_upload():
return dcc.Upload( return dcc.Upload(
id="upload-prompts", id='upload-prompts',
children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]), children=html.Div([
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={ style={
"width": "100%", 'width': '100%',
"height": "60px", 'height': '60px',
"lineHeight": "60px", 'lineHeight': '60px',
"borderWidth": "1px", 'borderWidth': '1px',
"borderStyle": "dashed", 'borderStyle': 'dashed',
"borderRadius": "5px", 'borderRadius': '5px',
"textAlign": "center", 'textAlign': 'center',
"margin-bottom": "20px", 'margin-bottom': '20px',
"borderColor": "#28a745", 'borderColor': '#28a745'
}, },
multiple=False, multiple=False
) )
@staticmethod @staticmethod
def create_reset_button(): def create_reset_button():
return dbc.Button( return dbc.Button(
"Reset All Data", "Reset All Data",
id="reset-button", id='reset-button',
color="danger", color='danger',
outline=True, outline=True,
size="sm", size='sm',
className="mb-3", className='mb-3',
style={"width": "100%"}, style={'width': '100%'}
)
@staticmethod
def create_error_alert():
"""Create error alert component for data upload issues."""
return dbc.Alert(
id="upload-error-alert",
dismissable=True,
is_open=False,
color="danger",
className="mb-3",
) )

View File

@@ -4,44 +4,41 @@ from .components.sidebar import SidebarComponent
class AppLayout: class AppLayout:
def __init__(self): def __init__(self):
self.sidebar = SidebarComponent() self.sidebar = SidebarComponent()
def create_layout(self): def create_layout(self):
return dbc.Container( return dbc.Container([
[self._create_header(), self._create_main_content()] self._create_header(),
+ self._create_stores(), self._create_main_content(),
fluid=True, self._create_stores()
) ], fluid=True)
def _create_header(self): def _create_header(self):
return dbc.Row( return dbc.Row([
[ dbc.Col([
dbc.Col(
[
html.H1("EmbeddingBuddy", className="text-center mb-4"), html.H1("EmbeddingBuddy", className="text-center mb-4"),
], ], width=12)
width=12, ])
)
]
)
def _create_main_content(self): def _create_main_content(self):
return dbc.Row( return dbc.Row([
[self.sidebar.create_layout(), self._create_visualization_area()] self.sidebar.create_layout(),
) self._create_visualization_area()
])
def _create_visualization_area(self): def _create_visualization_area(self):
return dbc.Col( return dbc.Col([
[
dcc.Graph( dcc.Graph(
id="embedding-plot", id='embedding-plot',
style={"height": "85vh", "width": "100%"}, style={'height': '85vh', 'width': '100%'},
config={"responsive": True, "displayModeBar": True}, config={'responsive': True, 'displayModeBar': True}
)
],
width=9,
) )
], width=9)
def _create_stores(self): def _create_stores(self):
return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")] return [
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
]

View File

@@ -4,33 +4,30 @@ from ..models.schemas import Document
class ColorMapper: class ColorMapper:
@staticmethod @staticmethod
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]: def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
if color_by == "category": if color_by == 'category':
return [doc.category for doc in documents] return [doc.category for doc in documents]
elif color_by == "subcategory": elif color_by == 'subcategory':
return [doc.subcategory for doc in documents] return [doc.subcategory for doc in documents]
elif color_by == "tags": elif color_by == 'tags':
return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents] return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents]
else: else:
return ["All"] * len(documents) return ['All'] * len(documents)
@staticmethod @staticmethod
def to_grayscale_hex(color_str: str) -> str: def to_grayscale_hex(color_str: str) -> str:
try: try:
if color_str.startswith("#"): if color_str.startswith('#'):
rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5)) rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
else: else:
rgb = pc.hex_to_rgb( rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0]
)
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
gray_rgb = ( gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[1] * 0.3, gray_value * 0.7 + rgb[1] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3, gray_value * 0.7 + rgb[2] * 0.3)
) return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})"
except: # noqa: E722 except: # noqa: E722
return "rgb(128,128,128)" return 'rgb(128,128,128)'

View File

@@ -7,172 +7,139 @@ from .colors import ColorMapper
class PlotFactory: class PlotFactory:
def __init__(self): def __init__(self):
self.color_mapper = ColorMapper() self.color_mapper = ColorMapper()
def create_plot( def create_plot(self, plot_data: PlotData, dimensions: str = '3d',
self, color_by: str = 'category', method: str = 'PCA',
plot_data: PlotData, show_prompts: Optional[List[str]] = None) -> go.Figure:
dimensions: str = "3d",
color_by: str = "category", if plot_data.prompts and show_prompts and 'show' in show_prompts:
method: str = "PCA",
show_prompts: Optional[List[str]] = None,
) -> go.Figure:
if plot_data.prompts and show_prompts and "show" in show_prompts:
return self._create_dual_plot(plot_data, dimensions, color_by, method) return self._create_dual_plot(plot_data, dimensions, color_by, method)
else: else:
return self._create_single_plot(plot_data, dimensions, color_by, method) return self._create_single_plot(plot_data, dimensions, color_by, method)
def _create_single_plot( def _create_single_plot(self, plot_data: PlotData, dimensions: str,
self, plot_data: PlotData, dimensions: str, color_by: str, method: str color_by: str, method: str) -> go.Figure:
) -> go.Figure: df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
df = self._prepare_dataframe( color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
plot_data.documents, plot_data.coordinates, dimensions
)
color_values = self.color_mapper.create_color_mapping(
plot_data.documents, color_by
)
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"] hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
if dimensions == "3d": if dimensions == '3d':
fig = px.scatter_3d( fig = px.scatter_3d(
df, df, x='dim_1', y='dim_2', z='dim_3',
x="dim_1",
y="dim_2",
z="dim_3",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f"3D Embedding Visualization - {method} (colored by {color_by})", title=f'3D Embedding Visualization - {method} (colored by {color_by})'
) )
fig.update_traces(marker=dict(size=5)) fig.update_traces(marker=dict(size=5))
else: else:
fig = px.scatter( fig = px.scatter(
df, df, x='dim_1', y='dim_2',
x="dim_1",
y="dim_2",
color=color_values, color=color_values,
hover_data=hover_fields, hover_data=hover_fields,
title=f"2D Embedding Visualization - {method} (colored by {color_by})", title=f'2D Embedding Visualization - {method} (colored by {color_by})'
) )
fig.update_traces(marker=dict(size=8)) fig.update_traces(marker=dict(size=8))
fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)) fig.update_layout(
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig return fig
def _create_dual_plot( def _create_dual_plot(self, plot_data: PlotData, dimensions: str,
self, plot_data: PlotData, dimensions: str, color_by: str, method: str color_by: str, method: str) -> go.Figure:
) -> go.Figure:
fig = go.Figure() fig = go.Figure()
doc_df = self._prepare_dataframe( doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
plot_data.documents, plot_data.coordinates, dimensions doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
)
doc_color_values = self.color_mapper.create_color_mapping(
plot_data.documents, color_by
)
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"] hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
if dimensions == "3d": if dimensions == '3d':
doc_fig = px.scatter_3d( doc_fig = px.scatter_3d(
doc_df, doc_df, x='dim_1', y='dim_2', z='dim_3',
x="dim_1",
y="dim_2",
z="dim_3",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields, hover_data=hover_fields
) )
else: else:
doc_fig = px.scatter( doc_fig = px.scatter(
doc_df, doc_df, x='dim_1', y='dim_2',
x="dim_1",
y="dim_2",
color=doc_color_values, color=doc_color_values,
hover_data=hover_fields, hover_data=hover_fields
) )
for trace in doc_fig.data: for trace in doc_fig.data:
trace.name = f"Documents - {trace.name}" trace.name = f'Documents - {trace.name}'
if dimensions == "3d": if dimensions == '3d':
trace.marker.size = 5 trace.marker.size = 5
trace.marker.symbol = "circle" trace.marker.symbol = 'circle'
else: else:
trace.marker.size = 8 trace.marker.size = 8
trace.marker.symbol = "circle" trace.marker.symbol = 'circle'
trace.marker.opacity = 1.0 trace.marker.opacity = 1.0
fig.add_trace(trace) fig.add_trace(trace)
if plot_data.prompts and plot_data.prompt_coordinates is not None: if plot_data.prompts and plot_data.prompt_coordinates is not None:
prompt_df = self._prepare_dataframe( prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions)
plot_data.prompts, plot_data.prompt_coordinates, dimensions prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by)
)
prompt_color_values = self.color_mapper.create_color_mapping(
plot_data.prompts, color_by
)
if dimensions == "3d": if dimensions == '3d':
prompt_fig = px.scatter_3d( prompt_fig = px.scatter_3d(
prompt_df, prompt_df, x='dim_1', y='dim_2', z='dim_3',
x="dim_1",
y="dim_2",
z="dim_3",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields, hover_data=hover_fields
) )
else: else:
prompt_fig = px.scatter( prompt_fig = px.scatter(
prompt_df, prompt_df, x='dim_1', y='dim_2',
x="dim_1",
y="dim_2",
color=prompt_color_values, color=prompt_color_values,
hover_data=hover_fields, hover_data=hover_fields
) )
for trace in prompt_fig.data: for trace in prompt_fig.data:
if hasattr(trace.marker, "color") and isinstance( if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
trace.marker.color, str trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color)
):
trace.marker.color = self.color_mapper.to_grayscale_hex(
trace.marker.color
)
trace.name = f"Prompts - {trace.name}" trace.name = f'Prompts - {trace.name}'
if dimensions == "3d": if dimensions == '3d':
trace.marker.size = 6 trace.marker.size = 6
trace.marker.symbol = "diamond" trace.marker.symbol = 'diamond'
else: else:
trace.marker.size = 10 trace.marker.size = 10
trace.marker.symbol = "diamond" trace.marker.symbol = 'diamond'
trace.marker.opacity = 0.8 trace.marker.opacity = 0.8
fig.add_trace(trace) fig.add_trace(trace)
title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})" title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
fig.update_layout( fig.update_layout(
title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0) title=title,
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
) )
return fig return fig
def _prepare_dataframe( def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame:
self, documents: List[Document], coordinates, dimensions: str
) -> pd.DataFrame:
df_data = [] df_data = []
for i, doc in enumerate(documents): for i, doc in enumerate(documents):
row = { row = {
"id": doc.id, 'id': doc.id,
"text": doc.text, 'text': doc.text,
"text_preview": doc.text[:100] + "..." 'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text,
if len(doc.text) > 100 'category': doc.category,
else doc.text, 'subcategory': doc.subcategory,
"category": doc.category, 'tags_str': ', '.join(doc.tags) if doc.tags else 'None',
"subcategory": doc.subcategory, 'dim_1': coordinates[i, 0],
"tags_str": ", ".join(doc.tags) if doc.tags else "None", 'dim_2': coordinates[i, 1],
"dim_1": coordinates[i, 0],
"dim_2": coordinates[i, 1],
} }
if dimensions == "3d": if dimensions == '3d':
row["dim_3"] = coordinates[i, 2] row['dim_3'] = coordinates[i, 2]
df_data.append(row) df_data.append(row)
return pd.DataFrame(df_data) return pd.DataFrame(df_data)

View File

@@ -1,197 +0,0 @@
"""Tests for handling bad/invalid data files."""
import pytest
import json
import base64
from src.embeddingbuddy.data.parser import NDJSONParser
from src.embeddingbuddy.data.processor import DataProcessor
class TestBadDataHandling:
"""Test suite for various types of invalid input data."""
def setup_method(self):
"""Set up test fixtures."""
self.parser = NDJSONParser()
self.processor = DataProcessor()
def _create_upload_contents(self, text_content: str) -> str:
"""Helper to create upload contents format."""
encoded = base64.b64encode(text_content.encode("utf-8")).decode("utf-8")
return f"data:application/json;base64,{encoded}"
def test_missing_embedding_field(self):
"""Test files missing required embedding field."""
bad_content = '{"id": "doc_001", "text": "Sample text", "category": "test"}'
with pytest.raises(KeyError, match="embedding"):
self.parser.parse_text(bad_content)
# Test processor error handling
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
assert result.error is not None
assert "embedding" in result.error
def test_missing_text_field(self):
"""Test files missing required text field."""
bad_content = (
'{"id": "doc_001", "embedding": [0.1, 0.2, 0.3], "category": "test"}'
)
with pytest.raises(KeyError, match="text"):
self.parser.parse_text(bad_content)
# Test processor error handling
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
assert result.error is not None
assert "text" in result.error
def test_malformed_json_lines(self):
"""Test files with malformed JSON syntax."""
# Missing closing brace
bad_content = '{"id": "doc_001", "embedding": [0.1, 0.2], "text": "test"'
with pytest.raises(json.JSONDecodeError):
self.parser.parse_text(bad_content)
# Test processor error handling
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
assert result.error is not None
def test_invalid_embedding_types(self):
"""Test files with invalid embedding data types."""
test_cases = [
# String instead of array
'{"id": "doc_001", "embedding": "not_an_array", "text": "test"}',
# Mixed types in array
'{"id": "doc_002", "embedding": [0.1, "text", 0.3], "text": "test"}',
# Empty array
'{"id": "doc_003", "embedding": [], "text": "test"}',
# Null embedding
'{"id": "doc_004", "embedding": null, "text": "test"}',
]
for bad_content in test_cases:
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
assert result.error is not None, f"Should fail for: {bad_content}"
def test_inconsistent_embedding_dimensions(self):
"""Test files with embeddings of different dimensions."""
bad_content = """{"id": "doc_001", "embedding": [0.1, 0.2, 0.3, 0.4], "text": "4D embedding"}
{"id": "doc_002", "embedding": [0.1, 0.2, 0.3], "text": "3D embedding"}"""
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
# This might succeed parsing but fail in processing
# The error depends on where dimension validation occurs
if result.error is None:
# If parsing succeeds, check that embeddings have inconsistent shapes
assert len(result.documents) == 2
assert len(result.documents[0].embedding) != len(
result.documents[1].embedding
)
def test_empty_lines_in_ndjson(self):
"""Test files with empty lines mixed in."""
content_with_empty_lines = """{"id": "doc_001", "embedding": [0.1, 0.2], "text": "First line"}
{"id": "doc_002", "embedding": [0.3, 0.4], "text": "After empty line"}"""
# This should work - empty lines should be skipped
documents = self.parser.parse_text(content_with_empty_lines)
assert len(documents) == 2
assert documents[0].id == "doc_001"
assert documents[1].id == "doc_002"
def test_not_ndjson_format(self):
"""Test regular JSON array instead of NDJSON."""
json_array = """[
{"id": "doc_001", "embedding": [0.1, 0.2], "text": "First"},
{"id": "doc_002", "embedding": [0.3, 0.4], "text": "Second"}
]"""
with pytest.raises(json.JSONDecodeError):
self.parser.parse_text(json_array)
def test_binary_content_in_file(self):
"""Test files with binary content mixed in."""
# Simulate binary content that can't be decoded
binary_content = (
b'\x00\x01\x02{"id": "doc_001", "embedding": [0.1], "text": "test"}'
)
# This should result in an error when processing
encoded = base64.b64encode(binary_content).decode("utf-8")
upload_contents = f"data:application/json;base64,{encoded}"
result = self.processor.process_upload(upload_contents)
# Should either fail with UnicodeDecodeError or JSON parsing error
assert result.error is not None
def test_extremely_large_embeddings(self):
"""Test embeddings with very large dimensions."""
large_embedding = [0.1] * 10000 # 10k dimensions
content = json.dumps(
{
"id": "doc_001",
"embedding": large_embedding,
"text": "Large embedding test",
}
)
# This should work but might be slow
upload_contents = self._create_upload_contents(content)
result = self.processor.process_upload(upload_contents)
if result.error is None:
assert len(result.documents) == 1
assert len(result.documents[0].embedding) == 10000
def test_special_characters_in_text(self):
"""Test handling of special characters and unicode."""
special_content = json.dumps(
{
"id": "doc_001",
"embedding": [0.1, 0.2],
"text": 'Special chars: 🚀 ñoñó 中文 \n\t"',
},
ensure_ascii=False,
)
upload_contents = self._create_upload_contents(special_content)
result = self.processor.process_upload(upload_contents)
assert result.error is None
assert len(result.documents) == 1
assert "🚀" in result.documents[0].text
def test_processor_error_structure(self):
"""Test that processor returns proper error structure."""
bad_content = '{"invalid": "json"' # Missing closing brace
upload_contents = self._create_upload_contents(bad_content)
result = self.processor.process_upload(upload_contents)
# Check error structure
assert result.error is not None
assert isinstance(result.error, str)
assert len(result.documents) == 0
assert result.embeddings.size == 0
def test_multiple_errors_in_file(self):
"""Test file with multiple different types of errors."""
multi_error_content = """{"id": "doc_001", "text": "Missing embedding"}
{"id": "doc_002", "embedding": "wrong_type", "text": "Wrong embedding type"}
{"id": "doc_003", "embedding": [0.1, 0.2], "text": "Valid line"}
{"id": "doc_004", "embedding": [0.3, 0.4]""" # Missing text and closing brace
upload_contents = self._create_upload_contents(multi_error_content)
result = self.processor.process_upload(upload_contents)
# Should fail on first error encountered
assert result.error is not None

View File

@@ -6,10 +6,9 @@ from src.embeddingbuddy.models.schemas import Document
class TestNDJSONParser: class TestNDJSONParser:
def test_parse_text_basic(self): def test_parse_text_basic(self):
text_content = ( text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
'{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
)
documents = NDJSONParser.parse_text(text_content) documents = NDJSONParser.parse_text(text_content)
assert len(documents) == 1 assert len(documents) == 1
@@ -33,10 +32,11 @@ class TestNDJSONParser:
class TestDataProcessor: class TestDataProcessor:
def test_extract_embeddings(self): def test_extract_embeddings(self):
documents = [ documents = [
Document(id="1", text="test1", embedding=[0.1, 0.2]), Document(id="1", text="test1", embedding=[0.1, 0.2]),
Document(id="2", text="test2", embedding=[0.3, 0.4]), Document(id="2", text="test2", embedding=[0.3, 0.4])
] ]
processor = DataProcessor() processor = DataProcessor()
@@ -51,18 +51,16 @@ class TestDataProcessor:
doc_data = ProcessedData( doc_data = ProcessedData(
documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])], documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
embeddings=np.array([[0.1, 0.2]]), embeddings=np.array([[0.1, 0.2]])
) )
prompt_data = ProcessedData( prompt_data = ProcessedData(
documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])], documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
embeddings=np.array([[0.3, 0.4]]), embeddings=np.array([[0.3, 0.4]])
) )
processor = DataProcessor() processor = DataProcessor()
all_embeddings, documents, prompts = processor.combine_data( all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data)
doc_data, prompt_data
)
assert all_embeddings.shape == (2, 2) assert all_embeddings.shape == (2, 2)
assert len(documents) == 1 assert len(documents) == 1

View File

@@ -1,41 +1,38 @@
import pytest import pytest
import numpy as np import numpy as np
from src.embeddingbuddy.models.reducers import ( from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer
ReducerFactory,
PCAReducer,
TSNEReducer,
UMAPReducer,
)
class TestReducerFactory: class TestReducerFactory:
def test_create_pca_reducer(self): def test_create_pca_reducer(self):
reducer = ReducerFactory.create_reducer("pca", n_components=2) reducer = ReducerFactory.create_reducer('pca', n_components=2)
assert isinstance(reducer, PCAReducer) assert isinstance(reducer, PCAReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_create_tsne_reducer(self): def test_create_tsne_reducer(self):
reducer = ReducerFactory.create_reducer("tsne", n_components=3) reducer = ReducerFactory.create_reducer('tsne', n_components=3)
assert isinstance(reducer, TSNEReducer) assert isinstance(reducer, TSNEReducer)
assert reducer.n_components == 3 assert reducer.n_components == 3
def test_create_umap_reducer(self): def test_create_umap_reducer(self):
reducer = ReducerFactory.create_reducer("umap", n_components=2) reducer = ReducerFactory.create_reducer('umap', n_components=2)
assert isinstance(reducer, UMAPReducer) assert isinstance(reducer, UMAPReducer)
assert reducer.n_components == 2 assert reducer.n_components == 2
def test_invalid_method(self): def test_invalid_method(self):
with pytest.raises(ValueError, match="Unknown reduction method"): with pytest.raises(ValueError, match="Unknown reduction method"):
ReducerFactory.create_reducer("invalid_method") ReducerFactory.create_reducer('invalid_method')
def test_available_methods(self): def test_available_methods(self):
methods = ReducerFactory.get_available_methods() methods = ReducerFactory.get_available_methods()
assert "pca" in methods assert 'pca' in methods
assert "tsne" in methods assert 'tsne' in methods
assert "umap" in methods assert 'umap' in methods
class TestPCAReducer: class TestPCAReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(100, 512) embeddings = np.random.rand(100, 512)
reducer = PCAReducer(n_components=2) reducer = PCAReducer(n_components=2)
@@ -53,6 +50,7 @@ class TestPCAReducer:
class TestTSNEReducer: class TestTSNEReducer:
def test_fit_transform_small_dataset(self): def test_fit_transform_small_dataset(self):
embeddings = np.random.rand(30, 10) # Small dataset for faster testing embeddings = np.random.rand(30, 10) # Small dataset for faster testing
reducer = TSNEReducer(n_components=2) reducer = TSNEReducer(n_components=2)
@@ -70,6 +68,7 @@ class TestTSNEReducer:
class TestUMAPReducer: class TestUMAPReducer:
def test_fit_transform(self): def test_fit_transform(self):
embeddings = np.random.rand(50, 10) embeddings = np.random.rand(50, 10)
reducer = UMAPReducer(n_components=2) reducer = UMAPReducer(n_components=2)

1082
uv.lock generated

File diff suppressed because it is too large Load Diff