refactor and add tests, v0.2.0

This commit is contained in:
2025-08-13 20:07:40 -07:00
parent 76be59254c
commit 809dbeb783
32 changed files with 1401 additions and 32 deletions

View File

@@ -0,0 +1,3 @@
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
__version__ = "0.1.0"

39
src/embeddingbuddy/app.py Normal file
View File

@@ -0,0 +1,39 @@
import dash
import dash_bootstrap_components as dbc
from .config.settings import AppSettings
from .ui.layout import AppLayout
from .ui.callbacks.data_processing import DataProcessingCallbacks
from .ui.callbacks.visualization import VisualizationCallbacks
from .ui.callbacks.interactions import InteractionCallbacks
def create_app():
app = dash.Dash(
__name__,
external_stylesheets=[dbc.themes.BOOTSTRAP]
)
layout_manager = AppLayout()
app.layout = layout_manager.create_layout()
DataProcessingCallbacks()
VisualizationCallbacks()
InteractionCallbacks()
return app
def run_app(app=None, debug=None, host=None, port=None):
if app is None:
app = create_app()
app.run(
debug=debug if debug is not None else AppSettings.DEBUG,
host=host if host is not None else AppSettings.HOST,
port=port if port is not None else AppSettings.PORT
)
if __name__ == '__main__':
app = create_app()
run_app(app)

View File

View File

@@ -0,0 +1,107 @@
from typing import Dict, Any
import os
class AppSettings:
# UI Configuration
UPLOAD_STYLE = {
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px'
}
PROMPTS_UPLOAD_STYLE = {
**UPLOAD_STYLE,
'borderColor': '#28a745'
}
PLOT_CONFIG = {
'responsive': True,
'displayModeBar': True
}
PLOT_STYLE = {
'height': '85vh',
'width': '100%'
}
PLOT_LAYOUT_CONFIG = {
'height': None,
'autosize': True,
'margin': dict(l=0, r=0, t=50, b=0)
}
# Dimensionality Reduction Settings
DEFAULT_N_COMPONENTS_3D = 3
DEFAULT_N_COMPONENTS_2D = 2
DEFAULT_RANDOM_STATE = 42
# Available Methods
REDUCTION_METHODS = [
{'label': 'PCA', 'value': 'pca'},
{'label': 't-SNE', 'value': 'tsne'},
{'label': 'UMAP', 'value': 'umap'}
]
COLOR_OPTIONS = [
{'label': 'Category', 'value': 'category'},
{'label': 'Subcategory', 'value': 'subcategory'},
{'label': 'Tags', 'value': 'tags'}
]
DIMENSION_OPTIONS = [
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
]
# Default Values
DEFAULT_METHOD = 'pca'
DEFAULT_COLOR_BY = 'category'
DEFAULT_DIMENSIONS = '3d'
DEFAULT_SHOW_PROMPTS = ['show']
# Plot Marker Settings
DOCUMENT_MARKER_SIZE_2D = 8
DOCUMENT_MARKER_SIZE_3D = 5
PROMPT_MARKER_SIZE_2D = 10
PROMPT_MARKER_SIZE_3D = 6
DOCUMENT_MARKER_SYMBOL = 'circle'
PROMPT_MARKER_SYMBOL = 'diamond'
DOCUMENT_OPACITY = 1.0
PROMPT_OPACITY = 0.8
# Text Processing
TEXT_PREVIEW_LENGTH = 100
# App Configuration
DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true'
HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1')
PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050'))
# Bootstrap Theme
EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css']
@classmethod
def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]:
if is_prompt:
size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D
symbol = cls.PROMPT_MARKER_SYMBOL
opacity = cls.PROMPT_OPACITY
else:
size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D
symbol = cls.DOCUMENT_MARKER_SYMBOL
opacity = cls.DOCUMENT_OPACITY
return {
'size': size,
'symbol': symbol,
'opacity': opacity
}

View File

View File

