From 64685b9b4f5f05b310c556add444e8c47fd4dfd4 Mon Sep 17 00:00:00 2001 From: Austin Godber Date: Tue, 12 Aug 2025 18:48:02 -0700 Subject: [PATCH] add prompts --- CLAUDE.md | 72 +++++++++++ app.py | 275 ++++++++++++++++++++++++++++++++++++++---- sample_prompts.ndjson | 10 ++ 3 files changed, 334 insertions(+), 23 deletions(-) create mode 100644 CLAUDE.md create mode 100644 sample_prompts.ndjson diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..787ee92 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,72 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with +code in this repository. + +## Project Overview + +EmbeddingBuddy is a Python Dash web application for interactive exploration and +visualization of embedding vectors through dimensionality reduction techniques +(PCA, t-SNE, UMAP). The app provides a drag-and-drop interface for uploading +NDJSON files containing embeddings and visualizes them in 2D/3D plots. + +## Development Commands + +**Install dependencies:** + +```bash +uv sync +``` + +**Run the application:** + +```bash +uv run python app.py +``` + +The app will be available at http://127.0.0.1:8050 + +**Test with sample data:** +Use the included `sample_data.ndjson` file for testing the application functionality. + +## Architecture + +### Core Files + +- `app.py` - Main Dash application with complete web interface, data processing, + and visualization logic +- `main.py` - Simple entry point (currently minimal) +- `pyproject.toml` - Project configuration and dependencies using uv package manager + +### Key Components + +- **Data Processing**: NDJSON parser that handles embedding documents with + required fields (`embedding`, `text`) and optional metadata (`id`, `category`, `subcategory`, `tags`) +- **Dimensionality Reduction**: Supports PCA, t-SNE (openTSNE), and UMAP algorithms +- **Visualization**: Plotly-based 2D/3D scatter plots with interactive features +- **UI Layout**: Bootstrap-styled sidebar with controls and large visualization area +- **State Management**: Dash callbacks for reactive updates between upload, + method selection, and plot rendering + +### Data Format + +The application expects NDJSON files where each line contains: + +```json +{"id": "doc_001", "embedding": [0.1, -0.3, 0.7, ...], "text": "Sample text", "category": "news", "subcategory": "politics", "tags": ["election"]} +``` + +### Callback Architecture + +- File upload → Data processing and storage in dcc.Store +- Method/parameter changes → Dimensionality reduction and plot update +- Point clicks → Detail display in sidebar + +## Dependencies + +Uses modern Python stack with uv for dependency management: + +- Dash + Plotly for web interface and visualization +- scikit-learn (PCA), openTSNE, umap-learn for dimensionality reduction +- pandas/numpy for data manipulation +- dash-bootstrap-components for styling \ No newline at end of file diff --git a/app.py b/app.py index 70bc549..e34874d 100644 --- a/app.py +++ b/app.py @@ -101,6 +101,117 @@ def create_plot(df, dimensions='3d', color_by='category', method='PCA'): ) return fig +def create_dual_plot(doc_df, prompt_df, dimensions='3d', color_by='category', method='PCA', show_prompts=None): + """Create plotly scatter plot with separate traces for documents and prompts.""" + + # Create the base figure + fig = go.Figure() + + # Helper function to convert colors to grayscale + def to_grayscale_hex(color_str): + """Convert a color to grayscale while maintaining some distinction.""" + import plotly.colors as pc + # Try to get RGB values from the color + try: + if color_str.startswith('#'): + # Hex color + rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5)) + else: + # Named color or other format - convert through plotly + rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0]) + + # Convert to grayscale using luminance formula, but keep some color + gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]) + # Make it a bit more gray but not completely + gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, + gray_value * 0.7 + rgb[1] * 0.3, + gray_value * 0.7 + rgb[2] * 0.3) + return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})' + except: + return 'rgb(128,128,128)' # fallback gray + + # Create document plot using plotly express for consistent colors + doc_color_values = create_color_mapping(doc_df.to_dict('records'), color_by) + doc_df_display = doc_df.copy() + doc_df_display['text_preview'] = doc_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x) + doc_df_display['tags_str'] = doc_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None') + + hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str'] + + # Create documents plot to get the color mapping + if dimensions == '3d': + doc_fig = px.scatter_3d( + doc_df_display, x='dim_1', y='dim_2', z='dim_3', + color=doc_color_values, + hover_data=hover_fields + ) + else: + doc_fig = px.scatter( + doc_df_display, x='dim_1', y='dim_2', + color=doc_color_values, + hover_data=hover_fields + ) + + # Add document traces to main figure + for trace in doc_fig.data: + trace.name = f'Documents - {trace.name}' + if dimensions == '3d': + trace.marker.size = 5 + trace.marker.symbol = 'circle' + else: + trace.marker.size = 8 + trace.marker.symbol = 'circle' + trace.marker.opacity = 1.0 + fig.add_trace(trace) + + # Add prompt traces if they exist + if prompt_df is not None and show_prompts and 'show' in show_prompts: + prompt_color_values = create_color_mapping(prompt_df.to_dict('records'), color_by) + prompt_df_display = prompt_df.copy() + prompt_df_display['text_preview'] = prompt_df_display['text'].apply(lambda x: x[:100] + "..." if len(x) > 100 else x) + prompt_df_display['tags_str'] = prompt_df_display['tags'].apply(lambda x: ', '.join(x) if x else 'None') + + # Create prompts plot to get consistent color grouping + if dimensions == '3d': + prompt_fig = px.scatter_3d( + prompt_df_display, x='dim_1', y='dim_2', z='dim_3', + color=prompt_color_values, + hover_data=hover_fields + ) + else: + prompt_fig = px.scatter( + prompt_df_display, x='dim_1', y='dim_2', + color=prompt_color_values, + hover_data=hover_fields + ) + + # Add prompt traces with grayed colors + for trace in prompt_fig.data: + # Convert the color to grayscale + original_color = trace.marker.color + if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str): + trace.marker.color = to_grayscale_hex(trace.marker.color) + + trace.name = f'Prompts - {trace.name}' + if dimensions == '3d': + trace.marker.size = 6 + trace.marker.symbol = 'diamond' + else: + trace.marker.size = 10 + trace.marker.symbol = 'diamond' + trace.marker.opacity = 0.8 + fig.add_trace(trace) + + title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})' + fig.update_layout( + title=title, + height=None, + autosize=True, + margin=dict(l=0, r=0, t=50, b=0) + ) + + return fig + # Layout app.layout = dbc.Container([ dbc.Row([ @@ -132,6 +243,36 @@ app.layout = dbc.Container([ multiple=False ), + dcc.Upload( + id='upload-prompts', + children=html.Div([ + 'Drag and Drop Prompts or ', + html.A('Select Files') + ]), + style={ + 'width': '100%', + 'height': '60px', + 'lineHeight': '60px', + 'borderWidth': '1px', + 'borderStyle': 'dashed', + 'borderRadius': '5px', + 'textAlign': 'center', + 'margin-bottom': '20px', + 'borderColor': '#28a745' + }, + multiple=False + ), + + dbc.Button( + "Reset All Data", + id='reset-button', + color='danger', + outline=True, + size='sm', + className='mb-3', + style={'width': '100%'} + ), + html.H5("Visualization Controls", className="mb-3"), dbc.Label("Method:"), @@ -169,6 +310,14 @@ app.layout = dbc.Container([ style={'margin-bottom': '20px'} ), + dbc.Label("Show Prompts:"), + dcc.Checklist( + id='show-prompts-toggle', + options=[{'label': 'Show prompts on plot', 'value': 'show'}], + value=['show'], + style={'margin-bottom': '20px'} + ), + html.H5("Point Details", className="mb-3"), html.Div(id='point-details', children="Click on a point to see details") @@ -184,7 +333,8 @@ app.layout = dbc.Container([ ], width=9) ]), - dcc.Store(id='processed-data') + dcc.Store(id='processed-data'), + dcc.Store(id='processed-prompts') ], fluid=True) @callback( @@ -208,14 +358,37 @@ def process_uploaded_file(contents, filename): except Exception as e: return {'error': str(e)} +@callback( + Output('processed-prompts', 'data'), + Input('upload-prompts', 'contents'), + State('upload-prompts', 'filename') +) +def process_uploaded_prompts(contents, filename): + if contents is None: + return None + + try: + prompts = parse_ndjson(contents) + embeddings = np.array([prompt['embedding'] for prompt in prompts]) + + # Store original embeddings and prompts + return { + 'prompts': prompts, + 'embeddings': embeddings.tolist() + } + except Exception as e: + return {'error': str(e)} + @callback( Output('embedding-plot', 'figure'), [Input('processed-data', 'data'), + Input('processed-prompts', 'data'), Input('method-dropdown', 'value'), Input('color-dropdown', 'value'), - Input('dimension-toggle', 'value')] + Input('dimension-toggle', 'value'), + Input('show-prompts-toggle', 'value')] ) -def update_plot(data, method, color_by, dimensions): +def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts): if not data or 'error' in data: return go.Figure().add_annotation( text="Upload a valid NDJSON file to see visualization", @@ -224,16 +397,28 @@ def update_plot(data, method, color_by, dimensions): showarrow=False, font=dict(size=16) ) - # Get embeddings and apply selected method - embeddings = np.array(data['embeddings']) + # Prepare embeddings for dimensionality reduction + doc_embeddings = np.array(data['embeddings']) + all_embeddings = doc_embeddings + has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts') + + if has_prompts: + prompt_embeddings = np.array(prompts_data['embeddings']) + all_embeddings = np.vstack([doc_embeddings, prompt_embeddings]) + n_components = 3 if dimensions == '3d' else 2 + # Apply dimensionality reduction to combined data reduced, variance_explained = apply_dimensionality_reduction( - embeddings, method=method, n_components=n_components + all_embeddings, method=method, n_components=n_components ) - # Create dataframe with reduced dimensions - df_data = [] + # Split reduced embeddings back + doc_reduced = reduced[:len(doc_embeddings)] + prompt_reduced = reduced[len(doc_embeddings):] if has_prompts else None + + # Create dataframes + doc_df_data = [] for i, doc in enumerate(data['documents']): row = { 'id': doc['id'], @@ -241,28 +426,52 @@ def update_plot(data, method, color_by, dimensions): 'category': doc.get('category', 'Unknown'), 'subcategory': doc.get('subcategory', 'Unknown'), 'tags': doc.get('tags', []), - 'dim_1': reduced[i, 0], - 'dim_2': reduced[i, 1] + 'dim_1': doc_reduced[i, 0], + 'dim_2': doc_reduced[i, 1], + 'type': 'document' } if dimensions == '3d': - row['dim_3'] = reduced[i, 2] - df_data.append(row) + row['dim_3'] = doc_reduced[i, 2] + doc_df_data.append(row) - df = pd.DataFrame(df_data) + doc_df = pd.DataFrame(doc_df_data) - return create_plot(df, dimensions, color_by, method.upper()) + prompt_df = None + if has_prompts and prompt_reduced is not None: + prompt_df_data = [] + for i, prompt in enumerate(prompts_data['prompts']): + row = { + 'id': prompt['id'], + 'text': prompt['text'], + 'category': prompt.get('category', 'Unknown'), + 'subcategory': prompt.get('subcategory', 'Unknown'), + 'tags': prompt.get('tags', []), + 'dim_1': prompt_reduced[i, 0], + 'dim_2': prompt_reduced[i, 1], + 'type': 'prompt' + } + if dimensions == '3d': + row['dim_3'] = prompt_reduced[i, 2] + prompt_df_data.append(row) + + prompt_df = pd.DataFrame(prompt_df_data) + + return create_dual_plot(doc_df, prompt_df, dimensions, color_by, method.upper(), show_prompts) @callback( Output('point-details', 'children'), Input('embedding-plot', 'clickData'), - State('processed-data', 'data') + [State('processed-data', 'data'), + State('processed-prompts', 'data')] ) -def display_click_data(clickData, data): +def display_click_data(clickData, data, prompts_data): if not clickData or not data: return "Click on a point to see details" - # Get point index - try different possible keys + # Get point info from click point_data = clickData['points'][0] + trace_name = point_data.get('fullData', {}).get('name', 'Documents') + if 'pointIndex' in point_data: point_index = point_data['pointIndex'] elif 'pointNumber' in point_data: @@ -270,17 +479,37 @@ def display_click_data(clickData, data): else: return "Could not identify clicked point" - doc = data['documents'][point_index] + # Determine which dataset this point belongs to + if trace_name == 'Prompts' and prompts_data and 'prompts' in prompts_data: + item = prompts_data['prompts'][point_index] + item_type = 'Prompt' + else: + item = data['documents'][point_index] + item_type = 'Document' return dbc.Card([ dbc.CardBody([ - html.H5(f"Document: {doc['id']}", className="card-title"), - html.P(f"Text: {doc['text']}", className="card-text"), - html.P(f"Category: {doc.get('category', 'Unknown')}", className="card-text"), - html.P(f"Subcategory: {doc.get('subcategory', 'Unknown')}", className="card-text"), - html.P(f"Tags: {', '.join(doc.get('tags', [])) if doc.get('tags') else 'None'}", className="card-text") + html.H5(f"{item_type}: {item['id']}", className="card-title"), + html.P(f"Text: {item['text']}", className="card-text"), + html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"), + html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"), + html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"), + html.P(f"Type: {item_type}", className="card-text text-muted") ]) ]) +@callback( + [Output('processed-data', 'data', allow_duplicate=True), + Output('processed-prompts', 'data', allow_duplicate=True), + Output('point-details', 'children', allow_duplicate=True)], + Input('reset-button', 'n_clicks'), + prevent_initial_call=True +) +def reset_data(n_clicks): + if n_clicks is None or n_clicks == 0: + return dash.no_update, dash.no_update, dash.no_update + + return None, None, "Click on a point to see details" + if __name__ == '__main__': app.run(debug=True) \ No newline at end of file diff --git a/sample_prompts.ndjson b/sample_prompts.ndjson new file mode 100644 index 0000000..e30d0a7 --- /dev/null +++ b/sample_prompts.ndjson @@ -0,0 +1,10 @@ +{"id": "prompt_001", "embedding": [0.15, -0.28, 0.65, 0.42, -0.11, 0.33, 0.78, -0.52], "text": "Find articles about machine learning applications", "category": "search", "subcategory": "technology", "tags": ["AI", "research"]} +{"id": "prompt_002", "embedding": [0.72, 0.18, -0.35, 0.51, 0.09, -0.44, 0.27, 0.63], "text": "Show me product reviews for smartphones", "category": "search", "subcategory": "product", "tags": ["mobile", "reviews"]} +{"id": "prompt_003", "embedding": [-0.21, 0.59, 0.34, -0.67, 0.45, 0.12, -0.38, 0.76], "text": "What are the latest political developments?", "category": "search", "subcategory": "news", "tags": ["politics", "current events"]} +{"id": "prompt_004", "embedding": [0.48, -0.15, 0.72, 0.31, -0.58, 0.24, 0.67, -0.39], "text": "Summarize recent tech industry trends", "category": "analysis", "subcategory": "technology", "tags": ["tech", "trends", "summary"]} +{"id": "prompt_005", "embedding": [-0.33, 0.47, -0.62, 0.28, 0.71, -0.18, 0.54, 0.35], "text": "Compare different smartphone models", "category": "analysis", "subcategory": "product", "tags": ["comparison", "mobile", "evaluation"]} +{"id": "prompt_006", "embedding": [0.64, 0.21, 0.39, -0.45, 0.13, 0.58, -0.27, 0.74], "text": "Analyze voter sentiment on recent policies", "category": "analysis", "subcategory": "politics", "tags": ["sentiment", "politics", "analysis"]} +{"id": "prompt_007", "embedding": [0.29, -0.43, 0.56, 0.68, -0.22, 0.37, 0.14, -0.61], "text": "Generate a summary of machine learning research", "category": "generation", "subcategory": "technology", "tags": ["AI", "research", "summary"]} +{"id": "prompt_008", "embedding": [-0.17, 0.52, -0.48, 0.36, 0.74, -0.29, 0.61, 0.18], "text": "Create a product recommendation report", "category": "generation", "subcategory": "product", "tags": ["recommendation", "report", "analysis"]} +{"id": "prompt_009", "embedding": [0.55, 0.08, 0.41, -0.37, 0.26, 0.69, -0.14, 0.58], "text": "Write a news brief on election updates", "category": "generation", "subcategory": "news", "tags": ["election", "news", "brief"]} +{"id": "prompt_010", "embedding": [0.23, -0.59, 0.47, 0.61, -0.35, 0.18, 0.72, -0.26], "text": "Explain how neural networks work", "category": "explanation", "subcategory": "technology", "tags": ["AI", "education", "neural networks"]} \ No newline at end of file