238 lines
7.3 KiB
Python
238 lines
7.3 KiB
Python
import json
|
|
import uuid
|
|
from io import StringIO
|
|
import base64
|
|
|
|
import dash
|
|
from dash import dcc, html, Input, Output, State, callback
|
|
import dash_bootstrap_components as dbc
|
|
import plotly.express as px
|
|
import plotly.graph_objects as go
|
|
import pandas as pd
|
|
import numpy as np
|
|
from sklearn.decomposition import PCA
|
|
|
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
|
|
|
|
def parse_ndjson(contents):
|
|
"""Parse NDJSON content and return list of documents."""
|
|
content_type, content_string = contents.split(',')
|
|
decoded = base64.b64decode(content_string)
|
|
text_content = decoded.decode('utf-8')
|
|
|
|
documents = []
|
|
for line in text_content.strip().split('\n'):
|
|
if line.strip():
|
|
doc = json.loads(line)
|
|
if 'id' not in doc:
|
|
doc['id'] = str(uuid.uuid4())
|
|
documents.append(doc)
|
|
return documents
|
|
|
|
def apply_pca(embeddings, n_components=3):
|
|
"""Apply PCA to embeddings."""
|
|
pca = PCA(n_components=n_components)
|
|
reduced = pca.fit_transform(embeddings)
|
|
return reduced, pca.explained_variance_ratio_
|
|
|
|
def create_color_mapping(documents, color_by):
|
|
"""Create color mapping for documents based on specified field."""
|
|
if color_by == 'category':
|
|
values = [doc.get('category', 'Unknown') for doc in documents]
|
|
elif color_by == 'subcategory':
|
|
values = [doc.get('subcategory', 'Unknown') for doc in documents]
|
|
elif color_by == 'tags':
|
|
values = [', '.join(doc.get('tags', [])) if doc.get('tags') else 'No tags' for doc in documents]
|
|
else:
|
|
values = ['All'] * len(documents)
|
|
|
|
return values
|
|
|
|
def create_plot(df, dimensions='3d', color_by='category'):
|
|
"""Create plotly scatter plot."""
|
|
color_values = create_color_mapping(df.to_dict('records'), color_by)
|
|
|
|
if dimensions == '3d':
|
|
fig = px.scatter_3d(
|
|
df, x='pca_1', y='pca_2', z='pca_3',
|
|
color=color_values,
|
|
hover_data=['id', 'text'],
|
|
title=f'3D Embedding Visualization (colored by {color_by})'
|
|
)
|
|
fig.update_traces(marker=dict(size=5))
|
|
else:
|
|
fig = px.scatter(
|
|
df, x='pca_1', y='pca_2',
|
|
color=color_values,
|
|
hover_data=['id', 'text'],
|
|
title=f'2D Embedding Visualization (colored by {color_by})'
|
|
)
|
|
fig.update_traces(marker=dict(size=8))
|
|
|
|
fig.update_layout(height=600)
|
|
return fig
|
|
|
|
# Layout
|
|
app.layout = dbc.Container([
|
|
dbc.Row([
|
|
dbc.Col([
|
|
html.H1("EmbeddingBuddy", className="text-center mb-4"),
|
|
html.P("Upload NDJSON file with embeddings to visualize", className="text-center text-muted")
|
|
])
|
|
]),
|
|
|
|
dbc.Row([
|
|
dbc.Col([
|
|
dcc.Upload(
|
|
id='upload-data',
|
|
children=html.Div([
|
|
'Drag and Drop or ',
|
|
html.A('Select Files')
|
|
]),
|
|
style={
|
|
'width': '100%',
|
|
'height': '60px',
|
|
'lineHeight': '60px',
|
|
'borderWidth': '1px',
|
|
'borderStyle': 'dashed',
|
|
'borderRadius': '5px',
|
|
'textAlign': 'center',
|
|
'margin': '10px'
|
|
},
|
|
multiple=False
|
|
)
|
|
])
|
|
]),
|
|
|
|
dbc.Row([
|
|
dbc.Col([
|
|
dbc.Label("Color by:"),
|
|
dcc.Dropdown(
|
|
id='color-dropdown',
|
|
options=[
|
|
{'label': 'Category', 'value': 'category'},
|
|
{'label': 'Subcategory', 'value': 'subcategory'},
|
|
{'label': 'Tags', 'value': 'tags'}
|
|
],
|
|
value='category',
|
|
style={'margin-bottom': '10px'}
|
|
)
|
|
], width=6),
|
|
dbc.Col([
|
|
dbc.Label("Dimensions:"),
|
|
dcc.RadioItems(
|
|
id='dimension-toggle',
|
|
options=[
|
|
{'label': '2D', 'value': '2d'},
|
|
{'label': '3D', 'value': '3d'}
|
|
],
|
|
value='3d',
|
|
inline=True
|
|
)
|
|
], width=6)
|
|
], className="mb-3"),
|
|
|
|
dbc.Row([
|
|
dbc.Col([
|
|
dcc.Graph(id='embedding-plot')
|
|
])
|
|
]),
|
|
|
|
dbc.Row([
|
|
dbc.Col([
|
|
html.Div(id='point-details', style={'margin-top': '20px'})
|
|
])
|
|
]),
|
|
|
|
dcc.Store(id='processed-data')
|
|
], fluid=True)
|
|
|
|
@callback(
|
|
Output('processed-data', 'data'),
|
|
Input('upload-data', 'contents'),
|
|
State('upload-data', 'filename')
|
|
)
|
|
def process_uploaded_file(contents, filename):
|
|
if contents is None:
|
|
return None
|
|
|
|
try:
|
|
documents = parse_ndjson(contents)
|
|
embeddings = np.array([doc['embedding'] for doc in documents])
|
|
|
|
# Apply PCA
|
|
pca_2d, var_2d = apply_pca(embeddings, n_components=2)
|
|
pca_3d, var_3d = apply_pca(embeddings, n_components=3)
|
|
|
|
# Create dataframe
|
|
df_data = []
|
|
for i, doc in enumerate(documents):
|
|
df_data.append({
|
|
'id': doc['id'],
|
|
'text': doc['text'],
|
|
'category': doc.get('category', 'Unknown'),
|
|
'subcategory': doc.get('subcategory', 'Unknown'),
|
|
'tags': doc.get('tags', []),
|
|
'pca_1': pca_3d[i, 0],
|
|
'pca_2': pca_3d[i, 1],
|
|
'pca_3': pca_3d[i, 2],
|
|
'pca_1_2d': pca_2d[i, 0],
|
|
'pca_2_2d': pca_2d[i, 1]
|
|
})
|
|
|
|
return {
|
|
'documents': df_data,
|
|
'variance_explained_2d': var_2d.tolist(),
|
|
'variance_explained_3d': var_3d.tolist()
|
|
}
|
|
except Exception as e:
|
|
return {'error': str(e)}
|
|
|
|
@callback(
|
|
Output('embedding-plot', 'figure'),
|
|
[Input('processed-data', 'data'),
|
|
Input('color-dropdown', 'value'),
|
|
Input('dimension-toggle', 'value')]
|
|
)
|
|
def update_plot(data, color_by, dimensions):
|
|
if not data or 'error' in data:
|
|
return go.Figure().add_annotation(
|
|
text="Upload a valid NDJSON file to see visualization",
|
|
xref="paper", yref="paper",
|
|
x=0.5, y=0.5, xanchor='center', yanchor='middle',
|
|
showarrow=False, font=dict(size=16)
|
|
)
|
|
|
|
df = pd.DataFrame(data['documents'])
|
|
|
|
if dimensions == '2d':
|
|
df['pca_1'] = df['pca_1_2d']
|
|
df['pca_2'] = df['pca_2_2d']
|
|
|
|
return create_plot(df, dimensions, color_by)
|
|
|
|
@callback(
|
|
Output('point-details', 'children'),
|
|
Input('embedding-plot', 'clickData'),
|
|
State('processed-data', 'data')
|
|
)
|
|
def display_click_data(clickData, data):
|
|
if not clickData or not data:
|
|
return "Click on a point to see details"
|
|
|
|
point_index = clickData['points'][0]['pointIndex']
|
|
doc = data['documents'][point_index]
|
|
|
|
return dbc.Card([
|
|
dbc.CardBody([
|
|
html.H5(f"Document: {doc['id']}", className="card-title"),
|
|
html.P(f"Text: {doc['text']}", className="card-text"),
|
|
html.P(f"Category: {doc['category']}", className="card-text"),
|
|
html.P(f"Subcategory: {doc['subcategory']}", className="card-text"),
|
|
html.P(f"Tags: {', '.join(doc['tags']) if doc['tags'] else 'None'}", className="card-text")
|
|
])
|
|
])
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True) |