@@ -0,0 +1,39 @@
import json
import uuid
import base64
from typing import List, Union
from ..models.schemas import Document, ProcessedData
class NDJSONParser:
@staticmethod
def parse_upload_contents(contents: str) -> List[Document]:
content_type, content_string = contents.split(',')
decoded = base64.b64decode(content_string)
text_content = decoded.decode('utf-8')
return NDJSONParser.parse_text(text_content)
@staticmethod
def parse_text(text_content: str) -> List[Document]:
documents = []
for line in text_content.strip().split('\n'):
if line.strip():
doc_dict = json.loads(line)
doc = NDJSONParser._dict_to_document(doc_dict)
documents.append(doc)
return documents
@staticmethod
def _dict_to_document(doc_dict: dict) -> Document:
if 'id' not in doc_dict:
doc_dict['id'] = str(uuid.uuid4())
return Document(
id=doc_dict['id'],
text=doc_dict['text'],
embedding=doc_dict['embedding'],
category=doc_dict.get('category'),
subcategory=doc_dict.get('subcategory'),
tags=doc_dict.get('tags')
)

View File

@@ -0,0 +1,54 @@
import numpy as np
from typing import List, Optional, Tuple
from ..models.schemas import Document, ProcessedData
from .parser import NDJSONParser
class DataProcessor:
def __init__(self):
self.parser = NDJSONParser()
def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData:
try:
documents = self.parser.parse_upload_contents(contents)
embeddings = self._extract_embeddings(documents)
return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def process_text(self, text_content: str) -> ProcessedData:
try:
documents = self.parser.parse_text(text_content)
embeddings = self._extract_embeddings(documents)
return ProcessedData(documents=documents, embeddings=embeddings)
except Exception as e:
return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
if not documents:
return np.array([])
return np.array([doc.embedding for doc in documents])
def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
if not doc_data or doc_data.error:
raise ValueError("Invalid document data")
all_embeddings = doc_data.embeddings
documents = doc_data.documents
prompts = None
if prompt_data and not prompt_data.error and prompt_data.documents:
all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
prompts = prompt_data.documents
return all_embeddings, documents, prompts
def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]:
doc_reduced = reduced_embeddings[:n_documents]
prompt_reduced = None
if n_prompts > 0:
prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts]
return doc_reduced, prompt_reduced

View File

View File

@@ -0,0 +1,95 @@
from abc import ABC, abstractmethod
import numpy as np
from typing import Optional, Tuple
from sklearn.decomposition import PCA
import umap
from openTSNE import TSNE
from .schemas import ReducedData
class DimensionalityReducer(ABC):
def __init__(self, n_components: int = 3, random_state: int = 42):
self.n_components = n_components
self.random_state = random_state
self._reducer = None
@abstractmethod
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
pass
@abstractmethod
def get_method_name(self) -> str:
pass
class PCAReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = PCA(n_components=self.n_components)
reduced = self._reducer.fit_transform(embeddings)
variance_explained = self._reducer.explained_variance_ratio_
return ReducedData(
reduced_embeddings=reduced,
variance_explained=variance_explained,
method=self.get_method_name(),
n_components=self.n_components
)
def get_method_name(self) -> str:
return "PCA"
class TSNEReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state)
reduced = self._reducer.fit(embeddings)
return ReducedData(
reduced_embeddings=reduced,
variance_explained=None,
method=self.get_method_name(),
n_components=self.n_components
)
def get_method_name(self) -> str:
return "t-SNE"
class UMAPReducer(DimensionalityReducer):
def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state)
reduced = self._reducer.fit_transform(embeddings)
return ReducedData(
reduced_embeddings=reduced,
variance_explained=None,
method=self.get_method_name(),
n_components=self.n_components
)
def get_method_name(self) -> str:
return "UMAP"
class ReducerFactory:
@staticmethod
def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer:
method_lower = method.lower()
if method_lower == 'pca':
return PCAReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'tsne':
return TSNEReducer(n_components=n_components, random_state=random_state)
elif method_lower == 'umap':
return UMAPReducer(n_components=n_components, random_state=random_state)
else:
raise ValueError(f"Unknown reduction method: {method}")
@staticmethod
def get_available_methods() -> list:
return ['pca', 'tsne', 'umap']

View File

