add in browser embedding generation #4

Merged
godber merged 2 commits from add-browser-embeddings into main 2025-09-07 07:26:57 -07:00
14 changed files with 1598 additions and 4 deletions
Showing only changes of commit bced5e07ce - Show all commits

View File

@@ -3,7 +3,8 @@
"allow": [ "allow": [
"Bash(mkdir:*)", "Bash(mkdir:*)",
"Bash(uv run:*)", "Bash(uv run:*)",
"Bash(uv add:*)" "Bash(uv add:*)",
"Bash(uv sync:*)"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

View File

@@ -9,14 +9,13 @@ from .ui.callbacks.interactions import InteractionCallbacks
def create_app(): def create_app():
import os import os
# Get the project root directory (two levels up from this file) # Get the project root directory (two levels up from this file)
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
assets_path = os.path.join(project_root, 'assets') assets_path = os.path.join(project_root, "assets")
app = dash.Dash( app = dash.Dash(
__name__, __name__, external_stylesheets=[dbc.themes.BOOTSTRAP], assets_folder=assets_path
external_stylesheets=[dbc.themes.BOOTSTRAP],
assets_folder=assets_path
) )
# Allow callbacks to components that are dynamically created in tabs # Allow callbacks to components that are dynamically created in tabs

View File

@@ -558,17 +558,17 @@ class DataProcessingCallbacks:
) )
def handle_text_input_actions(clear_clicks, load_clicks): def handle_text_input_actions(clear_clicks, load_clicks):
from dash import ctx from dash import ctx
if not ctx.triggered: if not ctx.triggered:
return no_update return no_update
button_id = ctx.triggered[0]['prop_id'].split('.')[0] button_id = ctx.triggered[0]["prop_id"].split(".")[0]
if button_id == "clear-text-btn" and clear_clicks: if button_id == "clear-text-btn" and clear_clicks:
return "" return ""
elif button_id == "load-sample-btn" and load_clicks: elif button_id == "load-sample-btn" and load_clicks:
return self._load_sample_text() return self._load_sample_text()
return no_update return no_update
# Model info callback # Model info callback
@@ -648,15 +648,19 @@ class DataProcessingCallbacks:
def _load_sample_text(self): def _load_sample_text(self):
"""Load sample text from assets/sample-txt.md file.""" """Load sample text from assets/sample-txt.md file."""
import os import os
try: try:
# Get the project root directory (four levels up from this file) # Get the project root directory (four levels up from this file)
current_file = os.path.abspath(__file__) current_file = os.path.abspath(__file__)
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))) project_root = os.path.dirname(
sample_file_path = os.path.join(project_root, 'assets', 'sample-txt.md') os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
)
)
sample_file_path = os.path.join(project_root, "assets", "sample-txt.md")
if os.path.exists(sample_file_path): if os.path.exists(sample_file_path):
with open(sample_file_path, 'r', encoding='utf-8') as file: with open(sample_file_path, "r", encoding="utf-8") as file:
return file.read() return file.read()
else: else:
# Fallback sample text if file doesn't exist # Fallback sample text if file doesn't exist
@@ -677,8 +681,8 @@ She finely chopped the garlic and sautéed it in two tablespoons of olive oil.
A pinch of saffron adds a beautiful color and aroma to traditional paella. A pinch of saffron adds a beautiful color and aroma to traditional paella.
If the soup is too salty, add a peeled potato to absorb excess sodium. If the soup is too salty, add a peeled potato to absorb excess sodium.
Let the bread dough rise for at least an hour in a warm, draft-free spot.""" Let the bread dough rise for at least an hour in a warm, draft-free spot."""
except Exception as e: except Exception:
# Return a simple fallback if there's any error # Return a simple fallback if there's any error
return "This is sample text for testing embedding generation. You can replace this with your own text." return "This is sample text for testing embedding generation. You can replace this with your own text."

View File

@@ -27,7 +27,7 @@ class AppLayout:
window.transformersPipeline = pipeline; window.transformersPipeline = pipeline;
console.log('✅ Transformers.js pipeline loaded globally'); console.log('✅ Transformers.js pipeline loaded globally');
""", """,
type="module" type="module",
), ),
], ],
width=12, width=12,

View File

