add in browser embedding generation #4
@@ -3,7 +3,8 @@
 | 
			
		||||
    "allow": [
 | 
			
		||||
      "Bash(mkdir:*)",
 | 
			
		||||
      "Bash(uv run:*)",
 | 
			
		||||
      "Bash(uv add:*)"
 | 
			
		||||
      "Bash(uv add:*)",
 | 
			
		||||
      "Bash(uv sync:*)"
 | 
			
		||||
    ],
 | 
			
		||||
    "deny": [],
 | 
			
		||||
    "ask": [],
 | 
			
		||||
 
 | 
			
		||||
@@ -9,14 +9,13 @@ from .ui.callbacks.interactions import InteractionCallbacks
 | 
			
		||||
 | 
			
		||||
def create_app():
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    # Get the project root directory (two levels up from this file)
 | 
			
		||||
    project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
 | 
			
		||||
    assets_path = os.path.join(project_root, 'assets')
 | 
			
		||||
    
 | 
			
		||||
    assets_path = os.path.join(project_root, "assets")
 | 
			
		||||
 | 
			
		||||
    app = dash.Dash(
 | 
			
		||||
        __name__, 
 | 
			
		||||
        external_stylesheets=[dbc.themes.BOOTSTRAP],
 | 
			
		||||
        assets_folder=assets_path
 | 
			
		||||
        __name__, external_stylesheets=[dbc.themes.BOOTSTRAP], assets_folder=assets_path
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Allow callbacks to components that are dynamically created in tabs
 | 
			
		||||
 
 | 
			
		||||
@@ -558,17 +558,17 @@ class DataProcessingCallbacks:
 | 
			
		||||
        )
 | 
			
		||||
        def handle_text_input_actions(clear_clicks, load_clicks):
 | 
			
		||||
            from dash import ctx
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            if not ctx.triggered:
 | 
			
		||||
                return no_update
 | 
			
		||||
                
 | 
			
		||||
            button_id = ctx.triggered[0]['prop_id'].split('.')[0]
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            button_id = ctx.triggered[0]["prop_id"].split(".")[0]
 | 
			
		||||
 | 
			
		||||
            if button_id == "clear-text-btn" and clear_clicks:
 | 
			
		||||
                return ""
 | 
			
		||||
            elif button_id == "load-sample-btn" and load_clicks:
 | 
			
		||||
                return self._load_sample_text()
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            return no_update
 | 
			
		||||
 | 
			
		||||
        # Model info callback
 | 
			
		||||
@@ -648,15 +648,19 @@ class DataProcessingCallbacks:
 | 
			
		||||
    def _load_sample_text(self):
 | 
			
		||||
        """Load sample text from assets/sample-txt.md file."""
 | 
			
		||||
        import os
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # Get the project root directory (four levels up from this file)
 | 
			
		||||
            current_file = os.path.abspath(__file__)
 | 
			
		||||
            project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file)))))
 | 
			
		||||
            sample_file_path = os.path.join(project_root, 'assets', 'sample-txt.md')
 | 
			
		||||
            
 | 
			
		||||
            project_root = os.path.dirname(
 | 
			
		||||
                os.path.dirname(
 | 
			
		||||
                    os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            sample_file_path = os.path.join(project_root, "assets", "sample-txt.md")
 | 
			
		||||
 | 
			
		||||
            if os.path.exists(sample_file_path):
 | 
			
		||||
                with open(sample_file_path, 'r', encoding='utf-8') as file:
 | 
			
		||||
                with open(sample_file_path, "r", encoding="utf-8") as file:
 | 
			
		||||
                    return file.read()
 | 
			
		||||
            else:
 | 
			
		||||
                # Fallback sample text if file doesn't exist
 | 
			
		||||
@@ -677,8 +681,8 @@ She finely chopped the garlic and sautéed it in two tablespoons of olive oil.
 | 
			
		||||
A pinch of saffron adds a beautiful color and aroma to traditional paella.
 | 
			
		||||
If the soup is too salty, add a peeled potato to absorb excess sodium.
 | 
			
		||||
Let the bread dough rise for at least an hour in a warm, draft-free spot."""
 | 
			
		||||
                
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
 | 
			
		||||
        except Exception:
 | 
			
		||||
            # Return a simple fallback if there's any error
 | 
			
		||||
            return "This is sample text for testing embedding generation. You can replace this with your own text."
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,7 @@ class AppLayout:
 | 
			
		||||
                            window.transformersPipeline = pipeline;
 | 
			
		||||
                            console.log('✅ Transformers.js pipeline loaded globally');
 | 
			
		||||
                            """,
 | 
			
		||||
                            type="module"
 | 
			
		||||
                            type="module",
 | 
			
		||||
                        ),
 | 
			
		||||
                    ],
 | 
			
		||||
                    width=12,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,5 @@
 | 
			
		||||
"""Tests for client-side embedding processing functionality."""
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from src.embeddingbuddy.data.processor import DataProcessor
 | 
			
		||||