@@ -0,0 +1,58 @@
from typing import List, Optional, Any, Dict
from dataclasses import dataclass
import numpy as np
@dataclass
class Document:
id: str
text: str
embedding: List[float]
category: Optional[str] = None
subcategory: Optional[str] = None
tags: Optional[List[str]] = None
def __post_init__(self):
if self.tags is None:
self.tags = []
if self.category is None:
self.category = "Unknown"
if self.subcategory is None:
self.subcategory = "Unknown"
@dataclass
class ProcessedData:
documents: List[Document]
embeddings: np.ndarray
error: Optional[str] = None
def __post_init__(self):
if self.embeddings is not None and not isinstance(self.embeddings, np.ndarray):
self.embeddings = np.array(self.embeddings)
@dataclass
class ReducedData:
reduced_embeddings: np.ndarray
variance_explained: Optional[np.ndarray] = None
method: str = "unknown"
n_components: int = 2
def __post_init__(self):
if not isinstance(self.reduced_embeddings, np.ndarray):
self.reduced_embeddings = np.array(self.reduced_embeddings)
@dataclass
class PlotData:
documents: List[Document]
coordinates: np.ndarray
prompts: Optional[List[Document]] = None
prompt_coordinates: Optional[np.ndarray] = None
def __post_init__(self):
if not isinstance(self.coordinates, np.ndarray):
self.coordinates = np.array(self.coordinates)
if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray):
self.prompt_coordinates = np.array(self.prompt_coordinates)

View File

View File

@@ -0,0 +1,61 @@
import numpy as np
from dash import callback, Input, Output, State
from ...data.processor import DataProcessor
class DataProcessingCallbacks:
def __init__(self):
self.processor = DataProcessor()
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('processed-data', 'data'),
Input('upload-data', 'contents'),
State('upload-data', 'filename')
)
def process_uploaded_file(contents, filename):
if contents is None:
return None
processed_data = self.processor.process_upload(contents, filename)
if processed_data.error:
return {'error': processed_data.error}
return {
'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
'embeddings': processed_data.embeddings.tolist()
}
@callback(
Output('processed-prompts', 'data'),
Input('upload-prompts', 'contents'),
State('upload-prompts', 'filename')
)
def process_uploaded_prompts(contents, filename):
if contents is None:
return None
processed_data = self.processor.process_upload(contents, filename)
if processed_data.error:
return {'error': processed_data.error}
return {
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
'embeddings': processed_data.embeddings.tolist()
}
@staticmethod
def _document_to_dict(doc):
return {
'id': doc.id,
'text': doc.text,
'embedding': doc.embedding,
'category': doc.category,
'subcategory': doc.subcategory,
'tags': doc.tags
}

View File

