add ci workflows (#1)
All checks were successful
All checks were successful
Reviewed-on: godber/embedding-buddy#1
This commit is contained in:
@@ -1,61 +1,62 @@
|
||||
import numpy as np
|
||||
from dash import callback, Input, Output, State
|
||||
from ...data.processor import DataProcessor
|
||||
|
||||
|
||||
class DataProcessingCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self.processor = DataProcessor()
|
||||
self._register_callbacks()
|
||||
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('processed-data', 'data'),
|
||||
Input('upload-data', 'contents'),
|
||||
State('upload-data', 'filename')
|
||||
Output("processed-data", "data"),
|
||||
Input("upload-data", "contents"),
|
||||
State("upload-data", "filename"),
|
||||
)
|
||||
def process_uploaded_file(contents, filename):
|
||||
if contents is None:
|
||||
return None
|
||||
|
||||
|
||||
processed_data = self.processor.process_upload(contents, filename)
|
||||
|
||||
|
||||
if processed_data.error:
|
||||
return {'error': processed_data.error}
|
||||
|
||||
return {"error": processed_data.error}
|
||||
|
||||
return {
|
||||
'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
|
||||
'embeddings': processed_data.embeddings.tolist()
|
||||
"documents": [
|
||||
self._document_to_dict(doc) for doc in processed_data.documents
|
||||
],
|
||||
"embeddings": processed_data.embeddings.tolist(),
|
||||
}
|
||||
|
||||
|
||||
@callback(
|
||||
Output('processed-prompts', 'data'),
|
||||
Input('upload-prompts', 'contents'),
|
||||
State('upload-prompts', 'filename')
|
||||
Output("processed-prompts", "data"),
|
||||
Input("upload-prompts", "contents"),
|
||||
State("upload-prompts", "filename"),
|
||||
)
|
||||
def process_uploaded_prompts(contents, filename):
|
||||
if contents is None:
|
||||
return None
|
||||
|
||||
|
||||
processed_data = self.processor.process_upload(contents, filename)
|
||||
|
||||
|
||||
if processed_data.error:
|
||||
return {'error': processed_data.error}
|
||||
|
||||
return {"error": processed_data.error}
|
||||
|
||||
return {
|
||||
'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
|
||||
'embeddings': processed_data.embeddings.tolist()
|
||||
"prompts": [
|
||||
self._document_to_dict(doc) for doc in processed_data.documents
|
||||
],
|
||||
"embeddings": processed_data.embeddings.tolist(),
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _document_to_dict(doc):
|
||||
return {
|
||||
'id': doc.id,
|
||||
'text': doc.text,
|
||||
'embedding': doc.embedding,
|
||||
'category': doc.category,
|
||||
'subcategory': doc.subcategory,
|
||||
'tags': doc.tags
|
||||
}
|
||||
"id": doc.id,
|
||||
"text": doc.text,
|
||||
"embedding": doc.embedding,
|
||||
"category": doc.category,
|
||||
"subcategory": doc.subcategory,
|
||||
"tags": doc.tags,
|
||||
}
|
||||
|
@@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc
|
||||
|
||||
|
||||
class InteractionCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self._register_callbacks()
|
||||
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('point-details', 'children'),
|
||||
Input('embedding-plot', 'clickData'),
|
||||
[State('processed-data', 'data'),
|
||||
State('processed-prompts', 'data')]
|
||||
Output("point-details", "children"),
|
||||
Input("embedding-plot", "clickData"),
|
||||
[State("processed-data", "data"), State("processed-prompts", "data")],
|
||||
)
|
||||
def display_click_data(clickData, data, prompts_data):
|
||||
if not clickData or not data:
|
||||
return "Click on a point to see details"
|
||||
|
||||
point_data = clickData['points'][0]
|
||||
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
|
||||
|
||||
if 'pointIndex' in point_data:
|
||||
point_index = point_data['pointIndex']
|
||||
elif 'pointNumber' in point_data:
|
||||
point_index = point_data['pointNumber']
|
||||
|
||||
point_data = clickData["points"][0]
|
||||
trace_name = point_data.get("fullData", {}).get("name", "Documents")
|
||||
|
||||
if "pointIndex" in point_data:
|
||||
point_index = point_data["pointIndex"]
|
||||
elif "pointNumber" in point_data:
|
||||
point_index = point_data["pointNumber"]
|
||||
else:
|
||||
return "Could not identify clicked point"
|
||||
|
||||
if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
|
||||
item = prompts_data['prompts'][point_index]
|
||||
item_type = 'Prompt'
|
||||
|
||||
if (
|
||||
trace_name.startswith("Prompts")
|
||||
and prompts_data
|
||||
and "prompts" in prompts_data
|
||||
):
|
||||
item = prompts_data["prompts"][point_index]
|
||||
item_type = "Prompt"
|
||||
else:
|
||||
item = data['documents'][point_index]
|
||||
item_type = 'Document'
|
||||
|
||||
item = data["documents"][point_index]
|
||||
item_type = "Document"
|
||||
|
||||
return self._create_detail_card(item, item_type)
|
||||
|
||||
|
||||
@callback(
|
||||
[Output('processed-data', 'data', allow_duplicate=True),
|
||||
Output('processed-prompts', 'data', allow_duplicate=True),
|
||||
Output('point-details', 'children', allow_duplicate=True)],
|
||||
Input('reset-button', 'n_clicks'),
|
||||
prevent_initial_call=True
|
||||
[
|
||||
Output("processed-data", "data", allow_duplicate=True),
|
||||
Output("processed-prompts", "data", allow_duplicate=True),
|
||||
Output("point-details", "children", allow_duplicate=True),
|
||||
],
|
||||
Input("reset-button", "n_clicks"),
|
||||
prevent_initial_call=True,
|
||||
)
|
||||
def reset_data(n_clicks):
|
||||
if n_clicks is None or n_clicks == 0:
|
||||
return dash.no_update, dash.no_update, dash.no_update
|
||||
|
||||
|
||||
return None, None, "Click on a point to see details"
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _create_detail_card(item, item_type):
|
||||
return dbc.Card([
|
||||
dbc.CardBody([
|
||||
html.H5(f"{item_type}: {item['id']}", className="card-title"),
|
||||
html.P(f"Text: {item['text']}", className="card-text"),
|
||||
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
|
||||
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
|
||||
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
|
||||
html.P(f"Type: {item_type}", className="card-text text-muted")
|
||||
])
|
||||
])
|
||||
return dbc.Card(
|
||||
[
|
||||
dbc.CardBody(
|
||||
[
|
||||
html.H5(f"{item_type}: {item['id']}", className="card-title"),
|
||||
html.P(f"Text: {item['text']}", className="card-text"),
|
||||
html.P(
|
||||
f"Category: {item.get('category', 'Unknown')}",
|
||||
className="card-text",
|
||||
),
|
||||
html.P(
|
||||
f"Subcategory: {item.get('subcategory', 'Unknown')}",
|
||||
className="card-text",
|
||||
),
|
||||
html.P(
|
||||
f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
|
||||
className="card-text",
|
||||
),
|
||||
html.P(f"Type: {item_type}", className="card-text text-muted"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
@@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory
|
||||
|
||||
|
||||
class VisualizationCallbacks:
|
||||
|
||||
def __init__(self):
|
||||
self.plot_factory = PlotFactory()
|
||||
self._register_callbacks()
|
||||
|
||||
|
||||
def _register_callbacks(self):
|
||||
|
||||
@callback(
|
||||
Output('embedding-plot', 'figure'),
|
||||
[Input('processed-data', 'data'),
|
||||
Input('processed-prompts', 'data'),
|
||||
Input('method-dropdown', 'value'),
|
||||
Input('color-dropdown', 'value'),
|
||||
Input('dimension-toggle', 'value'),
|
||||
Input('show-prompts-toggle', 'value')]
|
||||
Output("embedding-plot", "figure"),
|
||||
[
|
||||
Input("processed-data", "data"),
|
||||
Input("processed-prompts", "data"),
|
||||
Input("method-dropdown", "value"),
|
||||
Input("color-dropdown", "value"),
|
||||
Input("dimension-toggle", "value"),
|
||||
Input("show-prompts-toggle", "value"),
|
||||
],
|
||||
)
|
||||
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
||||
if not data or 'error' in data:
|
||||
if not data or "error" in data:
|
||||
return go.Figure().add_annotation(
|
||||
text="Upload a valid NDJSON file to see visualization",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
font=dict(size=16),
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
doc_embeddings = np.array(data['embeddings'])
|
||||
doc_embeddings = np.array(data["embeddings"])
|
||||
all_embeddings = doc_embeddings
|
||||
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
|
||||
|
||||
has_prompts = (
|
||||
prompts_data
|
||||
and "error" not in prompts_data
|
||||
and prompts_data.get("prompts")
|
||||
)
|
||||
|
||||
if has_prompts:
|
||||
prompt_embeddings = np.array(prompts_data['embeddings'])
|
||||
prompt_embeddings = np.array(prompts_data["embeddings"])
|
||||
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
||||
|
||||
n_components = 3 if dimensions == '3d' else 2
|
||||
|
||||
reducer = ReducerFactory.create_reducer(method, n_components=n_components)
|
||||
|
||||
n_components = 3 if dimensions == "3d" else 2
|
||||
|
||||
reducer = ReducerFactory.create_reducer(
|
||||
method, n_components=n_components
|
||||
)
|
||||
reduced_data = reducer.fit_transform(all_embeddings)
|
||||
|
||||
doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
|
||||
|
||||
doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
|
||||
prompt_reduced = None
|
||||
if has_prompts:
|
||||
prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
|
||||
|
||||
documents = [self._dict_to_document(doc) for doc in data['documents']]
|
||||
prompt_reduced = reduced_data.reduced_embeddings[
|
||||
len(doc_embeddings) :
|
||||
]
|
||||
|
||||
documents = [self._dict_to_document(doc) for doc in data["documents"]]
|
||||
prompts = None
|
||||
if has_prompts:
|
||||
prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
|
||||
|
||||
prompts = [
|
||||
self._dict_to_document(prompt)
|
||||
for prompt in prompts_data["prompts"]
|
||||
]
|
||||
|
||||
plot_data = PlotData(
|
||||
documents=documents,
|
||||
coordinates=doc_reduced,
|
||||
prompts=prompts,
|
||||
prompt_coordinates=prompt_reduced
|
||||
prompt_coordinates=prompt_reduced,
|
||||
)
|
||||
|
||||
|
||||
return self.plot_factory.create_plot(
|
||||
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return go.Figure().add_annotation(
|
||||
text=f"Error creating visualization: {str(e)}",
|
||||
xref="paper", yref="paper",
|
||||
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
||||
showarrow=False, font=dict(size=16)
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0.5,
|
||||
y=0.5,
|
||||
xanchor="center",
|
||||
yanchor="middle",
|
||||
showarrow=False,
|
||||
font=dict(size=16),
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_document(doc_dict):
|
||||
return Document(
|
||||
id=doc_dict['id'],
|
||||
text=doc_dict['text'],
|
||||
embedding=doc_dict['embedding'],
|
||||
category=doc_dict.get('category'),
|
||||
subcategory=doc_dict.get('subcategory'),
|
||||
tags=doc_dict.get('tags', [])
|
||||
)
|
||||
id=doc_dict["id"],
|
||||
text=doc_dict["text"],
|
||||
embedding=doc_dict["embedding"],
|
||||
category=doc_dict.get("category"),
|
||||
subcategory=doc_dict.get("subcategory"),
|
||||
tags=doc_dict.get("tags", []),
|
||||
)
|
||||
|
Reference in New Issue
Block a user