Some checks failed
Security Scan / dependency-check (pull_request) Successful in 35s
Security Scan / security (pull_request) Successful in 40s
Test Suite / lint (pull_request) Failing after 40s
Test Suite / test (3.11) (pull_request) Successful in 1m26s
Test Suite / build (pull_request) Has been skipped
109 lines
3.9 KiB
Python
109 lines
3.9 KiB
Python
import numpy as np
|
|
from dash import callback, Input, Output
|
|
import plotly.graph_objects as go
|
|
from ...models.reducers import ReducerFactory
|
|
from ...models.schemas import Document, PlotData
|
|
from ...visualization.plots import PlotFactory
|
|
|
|
|
|
class VisualizationCallbacks:
|
|
def __init__(self):
|
|
self.plot_factory = PlotFactory()
|
|
self._register_callbacks()
|
|
|
|
def _register_callbacks(self):
|
|
@callback(
|
|
Output("embedding-plot", "figure"),
|
|
[
|
|
Input("processed-data", "data"),
|
|
Input("processed-prompts", "data"),
|
|
Input("method-dropdown", "value"),
|
|
Input("color-dropdown", "value"),
|
|
Input("dimension-toggle", "value"),
|
|
Input("show-prompts-toggle", "value"),
|
|
],
|
|
)
|
|
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
|
|
if not data or "error" in data:
|
|
return go.Figure().add_annotation(
|
|
text="Upload a valid NDJSON file to see visualization",
|
|
xref="paper",
|
|
yref="paper",
|
|
x=0.5,
|
|
y=0.5,
|
|
xanchor="center",
|
|
yanchor="middle",
|
|
showarrow=False,
|
|
font=dict(size=16),
|
|
)
|
|
|
|
try:
|
|
doc_embeddings = np.array(data["embeddings"])
|
|
all_embeddings = doc_embeddings
|
|
has_prompts = (
|
|
prompts_data
|
|
and "error" not in prompts_data
|
|
and prompts_data.get("prompts")
|
|
)
|
|
|
|
if has_prompts:
|
|
prompt_embeddings = np.array(prompts_data["embeddings"])
|
|
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
|
|
|
|
n_components = 3 if dimensions == "3d" else 2
|
|
|
|
reducer = ReducerFactory.create_reducer(
|
|
method, n_components=n_components
|
|
)
|
|
reduced_data = reducer.fit_transform(all_embeddings)
|
|
|
|
doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
|
|
prompt_reduced = None
|
|
if has_prompts:
|
|
prompt_reduced = reduced_data.reduced_embeddings[
|
|
len(doc_embeddings) :
|
|
]
|
|
|
|
documents = [self._dict_to_document(doc) for doc in data["documents"]]
|
|
prompts = None
|
|
if has_prompts:
|
|
prompts = [
|
|
self._dict_to_document(prompt)
|
|
for prompt in prompts_data["prompts"]
|
|
]
|
|
|
|
plot_data = PlotData(
|
|
documents=documents,
|
|
coordinates=doc_reduced,
|
|
prompts=prompts,
|
|
prompt_coordinates=prompt_reduced,
|
|
)
|
|
|
|
return self.plot_factory.create_plot(
|
|
plot_data, dimensions, color_by, reduced_data.method, show_prompts
|
|
)
|
|
|
|
except Exception as e:
|
|
return go.Figure().add_annotation(
|
|
text=f"Error creating visualization: {str(e)}",
|
|
xref="paper",
|
|
yref="paper",
|
|
x=0.5,
|
|
y=0.5,
|
|
xanchor="center",
|
|
yanchor="middle",
|
|
showarrow=False,
|
|
font=dict(size=16),
|
|
)
|
|
|
|
@staticmethod
|
|
def _dict_to_document(doc_dict):
|
|
return Document(
|
|
id=doc_dict["id"],
|
|
text=doc_dict["text"],
|
|
embedding=doc_dict["embedding"],
|
|
category=doc_dict.get("category"),
|
|
subcategory=doc_dict.get("subcategory"),
|
|
tags=doc_dict.get("tags", []),
|
|
)
|