@@ -0,0 +1,66 @@
import dash
from dash import callback, Input, Output, State, html
import dash_bootstrap_components as dbc
class InteractionCallbacks:
def __init__(self):
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('point-details', 'children'),
Input('embedding-plot', 'clickData'),
[State('processed-data', 'data'),
State('processed-prompts', 'data')]
)
def display_click_data(clickData, data, prompts_data):
if not clickData or not data:
return "Click on a point to see details"
point_data = clickData['points'][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
if 'pointIndex' in point_data:
point_index = point_data['pointIndex']
elif 'pointNumber' in point_data:
point_index = point_data['pointNumber']
else:
return "Could not identify clicked point"
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
item = prompts_data['prompts'][point_index]
item_type = 'Prompt'
else:
item = data['documents'][point_index]
item_type = 'Document'
return self._create_detail_card(item, item_type)
@callback(
[Output('processed-data', 'data', allow_duplicate=True),
Output('processed-prompts', 'data', allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)],
Input('reset-button', 'n_clicks'),
prevent_initial_call=True
)
def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details"
@staticmethod
def _create_detail_card(item, item_type):
return dbc.Card([
dbc.CardBody([
html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
html.P(f"Type: {item_type}", className="card-text text-muted")
])
])

View File

@@ -0,0 +1,87 @@
import numpy as np
from dash import callback, Input, Output
import plotly.graph_objects as go
from ...models.reducers import ReducerFactory
from ...models.schemas import Document, PlotData
from ...visualization.plots import PlotFactory
class VisualizationCallbacks:
def __init__(self):
self.plot_factory = PlotFactory()
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('embedding-plot', 'figure'),
[Input('processed-data', 'data'),
Input('processed-prompts', 'data'),
Input('method-dropdown', 'value'),
Input('color-dropdown', 'value'),
Input('dimension-toggle', 'value'),
Input('show-prompts-toggle', 'value')]
)
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data:
return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
try:
doc_embeddings = np.array(data['embeddings'])
all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings'])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
reduced_data = reducer.fit_transform(all_embeddings)
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
prompt_reduced = None
if has_prompts:
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
documents = [self._dict_to_document(doc) for doc in data['documents']]
prompts = None
if has_prompts:
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
plot_data = PlotData(
documents=documents,
coordinates=doc_reduced,
prompts=prompts,
prompt_coordinates=prompt_reduced
)
return self.plot_factory.create_plot(
plot_data, dimensions, color_by, reduced_data.method, show_prompts
)
except Exception as e:
return go.Figure().add_annotation(
text=f"Error creating visualization: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
@staticmethod
def _dict_to_document(doc_dict):
return Document(
id=doc_dict['id'],
text=doc_dict['text'],
embedding=doc_dict['embedding'],
category=doc_dict.get('category'),
subcategory=doc_dict.get('subcategory'),
tags=doc_dict.get('tags', [])
)

View File

@@ -0,0 +1,82 @@
from dash import dcc, html
import dash_bootstrap_components as dbc
from .upload import UploadComponent
class SidebarComponent:
def __init__(self):
self.upload_component = UploadComponent()
def create_layout(self):
return dbc.Col([
html.H5("Upload Data", className="mb-3"),
self.upload_component.create_data_upload(),
self.upload_component.create_prompts_upload(),
self.upload_component.create_reset_button(),
html.H5("Visualization Controls", className="mb-3"),
self._create_method_dropdown(),
self._create_color_dropdown(),
self._create_dimension_toggle(),
self._create_prompts_toggle(),
html.H5("Point Details", className="mb-3"),
html.Div(id='point-details', children="Click on a point to see details")
], width=3, style={'padding-right': '20px'})
def _create_method_dropdown(self):
return [
dbc.Label("Method:"),
dcc.Dropdown(
id='method-dropdown',
options=[
{'label': 'PCA', 'value': 'pca'},
{'label': 't-SNE', 'value': 'tsne'},
{'label': 'UMAP', 'value': 'umap'}
],
value='pca',
style={'margin-bottom': '15px'}
)
]
def _create_color_dropdown(self):
return [
dbc.Label("Color by:"),
dcc.Dropdown(
id='color-dropdown',
options=[
{'label': 'Category', 'value': 'category'},
{'label': 'Subcategory', 'value': 'subcategory'},
{'label': 'Tags', 'value': 'tags'}
],
value='category',
style={'margin-bottom': '15px'}
)
]
def _create_dimension_toggle(self):
return [
dbc.Label("Dimensions:"),
dcc.RadioItems(
id='dimension-toggle',
options=[
{'label': '2D', 'value': '2d'},
{'label': '3D', 'value': '3d'}
],
value='3d',
style={'margin-bottom': '20px'}
)
]
def _create_prompts_toggle(self):
return [
dbc.Label("Show Prompts:"),
dcc.Checklist(
id='show-prompts-toggle',
options=[{'label': 'Show prompts on plot', 'value': 'show'}],
value=['show'],
style={'margin-bottom': '20px'}
)
]

View File

@@ -0,0 +1,60 @@
from dash import dcc, html
import dash_bootstrap_components as dbc
class UploadComponent:
@staticmethod
def create_data_upload():
return dcc.Upload(
id='upload-data',
children=html.Div([
'Drag and Drop or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px'
},
multiple=False
)
@staticmethod
def create_prompts_upload():
return dcc.Upload(
id='upload-prompts',
children=html.Div([
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px',
'borderColor': '#28a745'
},
multiple=False
)
@staticmethod
def create_reset_button():
return dbc.Button(
"Reset All Data",
id='reset-button',
color='danger',
outline=True,
size='sm',
className='mb-3',
style={'width': '100%'}
)

View File

@@ -0,0 +1,44 @@
from dash import dcc, html
import dash_bootstrap_components as dbc
from .components.sidebar import SidebarComponent
class AppLayout:
def __init__(self):
self.sidebar = SidebarComponent()
def create_layout(self):
return dbc.Container([
self._create_header(),
self._create_main_content(),
self._create_stores()
], fluid=True)
def _create_header(self):
return dbc.Row([
dbc.Col([
html.H1("EmbeddingBuddy", className="text-center mb-4"),
], width=12)
])
def _create_main_content(self):
return dbc.Row([
self.sidebar.create_layout(),
self._create_visualization_area()
])
def _create_visualization_area(self):
return dbc.Col([
dcc.Graph(
id='embedding-plot',
style={'height': '85vh', 'width': '100%'},
config={'responsive': True, 'displayModeBar': True}
)
], width=9)
def _create_stores(self):
return [
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
]

View File

View File

@@ -0,0 +1,33 @@
from typing import List, Dict, Any
import plotly.colors as pc
from ..models.schemas import Document
class ColorMapper:
@staticmethod
def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
if color_by == 'category':
return [doc.category for doc in documents]
elif color_by == 'subcategory':
return [doc.subcategory for doc in documents]
elif color_by == 'tags':
return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents]
else:
return ['All'] * len(documents)
@staticmethod
def to_grayscale_hex(color_str: str) -> str:
try:
if color_str.startswith('#'):
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
else:
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[1] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3)
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
except:
return 'rgb(128,128,128)'

View File

@@ -0,0 +1,145 @@
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from typing import List, Optional
from ..models.schemas import Document, PlotData
from .colors import ColorMapper
class PlotFactory:
def __init__(self):
self.color_mapper = ColorMapper()
def create_plot(self, plot_data: PlotData, dimensions: str = '3d',
color_by: str = 'category', method: str = 'PCA',
show_prompts: Optional[List[str]] = None) -> go.Figure:
if plot_data.prompts and show_prompts and 'show' in show_prompts:
return self._create_dual_plot(plot_data, dimensions, color_by, method)
else:
return self._create_single_plot(plot_data, dimensions, color_by, method)
def _create_single_plot(self, plot_data: PlotData, dimensions: str,
color_by: str, method: str) -> go.Figure:
df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
if dimensions == '3d':
fig = px.scatter_3d(
df, x='dim_1', y='dim_2', z='dim_3',
color=color_values,
hover_data=hover_fields,
title=f'3D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=5))
else:
fig = px.scatter(
df, x='dim_1', y='dim_2',
color=color_values,
hover_data=hover_fields,
title=f'2D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=8))
fig.update_layout(
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
def _create_dual_plot(self, plot_data: PlotData, dimensions: str,
color_by: str, method: str) -> go.Figure:
fig = go.Figure()
doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
if dimensions == '3d':
doc_fig = px.scatter_3d(
doc_df, x='dim_1', y='dim_2', z='dim_3',
color=doc_color_values,
hover_data=hover_fields
)
else:
doc_fig = px.scatter(
doc_df, x='dim_1', y='dim_2',
color=doc_color_values,
hover_data=hover_fields
)
for trace in doc_fig.data:
trace.name = f'Documents - {trace.name}'
if dimensions == '3d':
trace.marker.size = 5
trace.marker.symbol = 'circle'
else:
trace.marker.size = 8
trace.marker.symbol = 'circle'
trace.marker.opacity = 1.0
fig.add_trace(trace)
if plot_data.prompts and plot_data.prompt_coordinates is not None:
prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions)
prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by)
if dimensions == '3d':
prompt_fig = px.scatter_3d(
prompt_df, x='dim_1', y='dim_2', z='dim_3',
color=prompt_color_values,
hover_data=hover_fields
)
else:
prompt_fig = px.scatter(
prompt_df, x='dim_1', y='dim_2',
color=prompt_color_values,
hover_data=hover_fields
)
for trace in prompt_fig.data:
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color)
trace.name = f'Prompts - {trace.name}'
if dimensions == '3d':
trace.marker.size = 6
trace.marker.symbol = 'diamond'
else:
trace.marker.size = 10
trace.marker.symbol = 'diamond'
trace.marker.opacity = 0.8
fig.add_trace(trace)
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
fig.update_layout(
title=title,
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame:
df_data = []
for i, doc in enumerate(documents):
row = {
'id': doc.id,
'text': doc.text,
'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text,
'category': doc.category,
'subcategory': doc.subcategory,
'tags_str': ', '.join(doc.tags) if doc.tags else 'None',
'dim_1': coordinates[i, 0],
'dim_2': coordinates[i, 1],
}
if dimensions == '3d':
row['dim_3'] = coordinates[i, 2]
df_data.append(row)
return pd.DataFrame(df_data)