add prompts

This commit is contained in:
2025-08-12 18:48:02 -07:00
parent 6c1e6d7395
commit 64685b9b4f
3 changed files with 334 additions and 23 deletions

275
app.py
View File

@@ -101,6 +101,117 @@ def create_plot(df, dimensions='3d', color_by='category', method='PCA'):
)
return fig
def create_dual_plot(doc_df, prompt_df, dimensions='3d', color_by='category', method='PCA', show_prompts=None):
"""Create plotly scatter plot with separate traces for documents and prompts."""
# Create the base figure
fig = go.Figure()
# Helper function to convert colors to grayscale
def to_grayscale_hex(color_str):
"""Convert a color to grayscale while maintaining some distinction."""
import plotly.colors as pc
# Try to get RGB values from the color
try:
if color_str.startswith('#'):
# Hex color
rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
else:
# Named color or other format - convert through plotly
rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
# Convert to grayscale using luminance formula, but keep some color
gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
# Make it a bit more gray but not completely
gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3,
gray_value * 0.7 + rgb[1] * 0.3,
gray_value * 0.7 + rgb[2] * 0.3)
return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
except:
return 'rgb(128,128,128)' # fallback gray
# Create document plot using plotly express for consistent colors
doc_color_values = create_color_mapping(doc_df.to_dict('records'), color_by)
doc_df_display = doc_df.copy()
doc_df_display['text_preview'] = doc_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x)
doc_df_display['tags_str'] = doc_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None')
hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
# Create documents plot to get the color mapping
if dimensions == '3d':
doc_fig = px.scatter_3d(
doc_df_display, x='dim_1', y='dim_2', z='dim_3',
color=doc_color_values,
hover_data=hover_fields
)
else:
doc_fig = px.scatter(
doc_df_display, x='dim_1', y='dim_2',
color=doc_color_values,
hover_data=hover_fields
)
# Add document traces to main figure
for trace in doc_fig.data:
trace.name = f'Documents - {trace.name}'
if dimensions == '3d':
trace.marker.size = 5
trace.marker.symbol = 'circle'
else:
trace.marker.size = 8
trace.marker.symbol = 'circle'
trace.marker.opacity = 1.0
fig.add_trace(trace)
# Add prompt traces if they exist
if prompt_df is not None and show_prompts and 'show' in show_prompts:
prompt_color_values = create_color_mapping(prompt_df.to_dict('records'), color_by)
prompt_df_display = prompt_df.copy()
prompt_df_display['text_preview'] = prompt_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x)
prompt_df_display['tags_str'] = prompt_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None')
# Create prompts plot to get consistent color grouping
if dimensions == '3d':
prompt_fig = px.scatter_3d(
prompt_df_display, x='dim_1', y='dim_2', z='dim_3',
color=prompt_color_values,
hover_data=hover_fields
)
else:
prompt_fig = px.scatter(
prompt_df_display, x='dim_1', y='dim_2',
color=prompt_color_values,
hover_data=hover_fields
)
# Add prompt traces with grayed colors
for trace in prompt_fig.data:
# Convert the color to grayscale
original_color = trace.marker.color
if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
trace.marker.color = to_grayscale_hex(trace.marker.color)
trace.name = f'Prompts - {trace.name}'
if dimensions == '3d':
trace.marker.size = 6
trace.marker.symbol = 'diamond'
else:
trace.marker.size = 10
trace.marker.symbol = 'diamond'
trace.marker.opacity = 0.8
fig.add_trace(trace)
title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
fig.update_layout(
title=title,
height=None,
autosize=True,
margin=dict(l=0, r=0, t=50, b=0)
)
return fig
# Layout
app.layout = dbc.Container([
dbc.Row([
@@ -132,6 +243,36 @@ app.layout = dbc.Container([
multiple=False
),
dcc.Upload(
id='upload-prompts',
children=html.Div([
'Drag and Drop Prompts or ',
html.A('Select Files')
]),
style={
'width': '100%',
'height': '60px',
'lineHeight': '60px',
'borderWidth': '1px',
'borderStyle': 'dashed',
'borderRadius': '5px',
'textAlign': 'center',
'margin-bottom': '20px',
'borderColor': '#28a745'
},
multiple=False
),
dbc.Button(
"Reset All Data",
id='reset-button',
color='danger',
outline=True,
size='sm',
className='mb-3',
style={'width': '100%'}
),
html.H5("Visualization Controls", className="mb-3"),
dbc.Label("Method:"),
@@ -169,6 +310,14 @@ app.layout = dbc.Container([
style={'margin-bottom': '20px'}
),
dbc.Label("Show Prompts:"),
dcc.Checklist(
id='show-prompts-toggle',
options=[{'label': 'Show prompts on plot', 'value': 'show'}],
value=['show'],
style={'margin-bottom': '20px'}
),
html.H5("Point Details", className="mb-3"),
html.Div(id='point-details', children="Click on a point to see details")
@@ -184,7 +333,8 @@ app.layout = dbc.Container([
], width=9)
]),
dcc.Store(id='processed-data')
dcc.Store(id='processed-data'),
dcc.Store(id='processed-prompts')
], fluid=True)
@callback(
@@ -208,14 +358,37 @@ def process_uploaded_file(contents, filename):
except Exception as e:
return {'error': str(e)}
@callback(
Output('processed-prompts', 'data'),
Input('upload-prompts', 'contents'),
State('upload-prompts', 'filename')
)
def process_uploaded_prompts(contents, filename):
if contents is None:
return None
try:
prompts = parse_ndjson(contents)
embeddings = np.array([prompt['embedding'] for prompt in prompts])
# Store original embeddings and prompts
return {
'prompts': prompts,
'embeddings': embeddings.tolist()
}
except Exception as e:
return {'error': str(e)}
@callback(
Output('embedding-plot', 'figure'),
[Input('processed-data', 'data'),
Input('processed-prompts', 'data'),
Input('method-dropdown', 'value'),
Input('color-dropdown', 'value'),
Input('dimension-toggle', 'value')]
Input('dimension-toggle', 'value'),
Input('show-prompts-toggle', 'value')]
)
def update_plot(data, method, color_by, dimensions):
def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
if not data or 'error' in data:
return go.Figure().add_annotation(
text="Upload a valid NDJSON file to see visualization",
@@ -224,16 +397,28 @@ def update_plot(data, method, color_by, dimensions):
showarrow=False, font=dict(size=16)
)
# Get embeddings and apply selected method
embeddings = np.array(data['embeddings'])
# Prepare embeddings for dimensionality reduction
doc_embeddings = np.array(data['embeddings'])
all_embeddings = doc_embeddings
has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
if has_prompts:
prompt_embeddings = np.array(prompts_data['embeddings'])
all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
n_components = 3 if dimensions == '3d' else 2
# Apply dimensionality reduction to combined data
reduced, variance_explained = apply_dimensionality_reduction(
embeddings, method=method, n_components=n_components
all_embeddings, method=method, n_components=n_components
)
# Create dataframe with reduced dimensions
df_data = []
# Split reduced embeddings back
doc_reduced = reduced[:len(doc_embeddings)]
prompt_reduced = reduced[len(doc_embeddings):] if has_prompts else None
# Create dataframes
doc_df_data = []
for i, doc in enumerate(data['documents']):
row = {
'id': doc['id'],
@@ -241,28 +426,52 @@ def update_plot(data, method, color_by, dimensions):
'category': doc.get('category', 'Unknown'),
'subcategory': doc.get('subcategory', 'Unknown'),
'tags': doc.get('tags', []),
'dim_1': reduced[i, 0],
'dim_2': reduced[i, 1]
'dim_1': doc_reduced[i, 0],
'dim_2': doc_reduced[i, 1],
'type': 'document'
}
if dimensions == '3d':
row['dim_3'] = reduced[i, 2]
df_data.append(row)
row['dim_3'] = doc_reduced[i, 2]
doc_df_data.append(row)
df = pd.DataFrame(df_data)
doc_df = pd.DataFrame(doc_df_data)
return create_plot(df, dimensions, color_by, method.upper())
prompt_df = None
if has_prompts and prompt_reduced is not None:
prompt_df_data = []
for i, prompt in enumerate(prompts_data['prompts']):
row = {
'id': prompt['id'],
'text': prompt['text'],
'category': prompt.get('category', 'Unknown'),
'subcategory': prompt.get('subcategory', 'Unknown'),
'tags': prompt.get('tags', []),
'dim_1': prompt_reduced[i, 0],
'dim_2': prompt_reduced[i, 1],
'type': 'prompt'
}
if dimensions == '3d':
row['dim_3'] = prompt_reduced[i, 2]
prompt_df_data.append(row)
prompt_df = pd.DataFrame(prompt_df_data)
return create_dual_plot(doc_df, prompt_df, dimensions, color_by, method.upper(), show_prompts)
@callback(
Output('point-details', 'children'),
Input('embedding-plot', 'clickData'),
State('processed-data', 'data')
[State('processed-data', 'data'),
State('processed-prompts', 'data')]
)
def display_click_data(clickData, data):
def display_click_data(clickData, data, prompts_data):
if not clickData or not data:
return "Click on a point to see details"
# Get point index - try different possible keys
# Get point info from click
point_data = clickData['points'][0]
trace_name = point_data.get('fullData', {}).get('name', 'Documents')
if 'pointIndex' in point_data:
point_index = point_data['pointIndex']
elif 'pointNumber' in point_data:
@@ -270,17 +479,37 @@ def display_click_data(clickData, data):
else:
return "Could not identify clicked point"
doc = data['documents'][point_index]
# Determine which dataset this point belongs to
if trace_name == 'Prompts' and prompts_data and 'prompts' in prompts_data:
item = prompts_data['prompts'][point_index]
item_type = 'Prompt'
else:
item = data['documents'][point_index]
item_type = 'Document'
return dbc.Card([
dbc.CardBody([
html.H5(f"Document: {doc['id']}", className="card-title"),
html.P(f"Text: {doc['text']}", className="card-text"),
html.P(f"Category: {doc.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {doc.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(doc.get('tags', [])) if doc.get('tags') else 'None'}", className="card-text")
html.H5(f"{item_type}: {item['id']}", className="card-title"),
html.P(f"Text: {item['text']}", className="card-text"),
html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
html.P(f"Type: {item_type}", className="card-text text-muted")
])
])
@callback(
[Output('processed-data', 'data', allow_duplicate=True),
Output('processed-prompts', 'data', allow_duplicate=True),
Output('point-details', 'children', allow_duplicate=True)],
Input('reset-button', 'n_clicks'),
prevent_initial_call=True
)
def reset_data(n_clicks):
if n_clicks is None or n_clicks == 0:
return dash.no_update, dash.no_update, dash.no_update
return None, None, "Click on a point to see details"
if __name__ == '__main__':
app.run(debug=True)