refactor and add tests, v0.2.0

This commit is contained in:
2025-08-13 20:07:40 -07:00
parent 76be59254c
commit 809dbeb783
32 changed files with 1401 additions and 32 deletions

View File

@@ -0,0 +1,61 @@
import numpy as np
from dash import callback, Input, Output, State
from ...data.processor import DataProcessor
class DataProcessingCallbacks:
def __init__(self):
self.processor = DataProcessor()
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('processed-data', 'data'),
Input('upload-data', 'contents'),
State('upload-data', 'filename')
)
def process_uploaded_file(contents, filename):
if contents is None:
return None
processed_data = self.processor.process_upload(contents, filename)
if processed_data.error:
return {'error': processed_data.error}
return {
'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
'embeddings': processed_data.embeddings.tolist()
}
@callback(
Output('processed-prompts', 'data'),
Input('upload-prompts', 'contents'),
State('upload-prompts', 'filename')
)
def process_uploaded_prompts(contents, filename):
if contents is None:
return None
processed_data = self.processor.process_upload(contents, filename)
if processed_data.error:
return {'error': processed_data.error}
return {
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
'embeddings': processed_data.embeddings.tolist()
}
@staticmethod
def _document_to_dict(doc):
return {
'id': doc.id,
'text': doc.text,
'embedding': doc.embedding,
'category': doc.category,
'subcategory': doc.subcategory,
'tags': doc.tags
}

View File

@@ -0,0 +1,66 @@
import dash
from dash import callback, Input, Output, State, html
import dash_bootstrap_components as dbc
class InteractionCallbacks:
def __init__(self):
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('point-details', 'children'),
Input('embedding-plot', 'clickData'),
[State('processed-data', 'data'),
State('processed-prompts', 'data')]
)
def display_click_data(clickData, data, prompts_data):
if not clickData or not data:
return "Click on a point to see details"
point_data = clickData['points'][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
if 'pointIndex' in point_data:
point_index = point_data['pointIndex']
elif 'pointNumber' in point_data:
point_index = point_data['pointNumber']
else:
return "Could not identify clicked point"
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
item = prompts_data['prompts'][point_index]
item_type = 'Prompt'
else:
item = data['documents'][point_index]
item_type = 'Document'
return self._create_detail_card(item, item_type)
@callback(
[Output('processed-data', 'data', allow_duplicate=True),
Output('processed-prompts', 'data', allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)],
Input('reset-button', 'n_clicks'),
prevent_initial_call=True
)
def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details"
@staticmethod
def _create_detail_card(item, item_type):
return dbc.Card([
dbc.CardBody([
html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
html.P(f"Type: {item_type}", className="card-text text-muted")
])
])

View File

@@ -0,0 +1,87 @@
import numpy as np
from dash import callback, Input, Output
import plotly.graph_objects as go
from ...models.reducers import ReducerFactory
from ...models.schemas import Document, PlotData
from ...visualization.plots import PlotFactory
class VisualizationCallbacks:
def __init__(self):
self.plot_factory = PlotFactory()
self._register_callbacks()
def _register_callbacks(self):
@callback(
Output('embedding-plot', 'figure'),
[Input('processed-data', 'data'),
Input('processed-prompts', 'data'),
Input('method-dropdown', 'value'),
Input('color-dropdown', 'value'),
Input('dimension-toggle', 'value'),
Input('show-prompts-toggle', 'value')]
)
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data:
return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
try:
doc_embeddings = np.array(data['embeddings'])
all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings'])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
reduced_data = reducer.fit_transform(all_embeddings)
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
prompt_reduced = None
if has_prompts:
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
documents = [self._dict_to_document(doc) for doc in data['documents']]
prompts = None
if has_prompts:
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
plot_data = PlotData(
documents=documents,
coordinates=doc_reduced,
prompts=prompts,
prompt_coordinates=prompt_reduced
)
return self.plot_factory.create_plot(
plot_data, dimensions, color_by, reduced_data.method, show_prompts
)
except Exception as e:
return go.Figure().add_annotation(
text=f"Error creating visualization: {str(e)}",
xref="paper", yref="paper",
x=0.5, y=0.5, xanchor='center', yanchor='middle',
showarrow=False, font=dict(size=16)
)
@staticmethod
def _dict_to_document(doc_dict):
return Document(
id=doc_dict['id'],
text=doc_dict['text'],
embedding=doc_dict['embedding'],
category=doc_dict.get('category'),
subcategory=doc_dict.get('subcategory'),
tags=doc_dict.get('tags', [])
)