add other options

This commit is contained in:
2025-08-12 15:48:22 -07:00
parent 850140481d
commit 69a77d18b9
4 changed files with 241 additions and 121 deletions

116
app.py
View File

@@ -11,6 +11,8 @@ import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import umap
from openTSNE import TSNE
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
@@ -30,11 +32,23 @@ def parse_ndjson(contents):
documents.append(doc)
return documents
def apply_pca(embeddings, n_components=3):
"""Apply PCA to embeddings."""
pca = PCA(n_components=n_components)
reduced = pca.fit_transform(embeddings)
return reduced, pca.explained_variance_ratio_
def apply_dimensionality_reduction(embeddings, method='pca', n_components=3):
"""Apply dimensionality reduction to embeddings."""
if method == 'pca':
reducer = PCA(n_components=n_components)
reduced = reducer.fit_transform(embeddings)
variance_explained = reducer.explained_variance_ratio_
return reduced, variance_explained
elif method == 'tsne':
reducer = TSNE(n_components=n_components, random_state=42)
reduced = reducer.fit(embeddings)
return reduced, None
elif method == 'umap':
reducer = umap.UMAP(n_components=n_components, random_state=42)
reduced = reducer.fit_transform(embeddings)
return reduced, None
else:
raise ValueError(f"Unknown method: {method}")
def create_color_mapping(documents, color_by):
"""Create color mapping for documents based on specified field."""
@@ -49,7 +63,7 @@ def create_color_mapping(documents, color_by):
return values
def create_plot(df, dimensions='3d', color_by='category'):
def create_plot(df, dimensions='3d', color_by='category', method='PCA'):
"""Create plotly scatter plot."""
color_values = create_color_mapping(df.to_dict('records'), color_by)
@@ -65,18 +79,18 @@ def create_plot(df, dimensions='3d', color_by='category'):
if dimensions == '3d':
fig = px.scatter_3d(
df_display, x='pca_1', y='pca_2', z='pca_3',
df_display, x='dim_1', y='dim_2', z='dim_3',
color=color_values,
hover_data=hover_fields,
title=f'3D Embedding Visualization (colored by {color_by})'
title=f'3D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=5))
else:
fig = px.scatter(
df_display, x='pca_1', y='pca_2',
df_display, x='dim_1', y='dim_2',
color=color_values,
hover_data=hover_fields,
title=f'2D Embedding Visualization (colored by {color_by})'
title=f'2D Embedding Visualization - {method} (colored by {color_by})'
)
fig.update_traces(marker=dict(size=8))
@@ -116,6 +130,19 @@ app.layout = dbc.Container([
]),
dbc.Row([
dbc.Col([
dbc.Label("Method:"),
dcc.Dropdown(
id='method-dropdown',
options=[
{'label': 'PCA', 'value': 'pca'},
{'label': 't-SNE', 'value': 'tsne'},
{'label': 'UMAP', 'value': 'umap'}
],
value='pca',
style={'margin-bottom': '10px'}
)
], width=4),
dbc.Col([
dbc.Label("Color by:"),
dcc.Dropdown(
@@ -128,7 +155,7 @@ app.layout = dbc.Container([
value='category',
style={'margin-bottom': '10px'}
)
], width=6),
], width=4),
dbc.Col([
dbc.Label("Dimensions:"),
dcc.RadioItems(
@@ -140,7 +167,7 @@ app.layout = dbc.Container([
value='3d',
inline=True
)
], width=6)
], width=4)
], className="mb-3"),
dbc.Row([
@@ -171,30 +198,10 @@ def process_uploaded_file(contents, filename):
documents = parse_ndjson(contents)
embeddings = np.array([doc['embedding'] for doc in documents])
# Apply PCA
pca_2d, var_2d = apply_pca(embeddings, n_components=2)
pca_3d, var_3d = apply_pca(embeddings, n_components=3)
# Create dataframe
df_data = []
for i, doc in enumerate(documents):
df_data.append({
'id': doc['id'],
'text': doc['text'],
'category': doc.get('category', 'Unknown'),
'subcategory': doc.get('subcategory', 'Unknown'),
'tags': doc.get('tags', []),
'pca_1': pca_3d[i, 0],
'pca_2': pca_3d[i, 1],
'pca_3': pca_3d[i, 2],
'pca_1_2d': pca_2d[i, 0],
'pca_2_2d': pca_2d[i, 1]
})
# Store original embeddings and documents
return {
'documents': df_data,
'variance_explained_2d': var_2d.tolist(),
'variance_explained_3d': var_3d.tolist()
'documents': documents,
'embeddings': embeddings.tolist()
}
except Exception as e:
return {'error': str(e)}
@@ -202,10 +209,11 @@ def process_uploaded_file(contents, filename):
@callback(
Output('embedding-plot', 'figure'),
[Input('processed-data', 'data'),
Input('method-dropdown', 'value'),
Input('color-dropdown', 'value'),
Input('dimension-toggle', 'value')]
)
def update_plot(data, color_by, dimensions):
def update_plot(data, method, color_by, dimensions):
if not data or 'error' in data:
return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization",
@@ -214,13 +222,33 @@ def update_plot(data, color_by, dimensions):
showarrow=False, font=dict(size=16)
)
df = pd.DataFrame(data['documents'])
# Get embeddings and apply selected method
embeddings = np.array(data['embeddings'])
n_components = 3 if dimensions == '3d' else 2
if dimensions == '2d':
df['pca_1'] = df['pca_1_2d']
df['pca_2'] = df['pca_2_2d']
reduced, variance_explained = apply_dimensionality_reduction(
embeddings, method=method, n_components=n_components
)
return create_plot(df, dimensions, color_by)
# Create dataframe with reduced dimensions
df_data = []
for i, doc in enumerate(data['documents']):
row = {
'id': doc['id'],
'text': doc['text'],
'category': doc.get('category', 'Unknown'),
'subcategory': doc.get('subcategory', 'Unknown'),
'tags': doc.get('tags', []),
'dim_1': reduced[i, 0],
'dim_2': reduced[i, 1]
}
if dimensions == '3d':
row['dim_3'] = reduced[i, 2]
df_data.append(row)
df = pd.DataFrame(df_data)
return create_plot(df, dimensions, color_by, method.upper())
@callback(
Output('point-details', 'children'),
@@ -238,9 +266,9 @@ def display_click_data(clickData, data):
dbc.CardBody([
html.H5(f"Document: {doc['id']}", className="card-title"),
html.P(f"Text: {doc['text']}", className="card-text"),
html.P(f"Category: {doc['category']}", className="card-text"),
html.P(f"Subcategory: {doc['subcategory']}", className="card-text"),
html.P(f"Tags: {', '.join(doc['tags']) if doc['tags'] else 'None'}", className="card-text")
html.P(f"Category: {doc.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {doc.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(doc.get('tags', [])) if doc.get('tags') else 'None'}", className="card-text")
])
])