@@ -23,33 +22,30 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
                    "text": "First test document",
 | 
			
		||||
                    "category": "Text Input",
 | 
			
		||||
                    "subcategory": "Generated",
 | 
			
		||||
                    "tags": []
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "id": "text_input_1", 
 | 
			
		||||
                    "id": "text_input_1",
 | 
			
		||||
                    "text": "Second test document",
 | 
			
		||||
                    "category": "Text Input",
 | 
			
		||||
                    "subcategory": "Generated",
 | 
			
		||||
                    "tags": []
 | 
			
		||||
                }
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                },
 | 
			
		||||
            ],
 | 
			
		||||
            "embeddings": [
 | 
			
		||||
                [0.1, 0.2, 0.3, 0.4],
 | 
			
		||||
                [0.5, 0.6, 0.7, 0.8]
 | 
			
		||||
            ]
 | 
			
		||||
            "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert result.error is None
 | 
			
		||||
        assert len(result.documents) == 2
 | 
			
		||||
        assert result.embeddings.shape == (2, 4)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # Check document content
 | 
			
		||||
        assert result.documents[0].text == "First test document"
 | 
			
		||||
        assert result.documents[1].text == "Second test document"
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        # Check embeddings match
 | 
			
		||||
        np.testing.assert_array_equal(result.embeddings[0], [0.1, 0.2, 0.3, 0.4])
 | 
			
		||||
        np.testing.assert_array_equal(result.embeddings[1], [0.5, 0.6, 0.7, 0.8])
 | 
			
		||||
@@ -57,9 +53,9 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
    def test_process_client_embeddings_with_error(self):
 | 
			
		||||
        """Test processing client data with error."""
 | 
			
		||||
        client_data = {"error": "Transformers.js not loaded"}
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert result.error == "Transformers.js not loaded"
 | 
			
		||||
        assert len(result.documents) == 0
 | 
			
		||||
@@ -68,9 +64,9 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
    def test_process_client_embeddings_missing_data(self):
 | 
			
		||||
        """Test processing with missing documents or embeddings."""
 | 
			
		||||
        client_data = {"documents": []}
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert "No documents or embeddings in client data" in result.error
 | 
			
		||||
        assert len(result.documents) == 0
 | 
			
		||||
@@ -79,16 +75,19 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
        """Test processing with mismatched document and embedding counts."""
 | 
			
		||||
        client_data = {
 | 
			
		||||
            "documents": [
 | 
			
		||||
                {"id": "test", "text": "Test document", "category": "Test", "subcategory": "Test", "tags": []}
 | 
			
		||||
                {
 | 
			
		||||
                    "id": "test",
 | 
			
		||||
                    "text": "Test document",
 | 
			
		||||
                    "category": "Test",
 | 
			
		||||
                    "subcategory": "Test",
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "embeddings": [
 | 
			
		||||
                [0.1, 0.2, 0.3, 0.4],
 | 
			
		||||
                [0.5, 0.6, 0.7, 0.8]
 | 
			
		||||
            ]
 | 
			
		||||
            "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert "Mismatch between number of documents and embeddings" in result.error
 | 
			
		||||
        assert len(result.documents) == 0
 | 
			
		||||
@@ -98,16 +97,19 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
        client_data = {
 | 
			
		||||
            "documents": [
 | 
			
		||||
                {"text": ""},  # Empty text should be skipped
 | 
			
		||||
                {"id": "test2", "text": "Valid document", "category": "Test", "subcategory": "Test", "tags": []}
 | 
			
		||||
                {
 | 
			
		||||
                    "id": "test2",
 | 
			
		||||
                    "text": "Valid document",
 | 
			
		||||
                    "category": "Test",
 | 
			
		||||
                    "subcategory": "Test",
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                },
 | 
			
		||||
            ],
 | 
			
		||||
            "embeddings": [
 | 
			
		||||
                [0.1, 0.2, 0.3, 0.4],
 | 
			
		||||
                [0.5, 0.6, 0.7, 0.8]
 | 
			
		||||
            ]
 | 
			
		||||
            "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert result.error is None
 | 
			
		||||
        assert len(result.documents) == 1  # Only valid document should be processed
 | 
			
		||||
@@ -117,15 +119,18 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
        """Test automatic ID generation for documents without IDs."""
 | 
			
		||||
        client_data = {
 | 
			
		||||
            "documents": [
 | 
			
		||||
                {"text": "Document without ID", "category": "Test", "subcategory": "Test", "tags": []}
 | 
			
		||||
                {
 | 
			
		||||
                    "text": "Document without ID",
 | 
			
		||||
                    "category": "Test",
 | 
			
		||||
                    "subcategory": "Test",
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "embeddings": [
 | 
			
		||||
                [0.1, 0.2, 0.3, 0.4]
 | 
			
		||||
            ]
 | 
			
		||||
            "embeddings": [[0.1, 0.2, 0.3, 0.4]],
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert result.error is None
 | 
			
		||||
        assert len(result.documents) == 1
 | 
			
		||||
@@ -135,13 +140,19 @@ class TestClientEmbeddingsProcessing:
 | 
			
		||||
        """Test processing with invalid embedding format."""
 | 
			
		||||
        client_data = {
 | 
			
		||||
            "documents": [
 | 
			
		||||
                {"id": "test", "text": "Test document", "category": "Test", "subcategory": "Test", "tags": []}
 | 
			
		||||
                {
 | 
			
		||||
                    "id": "test",
 | 
			
		||||
                    "text": "Test document",
 | 
			
		||||
                    "category": "Test",
 | 
			
		||||
                    "subcategory": "Test",
 | 
			
		||||
                    "tags": [],
 | 
			
		||||
                }
 | 
			
		||||
            ],
 | 
			
		||||
            "embeddings": 0.5  # Scalar instead of array
 | 
			
		||||
            "embeddings": 0.5,  # Scalar instead of array
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = self.processor.process_client_embeddings(client_data)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert isinstance(result, ProcessedData)
 | 
			
		||||
        assert result.error is not None  # Should have some error
 | 
			
		||||
        assert len(result.documents) == 0
 | 
			
		||||
        assert len(result.documents) == 0
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user