@@ -1,6 +1,5 @@
"""Tests for client-side embedding processing functionality.""" """Tests for client-side embedding processing functionality."""
import pytest
import numpy as np import numpy as np
from src.embeddingbuddy.data.processor import DataProcessor from src.embeddingbuddy.data.processor import DataProcessor
@@ -23,33 +22,30 @@ class TestClientEmbeddingsProcessing:
"text": "First test document", "text": "First test document",
"category": "Text Input", "category": "Text Input",
"subcategory": "Generated", "subcategory": "Generated",
"tags": [] "tags": [],
}, },
{ {
"id": "text_input_1", "id": "text_input_1",
"text": "Second test document", "text": "Second test document",
"category": "Text Input", "category": "Text Input",
"subcategory": "Generated", "subcategory": "Generated",
"tags": [] "tags": [],
} },
], ],
"embeddings": [ "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8]
]
} }
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert result.error is None assert result.error is None
assert len(result.documents) == 2 assert len(result.documents) == 2
assert result.embeddings.shape == (2, 4) assert result.embeddings.shape == (2, 4)
# Check document content # Check document content
assert result.documents[0].text == "First test document" assert result.documents[0].text == "First test document"
assert result.documents[1].text == "Second test document" assert result.documents[1].text == "Second test document"
# Check embeddings match # Check embeddings match
np.testing.assert_array_equal(result.embeddings[0], [0.1, 0.2, 0.3, 0.4]) np.testing.assert_array_equal(result.embeddings[0], [0.1, 0.2, 0.3, 0.4])
np.testing.assert_array_equal(result.embeddings[1], [0.5, 0.6, 0.7, 0.8]) np.testing.assert_array_equal(result.embeddings[1], [0.5, 0.6, 0.7, 0.8])
@@ -57,9 +53,9 @@ class TestClientEmbeddingsProcessing:
def test_process_client_embeddings_with_error(self): def test_process_client_embeddings_with_error(self):
"""Test processing client data with error.""" """Test processing client data with error."""
client_data = {"error": "Transformers.js not loaded"} client_data = {"error": "Transformers.js not loaded"}
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert result.error == "Transformers.js not loaded" assert result.error == "Transformers.js not loaded"
assert len(result.documents) == 0 assert len(result.documents) == 0
@@ -68,9 +64,9 @@ class TestClientEmbeddingsProcessing:
def test_process_client_embeddings_missing_data(self): def test_process_client_embeddings_missing_data(self):
"""Test processing with missing documents or embeddings.""" """Test processing with missing documents or embeddings."""
client_data = {"documents": []} client_data = {"documents": []}
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert "No documents or embeddings in client data" in result.error assert "No documents or embeddings in client data" in result.error
assert len(result.documents) == 0 assert len(result.documents) == 0
@@ -79,16 +75,19 @@ class TestClientEmbeddingsProcessing:
"""Test processing with mismatched document and embedding counts.""" """Test processing with mismatched document and embedding counts."""
client_data = { client_data = {
"documents": [ "documents": [
{"id": "test", "text": "Test document", "category": "Test", "subcategory": "Test", "tags": []} {
"id": "test",
"text": "Test document",
"category": "Test",
"subcategory": "Test",
"tags": [],
}
], ],
"embeddings": [ "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8]
]
} }
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert "Mismatch between number of documents and embeddings" in result.error assert "Mismatch between number of documents and embeddings" in result.error
assert len(result.documents) == 0 assert len(result.documents) == 0
@@ -98,16 +97,19 @@ class TestClientEmbeddingsProcessing:
client_data = { client_data = {
"documents": [ "documents": [
{"text": ""}, # Empty text should be skipped {"text": ""}, # Empty text should be skipped
{"id": "test2", "text": "Valid document", "category": "Test", "subcategory": "Test", "tags": []} {
"id": "test2",
"text": "Valid document",
"category": "Test",
"subcategory": "Test",
"tags": [],
},
], ],
"embeddings": [ "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8]
]
} }
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert result.error is None assert result.error is None
assert len(result.documents) == 1 # Only valid document should be processed assert len(result.documents) == 1 # Only valid document should be processed
@@ -117,15 +119,18 @@ class TestClientEmbeddingsProcessing:
"""Test automatic ID generation for documents without IDs.""" """Test automatic ID generation for documents without IDs."""
client_data = { client_data = {
"documents": [ "documents": [
{"text": "Document without ID", "category": "Test", "subcategory": "Test", "tags": []} {
"text": "Document without ID",
"category": "Test",
"subcategory": "Test",
"tags": [],
}
], ],
"embeddings": [ "embeddings": [[0.1, 0.2, 0.3, 0.4]],
[0.1, 0.2, 0.3, 0.4]
]
} }
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert result.error is None assert result.error is None
assert len(result.documents) == 1 assert len(result.documents) == 1
@@ -135,13 +140,19 @@ class TestClientEmbeddingsProcessing:
"""Test processing with invalid embedding format.""" """Test processing with invalid embedding format."""
client_data = { client_data = {
"documents": [ "documents": [
{"id": "test", "text": "Test document", "category": "Test", "subcategory": "Test", "tags": []} {
"id": "test",
"text": "Test document",
"category": "Test",
"subcategory": "Test",
"tags": [],
}
], ],
"embeddings": 0.5 # Scalar instead of array "embeddings": 0.5, # Scalar instead of array
} }
result = self.processor.process_client_embeddings(client_data) result = self.processor.process_client_embeddings(client_data)
assert isinstance(result, ProcessedData) assert isinstance(result, ProcessedData)
assert result.error is not None # Should have some error assert result.error is not None # Should have some error
assert len(result.documents) == 0 assert len(result.documents) == 0