Some checks failed
Security Scan / dependency-check (pull_request) Successful in 35s
Security Scan / security (pull_request) Successful in 40s
Test Suite / lint (pull_request) Failing after 40s
Test Suite / test (3.11) (pull_request) Successful in 1m26s
Test Suite / build (pull_request) Has been skipped
179 lines
6.0 KiB
Python
179 lines
6.0 KiB
Python
import pandas as pd
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
from typing import List, Optional
|
|
from ..models.schemas import Document, PlotData
|
|
from .colors import ColorMapper
|
|
|
|
|
|
class PlotFactory:
|
|
def __init__(self):
|
|
self.color_mapper = ColorMapper()
|
|
|
|
def create_plot(
|
|
self,
|
|
plot_data: PlotData,
|
|
dimensions: str = "3d",
|
|
color_by: str = "category",
|
|
method: str = "PCA",
|
|
show_prompts: Optional[List[str]] = None,
|
|
) -> go.Figure:
|
|
if plot_data.prompts and show_prompts and "show" in show_prompts:
|
|
return self._create_dual_plot(plot_data, dimensions, color_by, method)
|
|
else:
|
|
return self._create_single_plot(plot_data, dimensions, color_by, method)
|
|
|
|
def _create_single_plot(
|
|
self, plot_data: PlotData, dimensions: str, color_by: str, method: str
|
|
) -> go.Figure:
|
|
df = self._prepare_dataframe(
|
|
plot_data.documents, plot_data.coordinates, dimensions
|
|
)
|
|
color_values = self.color_mapper.create_color_mapping(
|
|
plot_data.documents, color_by
|
|
)
|
|
|
|
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
|
|
|
|
if dimensions == "3d":
|
|
fig = px.scatter_3d(
|
|
df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
z="dim_3",
|
|
color=color_values,
|
|
hover_data=hover_fields,
|
|
title=f"3D Embedding Visualization - {method} (colored by {color_by})",
|
|
)
|
|
fig.update_traces(marker=dict(size=5))
|
|
else:
|
|
fig = px.scatter(
|
|
df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
color=color_values,
|
|
hover_data=hover_fields,
|
|
title=f"2D Embedding Visualization - {method} (colored by {color_by})",
|
|
)
|
|
fig.update_traces(marker=dict(size=8))
|
|
|
|
fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0))
|
|
return fig
|
|
|
|
def _create_dual_plot(
|
|
self, plot_data: PlotData, dimensions: str, color_by: str, method: str
|
|
) -> go.Figure:
|
|
fig = go.Figure()
|
|
|
|
doc_df = self._prepare_dataframe(
|
|
plot_data.documents, plot_data.coordinates, dimensions
|
|
)
|
|
doc_color_values = self.color_mapper.create_color_mapping(
|
|
plot_data.documents, color_by
|
|
)
|
|
|
|
hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
|
|
|
|
if dimensions == "3d":
|
|
doc_fig = px.scatter_3d(
|
|
doc_df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
z="dim_3",
|
|
color=doc_color_values,
|
|
hover_data=hover_fields,
|
|
)
|
|
else:
|
|
doc_fig = px.scatter(
|
|
doc_df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
color=doc_color_values,
|
|
hover_data=hover_fields,
|
|
)
|
|
|
|
for trace in doc_fig.data:
|
|
trace.name = f"Documents - {trace.name}"
|
|
if dimensions == "3d":
|
|
trace.marker.size = 5
|
|
trace.marker.symbol = "circle"
|
|
else:
|
|
trace.marker.size = 8
|
|
trace.marker.symbol = "circle"
|
|
trace.marker.opacity = 1.0
|
|
fig.add_trace(trace)
|
|
|
|
if plot_data.prompts and plot_data.prompt_coordinates is not None:
|
|
prompt_df = self._prepare_dataframe(
|
|
plot_data.prompts, plot_data.prompt_coordinates, dimensions
|
|
)
|
|
prompt_color_values = self.color_mapper.create_color_mapping(
|
|
plot_data.prompts, color_by
|
|
)
|
|
|
|
if dimensions == "3d":
|
|
prompt_fig = px.scatter_3d(
|
|
prompt_df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
z="dim_3",
|
|
color=prompt_color_values,
|
|
hover_data=hover_fields,
|
|
)
|
|
else:
|
|
prompt_fig = px.scatter(
|
|
prompt_df,
|
|
x="dim_1",
|
|
y="dim_2",
|
|
color=prompt_color_values,
|
|
hover_data=hover_fields,
|
|
)
|
|
|
|
for trace in prompt_fig.data:
|
|
if hasattr(trace.marker, "color") and isinstance(
|
|
trace.marker.color, str
|
|
):
|
|
trace.marker.color = self.color_mapper.to_grayscale_hex(
|
|
trace.marker.color
|
|
)
|
|
|
|
trace.name = f"Prompts - {trace.name}"
|
|
if dimensions == "3d":
|
|
trace.marker.size = 6
|
|
trace.marker.symbol = "diamond"
|
|
else:
|
|
trace.marker.size = 10
|
|
trace.marker.symbol = "diamond"
|
|
trace.marker.opacity = 0.8
|
|
fig.add_trace(trace)
|
|
|
|
title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})"
|
|
fig.update_layout(
|
|
title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)
|
|
)
|
|
|
|
return fig
|
|
|
|
def _prepare_dataframe(
|
|
self, documents: List[Document], coordinates, dimensions: str
|
|
) -> pd.DataFrame:
|
|
df_data = []
|
|
for i, doc in enumerate(documents):
|
|
row = {
|
|
"id": doc.id,
|
|
"text": doc.text,
|
|
"text_preview": doc.text[:100] + "..."
|
|
if len(doc.text) > 100
|
|
else doc.text,
|
|
"category": doc.category,
|
|
"subcategory": doc.subcategory,
|
|
"tags_str": ", ".join(doc.tags) if doc.tags else "None",
|
|
"dim_1": coordinates[i, 0],
|
|
"dim_2": coordinates[i, 1],
|
|
}
|
|
if dimensions == "3d":
|
|
row["dim_3"] = coordinates[i, 2]
|
|
df_data.append(row)
|
|
|
|
return pd.DataFrame(df_data)
|