refactor and add tests, v0.2.0
This commit is contained in:
87
src/embeddingbuddy/ui/callbacks/visualization.py
Normal file
87
src/embeddingbuddy/ui/callbacks/visualization.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import numpy as np
|
||||
from dash import callback, Input, Output
|
||||
import plotly.graph_objects as go
|
||||
from ...models.reducers import ReducerFactory
|
||||
from ...models.schemas import Document, PlotData
|
||||
from ...visualization.plots import PlotFactory
|
||||
|
||||
|
||||
class VisualizationCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self.plot_factory = PlotFactory()
|
||||
self._register_callbacks()
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('embedding-plot', 'figure'),
|
||||
[Input('processed-data', 'data'),
|
||||
Input('processed-prompts', 'data'),
|
||||
Input('method-dropdown', 'value'),
|
||||
Input('color-dropdown', 'value'),
|
||||
Input('dimension-toggle', 'value'),
|
||||
Input('show-prompts-toggle', 'value')]
|
||||
)
|
||||
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
||||
if not data or 'error' in data:
|
||||
return go.Figure().add_annotation(
|
||||
text="Upload a valid NDJSON file to see visualization",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
)
|
||||
|
||||
try:
|
||||
doc_embeddings = np.array(data['embeddings'])
|
||||
all_embeddings = doc_embeddings
|
||||
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
|
||||
|
||||
if has_prompts:
|
||||
prompt_embeddings = np.array(prompts_data['embeddings'])
|
||||
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
||||
|
||||
n_components = 3 if dimensions == '3d' else 2
|
||||
|
||||
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
|
||||
reduced_data = reducer.fit_transform(all_embeddings)
|
||||
|
||||
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
|
||||
prompt_reduced = None
|
||||
if has_prompts:
|
||||
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
|
||||
|
||||
documents = [self._dict_to_document(doc) for doc in data['documents']]
|
||||
prompts = None
|
||||
if has_prompts:
|
||||
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
|
||||
|
||||
plot_data = PlotData(
|
||||
documents=documents,
|
||||
coordinates=doc_reduced,
|
||||
prompts=prompts,
|
||||
prompt_coordinates=prompt_reduced
|
||||
)
|
||||
|
||||
return self.plot_factory.create_plot(
|
||||
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return go.Figure().add_annotation(
|
||||
text=f"Error creating visualization: {str(e)}",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_document(doc_dict):
|
||||
return Document(
|
||||
id=doc_dict['id'],
|
||||
text=doc_dict['text'],
|
||||
embedding=doc_dict['embedding'],
|
||||
category=doc_dict.get('category'),
|
||||
subcategory=doc_dict.get('subcategory'),
|
||||
tags=doc_dict.get('tags', [])
|
||||
)
|
Reference in New Issue
Block a user