add ci workflows #1
@@ -1,3 +1,3 @@
 | 
			
		||||
"""EmbeddingBuddy - Interactive exploration and visualization of embedding vectors."""
 | 
			
		||||
 | 
			
		||||
__version__ = "0.1.0"
 | 
			
		||||
__version__ = "0.1.0"
 | 
			
		||||
 
 | 
			
		||||
@@ -8,32 +8,29 @@ from .ui.callbacks.interactions import InteractionCallbacks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_app():
 | 
			
		||||
    app = dash.Dash(
 | 
			
		||||
        __name__, 
 | 
			
		||||
        external_stylesheets=[dbc.themes.BOOTSTRAP]
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
 | 
			
		||||
 | 
			
		||||
    layout_manager = AppLayout()
 | 
			
		||||
    app.layout = layout_manager.create_layout()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    DataProcessingCallbacks()
 | 
			
		||||
    VisualizationCallbacks()
 | 
			
		||||
    InteractionCallbacks()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    return app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_app(app=None, debug=None, host=None, port=None):
 | 
			
		||||
    if app is None:
 | 
			
		||||
        app = create_app()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    app.run(
 | 
			
		||||
        debug=debug if debug is not None else AppSettings.DEBUG,
 | 
			
		||||
        host=host if host is not None else AppSettings.HOST,
 | 
			
		||||
        port=port if port is not None else AppSettings.PORT
 | 
			
		||||
        port=port if port is not None else AppSettings.PORT,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    app = create_app()
 | 
			
		||||
    run_app(app)
 | 
			
		||||
    run_app(app)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,105 +3,100 @@ import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AppSettings:
 | 
			
		||||
    
 | 
			
		||||
    # UI Configuration
 | 
			
		||||
    UPLOAD_STYLE = {
 | 
			
		||||
        'width': '100%',
 | 
			
		||||
        'height': '60px',
 | 
			
		||||
        'lineHeight': '60px',
 | 
			
		||||
        'borderWidth': '1px',
 | 
			
		||||
        'borderStyle': 'dashed',
 | 
			
		||||
        'borderRadius': '5px',
 | 
			
		||||
        'textAlign': 'center',
 | 
			
		||||
        'margin-bottom': '20px'
 | 
			
		||||
        "width": "100%",
 | 
			
		||||
        "height": "60px",
 | 
			
		||||
        "lineHeight": "60px",
 | 
			
		||||
        "borderWidth": "1px",
 | 
			
		||||
        "borderStyle": "dashed",
 | 
			
		||||
        "borderRadius": "5px",
 | 
			
		||||
        "textAlign": "center",
 | 
			
		||||
        "margin-bottom": "20px",
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    PROMPTS_UPLOAD_STYLE = {
 | 
			
		||||
        **UPLOAD_STYLE,
 | 
			
		||||
        'borderColor': '#28a745'
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    PLOT_CONFIG = {
 | 
			
		||||
        'responsive': True,
 | 
			
		||||
        'displayModeBar': True
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    PLOT_STYLE = {
 | 
			
		||||
        'height': '85vh',
 | 
			
		||||
        'width': '100%'
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    PROMPTS_UPLOAD_STYLE = {**UPLOAD_STYLE, "borderColor": "#28a745"}
 | 
			
		||||
 | 
			
		||||
    PLOT_CONFIG = {"responsive": True, "displayModeBar": True}
 | 
			
		||||
 | 
			
		||||
    PLOT_STYLE = {"height": "85vh", "width": "100%"}
 | 
			
		||||
 | 
			
		||||
    PLOT_LAYOUT_CONFIG = {
 | 
			
		||||
        'height': None,
 | 
			
		||||
        'autosize': True,
 | 
			
		||||
        'margin': dict(l=0, r=0, t=50, b=0)
 | 
			
		||||
        "height": None,
 | 
			
		||||
        "autosize": True,
 | 
			
		||||
        "margin": dict(l=0, r=0, t=50, b=0),
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    # Dimensionality Reduction Settings
 | 
			
		||||
    DEFAULT_N_COMPONENTS_3D = 3
 | 
			
		||||
    DEFAULT_N_COMPONENTS_2D = 2
 | 
			
		||||
    DEFAULT_RANDOM_STATE = 42
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    # Available Methods
 | 
			
		||||
    REDUCTION_METHODS = [
 | 
			
		||||
        {'label': 'PCA', 'value': 'pca'},
 | 
			
		||||
        {'label': 't-SNE', 'value': 'tsne'},
 | 
			
		||||
        {'label': 'UMAP', 'value': 'umap'}
 | 
			
		||||
        {"label": "PCA", "value": "pca"},
 | 
			
		||||
        {"label": "t-SNE", "value": "tsne"},
 | 
			
		||||
        {"label": "UMAP", "value": "umap"},
 | 
			
		||||
    ]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    COLOR_OPTIONS = [
 | 
			
		||||
        {'label': 'Category', 'value': 'category'},
 | 
			
		||||
        {'label': 'Subcategory', 'value': 'subcategory'},
 | 
			
		||||
        {'label': 'Tags', 'value': 'tags'}
 | 
			
		||||
        {"label": "Category", "value": "category"},
 | 
			
		||||
        {"label": "Subcategory", "value": "subcategory"},
 | 
			
		||||
        {"label": "Tags", "value": "tags"},
 | 
			
		||||
    ]
 | 
			
		||||
    
 | 
			
		||||
    DIMENSION_OPTIONS = [
 | 
			
		||||
        {'label': '2D', 'value': '2d'},
 | 
			
		||||
        {'label': '3D', 'value': '3d'}
 | 
			
		||||
    ]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    DIMENSION_OPTIONS = [{"label": "2D", "value": "2d"}, {"label": "3D", "value": "3d"}]
 | 
			
		||||
 | 
			
		||||
    # Default Values
 | 
			
		||||
    DEFAULT_METHOD = 'pca'
 | 
			
		||||
    DEFAULT_COLOR_BY = 'category'
 | 
			
		||||
    DEFAULT_DIMENSIONS = '3d'
 | 
			
		||||
    DEFAULT_SHOW_PROMPTS = ['show']
 | 
			
		||||
    
 | 
			
		||||
    DEFAULT_METHOD = "pca"
 | 
			
		||||
    DEFAULT_COLOR_BY = "category"
 | 
			
		||||
    DEFAULT_DIMENSIONS = "3d"
 | 
			
		||||
    DEFAULT_SHOW_PROMPTS = ["show"]
 | 
			
		||||
 | 
			
		||||
    # Plot Marker Settings
 | 
			
		||||
    DOCUMENT_MARKER_SIZE_2D = 8
 | 
			
		||||
    DOCUMENT_MARKER_SIZE_3D = 5
 | 
			
		||||
    PROMPT_MARKER_SIZE_2D = 10
 | 
			
		||||
    PROMPT_MARKER_SIZE_3D = 6
 | 
			
		||||
    
 | 
			
		||||
    DOCUMENT_MARKER_SYMBOL = 'circle'
 | 
			
		||||
    PROMPT_MARKER_SYMBOL = 'diamond'
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    DOCUMENT_MARKER_SYMBOL = "circle"
 | 
			
		||||
    PROMPT_MARKER_SYMBOL = "diamond"
 | 
			
		||||
 | 
			
		||||
    DOCUMENT_OPACITY = 1.0
 | 
			
		||||
    PROMPT_OPACITY = 0.8
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    # Text Processing
 | 
			
		||||
    TEXT_PREVIEW_LENGTH = 100
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    # App Configuration
 | 
			
		||||
    DEBUG = os.getenv('EMBEDDINGBUDDY_DEBUG', 'True').lower() == 'true'
 | 
			
		||||
    HOST = os.getenv('EMBEDDINGBUDDY_HOST', '127.0.0.1')
 | 
			
		||||
    PORT = int(os.getenv('EMBEDDINGBUDDY_PORT', '8050'))
 | 
			
		||||
    
 | 
			
		||||
    DEBUG = os.getenv("EMBEDDINGBUDDY_DEBUG", "True").lower() == "true"
 | 
			
		||||
    HOST = os.getenv("EMBEDDINGBUDDY_HOST", "127.0.0.1")
 | 
			
		||||
    PORT = int(os.getenv("EMBEDDINGBUDDY_PORT", "8050"))
 | 
			
		||||
 | 
			
		||||
    # Bootstrap Theme
 | 
			
		||||
    EXTERNAL_STYLESHEETS = ['https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css']
 | 
			
		||||
    
 | 
			
		||||
    EXTERNAL_STYLESHEETS = [
 | 
			
		||||
        "https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css"
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_plot_marker_config(cls, dimensions: str, is_prompt: bool = False) -> Dict[str, Any]:
 | 
			
		||||
    def get_plot_marker_config(
 | 
			
		||||
        cls, dimensions: str, is_prompt: bool = False
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        if is_prompt:
 | 
			
		||||
            size = cls.PROMPT_MARKER_SIZE_3D if dimensions == '3d' else cls.PROMPT_MARKER_SIZE_2D
 | 
			
		||||
            size = (
 | 
			
		||||
                cls.PROMPT_MARKER_SIZE_3D
 | 
			
		||||
                if dimensions == "3d"
 | 
			
		||||
                else cls.PROMPT_MARKER_SIZE_2D
 | 
			
		||||
            )
 | 
			
		||||
            symbol = cls.PROMPT_MARKER_SYMBOL
 | 
			
		||||
            opacity = cls.PROMPT_OPACITY
 | 
			
		||||
        else:
 | 
			
		||||
            size = cls.DOCUMENT_MARKER_SIZE_3D if dimensions == '3d' else cls.DOCUMENT_MARKER_SIZE_2D
 | 
			
		||||
            size = (
 | 
			
		||||
                cls.DOCUMENT_MARKER_SIZE_3D
 | 
			
		||||
                if dimensions == "3d"
 | 
			
		||||
                else cls.DOCUMENT_MARKER_SIZE_2D
 | 
			
		||||
            )
 | 
			
		||||
            symbol = cls.DOCUMENT_MARKER_SYMBOL
 | 
			
		||||
            opacity = cls.DOCUMENT_OPACITY
 | 
			
		||||
        
 | 
			
		||||
        return {
 | 
			
		||||
            'size': size,
 | 
			
		||||
            'symbol': symbol,
 | 
			
		||||
            'opacity': opacity
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return {"size": size, "symbol": symbol, "opacity": opacity}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,34 +6,33 @@ from ..models.schemas import Document
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NDJSONParser:
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def parse_upload_contents(contents: str) -> List[Document]:
 | 
			
		||||
        content_type, content_string = contents.split(',')
 | 
			
		||||
        content_type, content_string = contents.split(",")
 | 
			
		||||
        decoded = base64.b64decode(content_string)
 | 
			
		||||
        text_content = decoded.decode('utf-8')
 | 
			
		||||
        text_content = decoded.decode("utf-8")
 | 
			
		||||
        return NDJSONParser.parse_text(text_content)
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def parse_text(text_content: str) -> List[Document]:
 | 
			
		||||
        documents = []
 | 
			
		||||
        for line in text_content.strip().split('\n'):
 | 
			
		||||
        for line in text_content.strip().split("\n"):
 | 
			
		||||
            if line.strip():
 | 
			
		||||
                doc_dict = json.loads(line)
 | 
			
		||||
                doc = NDJSONParser._dict_to_document(doc_dict)
 | 
			
		||||
                documents.append(doc)
 | 
			
		||||
        return documents
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _dict_to_document(doc_dict: dict) -> Document:
 | 
			
		||||
        if 'id' not in doc_dict:
 | 
			
		||||
            doc_dict['id'] = str(uuid.uuid4())
 | 
			
		||||
        
 | 
			
		||||
        if "id" not in doc_dict:
 | 
			
		||||
            doc_dict["id"] = str(uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
        return Document(
 | 
			
		||||
            id=doc_dict['id'],
 | 
			
		||||
            text=doc_dict['text'],
 | 
			
		||||
            embedding=doc_dict['embedding'],
 | 
			
		||||
            category=doc_dict.get('category'),
 | 
			
		||||
            subcategory=doc_dict.get('subcategory'),
 | 
			
		||||
            tags=doc_dict.get('tags')
 | 
			
		||||
        )
 | 
			
		||||
            id=doc_dict["id"],
 | 
			
		||||
            text=doc_dict["text"],
 | 
			
		||||
            embedding=doc_dict["embedding"],
 | 
			
		||||
            category=doc_dict.get("category"),
 | 
			
		||||
            subcategory=doc_dict.get("subcategory"),
 | 
			
		||||
            tags=doc_dict.get("tags"),
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -5,18 +5,19 @@ from .parser import NDJSONParser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataProcessor:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.parser = NDJSONParser()
 | 
			
		||||
    
 | 
			
		||||
    def process_upload(self, contents: str, filename: Optional[str] = None) -> ProcessedData:
 | 
			
		||||
 | 
			
		||||
    def process_upload(
 | 
			
		||||
        self, contents: str, filename: Optional[str] = None
 | 
			
		||||
    ) -> ProcessedData:
 | 
			
		||||
        try:
 | 
			
		||||
            documents = self.parser.parse_upload_contents(contents)
 | 
			
		||||
            embeddings = self._extract_embeddings(documents)
 | 
			
		||||
            return ProcessedData(documents=documents, embeddings=embeddings)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def process_text(self, text_content: str) -> ProcessedData:
 | 
			
		||||
        try:
 | 
			
		||||
            documents = self.parser.parse_text(text_content)
 | 
			
		||||
@@ -24,31 +25,35 @@ class DataProcessor:
 | 
			
		||||
            return ProcessedData(documents=documents, embeddings=embeddings)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            return ProcessedData(documents=[], embeddings=np.array([]), error=str(e))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _extract_embeddings(self, documents: List[Document]) -> np.ndarray:
 | 
			
		||||
        if not documents:
 | 
			
		||||
            return np.array([])
 | 
			
		||||
        return np.array([doc.embedding for doc in documents])
 | 
			
		||||
    
 | 
			
		||||
    def combine_data(self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
 | 
			
		||||
 | 
			
		||||
    def combine_data(
 | 
			
		||||
        self, doc_data: ProcessedData, prompt_data: Optional[ProcessedData] = None
 | 
			
		||||
    ) -> Tuple[np.ndarray, List[Document], Optional[List[Document]]]:
 | 
			
		||||
        if not doc_data or doc_data.error:
 | 
			
		||||
            raise ValueError("Invalid document data")
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        all_embeddings = doc_data.embeddings
 | 
			
		||||
        documents = doc_data.documents
 | 
			
		||||
        prompts = None
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if prompt_data and not prompt_data.error and prompt_data.documents:
 | 
			
		||||
            all_embeddings = np.vstack([doc_data.embeddings, prompt_data.embeddings])
 | 
			
		||||
            prompts = prompt_data.documents
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return all_embeddings, documents, prompts
 | 
			
		||||
    
 | 
			
		||||
    def split_reduced_data(self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0) -> Tuple[np.ndarray, Optional[np.ndarray]]:
 | 
			
		||||
 | 
			
		||||
    def split_reduced_data(
 | 
			
		||||
        self, reduced_embeddings: np.ndarray, n_documents: int, n_prompts: int = 0
 | 
			
		||||
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
 | 
			
		||||
        doc_reduced = reduced_embeddings[:n_documents]
 | 
			
		||||
        prompt_reduced = None
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if n_prompts > 0:
 | 
			
		||||
            prompt_reduced = reduced_embeddings[n_documents:n_documents + n_prompts]
 | 
			
		||||
        
 | 
			
		||||
        return doc_reduced, prompt_reduced
 | 
			
		||||
            prompt_reduced = reduced_embeddings[n_documents : n_documents + n_prompts]
 | 
			
		||||
 | 
			
		||||
        return doc_reduced, prompt_reduced
 | 
			
		||||
 
 | 
			
		||||
@@ -7,88 +7,89 @@ from .schemas import ReducedData
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DimensionalityReducer(ABC):
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, n_components: int = 3, random_state: int = 42):
 | 
			
		||||
        self.n_components = n_components
 | 
			
		||||
        self.random_state = random_state
 | 
			
		||||
        self._reducer = None
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
 | 
			
		||||
        pass
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_method_name(self) -> str:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PCAReducer(DimensionalityReducer):
 | 
			
		||||
    
 | 
			
		||||
    def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
 | 
			
		||||
        self._reducer = PCA(n_components=self.n_components)
 | 
			
		||||
        reduced = self._reducer.fit_transform(embeddings)
 | 
			
		||||
        variance_explained = self._reducer.explained_variance_ratio_
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return ReducedData(
 | 
			
		||||
            reduced_embeddings=reduced,
 | 
			
		||||
            variance_explained=variance_explained,
 | 
			
		||||
            method=self.get_method_name(),
 | 
			
		||||
            n_components=self.n_components
 | 
			
		||||
            n_components=self.n_components,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def get_method_name(self) -> str:
 | 
			
		||||
        return "PCA"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TSNEReducer(DimensionalityReducer):
 | 
			
		||||
    
 | 
			
		||||
    def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
 | 
			
		||||
        self._reducer = TSNE(n_components=self.n_components, random_state=self.random_state)
 | 
			
		||||
        self._reducer = TSNE(
 | 
			
		||||
            n_components=self.n_components, random_state=self.random_state
 | 
			
		||||
        )
 | 
			
		||||
        reduced = self._reducer.fit(embeddings)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return ReducedData(
 | 
			
		||||
            reduced_embeddings=reduced,
 | 
			
		||||
            variance_explained=None,
 | 
			
		||||
            method=self.get_method_name(),
 | 
			
		||||
            n_components=self.n_components
 | 
			
		||||
            n_components=self.n_components,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def get_method_name(self) -> str:
 | 
			
		||||
        return "t-SNE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UMAPReducer(DimensionalityReducer):
 | 
			
		||||
    
 | 
			
		||||
    def fit_transform(self, embeddings: np.ndarray) -> ReducedData:
 | 
			
		||||
        self._reducer = umap.UMAP(n_components=self.n_components, random_state=self.random_state)
 | 
			
		||||
        self._reducer = umap.UMAP(
 | 
			
		||||
            n_components=self.n_components, random_state=self.random_state
 | 
			
		||||
        )
 | 
			
		||||
        reduced = self._reducer.fit_transform(embeddings)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return ReducedData(
 | 
			
		||||
            reduced_embeddings=reduced,
 | 
			
		||||
            variance_explained=None,
 | 
			
		||||
            method=self.get_method_name(),
 | 
			
		||||
            n_components=self.n_components
 | 
			
		||||
            n_components=self.n_components,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def get_method_name(self) -> str:
 | 
			
		||||
        return "UMAP"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReducerFactory:
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_reducer(method: str, n_components: int = 3, random_state: int = 42) -> DimensionalityReducer:
 | 
			
		||||
    def create_reducer(
 | 
			
		||||
        method: str, n_components: int = 3, random_state: int = 42
 | 
			
		||||
    ) -> DimensionalityReducer:
 | 
			
		||||
        method_lower = method.lower()
 | 
			
		||||
        
 | 
			
		||||
        if method_lower == 'pca':
 | 
			
		||||
 | 
			
		||||
        if method_lower == "pca":
 | 
			
		||||
            return PCAReducer(n_components=n_components, random_state=random_state)
 | 
			
		||||
        elif method_lower == 'tsne':
 | 
			
		||||
        elif method_lower == "tsne":
 | 
			
		||||
            return TSNEReducer(n_components=n_components, random_state=random_state)
 | 
			
		||||
        elif method_lower == 'umap':
 | 
			
		||||
        elif method_lower == "umap":
 | 
			
		||||
            return UMAPReducer(n_components=n_components, random_state=random_state)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown reduction method: {method}")
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_available_methods() -> list:
 | 
			
		||||
        return ['pca', 'tsne', 'umap']
 | 
			
		||||
        return ["pca", "tsne", "umap"]
 | 
			
		||||
 
 | 
			
		||||
@@ -50,9 +50,11 @@ class PlotData:
 | 
			
		||||
    coordinates: np.ndarray
 | 
			
		||||
    prompts: Optional[List[Document]] = None
 | 
			
		||||
    prompt_coordinates: Optional[np.ndarray] = None
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if not isinstance(self.coordinates, np.ndarray):
 | 
			
		||||
            self.coordinates = np.array(self.coordinates)
 | 
			
		||||
        if self.prompt_coordinates is not None and not isinstance(self.prompt_coordinates, np.ndarray):
 | 
			
		||||
            self.prompt_coordinates = np.array(self.prompt_coordinates)
 | 
			
		||||
        if self.prompt_coordinates is not None and not isinstance(
 | 
			
		||||
            self.prompt_coordinates, np.ndarray
 | 
			
		||||
        ):
 | 
			
		||||
            self.prompt_coordinates = np.array(self.prompt_coordinates)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,58 +3,60 @@ from ...data.processor import DataProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataProcessingCallbacks:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.processor = DataProcessor()
 | 
			
		||||
        self._register_callbacks()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _register_callbacks(self):
 | 
			
		||||
        
 | 
			
		||||
        @callback(
 | 
			
		||||
            Output('processed-data', 'data'),
 | 
			
		||||
            Input('upload-data', 'contents'),
 | 
			
		||||
            State('upload-data', 'filename')
 | 
			
		||||
            Output("processed-data", "data"),
 | 
			
		||||
            Input("upload-data", "contents"),
 | 
			
		||||
            State("upload-data", "filename"),
 | 
			
		||||
        )
 | 
			
		||||
        def process_uploaded_file(contents, filename):
 | 
			
		||||
            if contents is None:
 | 
			
		||||
                return None
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            processed_data = self.processor.process_upload(contents, filename)
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            if processed_data.error:
 | 
			
		||||
                return {'error': processed_data.error}
 | 
			
		||||
            
 | 
			
		||||
                return {"error": processed_data.error}
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                'documents': [self._document_to_dict(doc) for doc in processed_data.documents],
 | 
			
		||||
                'embeddings': processed_data.embeddings.tolist()
 | 
			
		||||
                "documents": [
 | 
			
		||||
                    self._document_to_dict(doc) for doc in processed_data.documents
 | 
			
		||||
                ],
 | 
			
		||||
                "embeddings": processed_data.embeddings.tolist(),
 | 
			
		||||
            }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        @callback(
 | 
			
		||||
            Output('processed-prompts', 'data'),
 | 
			
		||||
            Input('upload-prompts', 'contents'),
 | 
			
		||||
            State('upload-prompts', 'filename')
 | 
			
		||||
            Output("processed-prompts", "data"),
 | 
			
		||||
            Input("upload-prompts", "contents"),
 | 
			
		||||
            State("upload-prompts", "filename"),
 | 
			
		||||
        )
 | 
			
		||||
        def process_uploaded_prompts(contents, filename):
 | 
			
		||||
            if contents is None:
 | 
			
		||||
                return None
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            processed_data = self.processor.process_upload(contents, filename)
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            if processed_data.error:
 | 
			
		||||
                return {'error': processed_data.error}
 | 
			
		||||
            
 | 
			
		||||
                return {"error": processed_data.error}
 | 
			
		||||
 | 
			
		||||
            return {
 | 
			
		||||
                'prompts': [self._document_to_dict(doc) for doc in processed_data.documents],
 | 
			
		||||
                'embeddings': processed_data.embeddings.tolist()
 | 
			
		||||
                "prompts": [
 | 
			
		||||
                    self._document_to_dict(doc) for doc in processed_data.documents
 | 
			
		||||
                ],
 | 
			
		||||
                "embeddings": processed_data.embeddings.tolist(),
 | 
			
		||||
            }
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _document_to_dict(doc):
 | 
			
		||||
        return {
 | 
			
		||||
            'id': doc.id,
 | 
			
		||||
            'text': doc.text,
 | 
			
		||||
            'embedding': doc.embedding,
 | 
			
		||||
            'category': doc.category,
 | 
			
		||||
            'subcategory': doc.subcategory,
 | 
			
		||||
            'tags': doc.tags
 | 
			
		||||
        }
 | 
			
		||||
            "id": doc.id,
 | 
			
		||||
            "text": doc.text,
 | 
			
		||||
            "embedding": doc.embedding,
 | 
			
		||||
            "category": doc.category,
 | 
			
		||||
            "subcategory": doc.subcategory,
 | 
			
		||||
            "tags": doc.tags,
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -4,63 +4,79 @@ import dash_bootstrap_components as dbc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InteractionCallbacks:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._register_callbacks()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _register_callbacks(self):
 | 
			
		||||
        
 | 
			
		||||
        @callback(
 | 
			
		||||
            Output('point-details', 'children'),
 | 
			
		||||
            Input('embedding-plot', 'clickData'),
 | 
			
		||||
            [State('processed-data', 'data'),
 | 
			
		||||
             State('processed-prompts', 'data')]
 | 
			
		||||
            Output("point-details", "children"),
 | 
			
		||||
            Input("embedding-plot", "clickData"),
 | 
			
		||||
            [State("processed-data", "data"), State("processed-prompts", "data")],
 | 
			
		||||
        )
 | 
			
		||||
        def display_click_data(clickData, data, prompts_data):
 | 
			
		||||
            if not clickData or not data:
 | 
			
		||||
                return "Click on a point to see details"
 | 
			
		||||
            
 | 
			
		||||
            point_data = clickData['points'][0]
 | 
			
		||||
            trace_name = point_data.get('fullData', {}).get('name', 'Documents')
 | 
			
		||||
            
 | 
			
		||||
            if 'pointIndex' in point_data:
 | 
			
		||||
                point_index = point_data['pointIndex']
 | 
			
		||||
            elif 'pointNumber' in point_data:
 | 
			
		||||
                point_index = point_data['pointNumber']
 | 
			
		||||
 | 
			
		||||
            point_data = clickData["points"][0]
 | 
			
		||||
            trace_name = point_data.get("fullData", {}).get("name", "Documents")
 | 
			
		||||
 | 
			
		||||
            if "pointIndex" in point_data:
 | 
			
		||||
                point_index = point_data["pointIndex"]
 | 
			
		||||
            elif "pointNumber" in point_data:
 | 
			
		||||
                point_index = point_data["pointNumber"]
 | 
			
		||||
            else:
 | 
			
		||||
                return "Could not identify clicked point"
 | 
			
		||||
            
 | 
			
		||||
            if trace_name.startswith('Prompts') and prompts_data and 'prompts' in prompts_data:
 | 
			
		||||
                item = prompts_data['prompts'][point_index]
 | 
			
		||||
                item_type = 'Prompt'
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                trace_name.startswith("Prompts")
 | 
			
		||||
                and prompts_data
 | 
			
		||||
                and "prompts" in prompts_data
 | 
			
		||||
            ):
 | 
			
		||||
                item = prompts_data["prompts"][point_index]
 | 
			
		||||
                item_type = "Prompt"
 | 
			
		||||
            else:
 | 
			
		||||
                item = data['documents'][point_index]
 | 
			
		||||
                item_type = 'Document'
 | 
			
		||||
            
 | 
			
		||||
                item = data["documents"][point_index]
 | 
			
		||||
                item_type = "Document"
 | 
			
		||||
 | 
			
		||||
            return self._create_detail_card(item, item_type)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        @callback(
 | 
			
		||||
            [Output('processed-data', 'data', allow_duplicate=True),
 | 
			
		||||
             Output('processed-prompts', 'data', allow_duplicate=True),
 | 
			
		||||
             Output('point-details', 'children', allow_duplicate=True)],
 | 
			
		||||
            Input('reset-button', 'n_clicks'),
 | 
			
		||||
            prevent_initial_call=True
 | 
			
		||||
            [
 | 
			
		||||
                Output("processed-data", "data", allow_duplicate=True),
 | 
			
		||||
                Output("processed-prompts", "data", allow_duplicate=True),
 | 
			
		||||
                Output("point-details", "children", allow_duplicate=True),
 | 
			
		||||
            ],
 | 
			
		||||
            Input("reset-button", "n_clicks"),
 | 
			
		||||
            prevent_initial_call=True,
 | 
			
		||||
        )
 | 
			
		||||
        def reset_data(n_clicks):
 | 
			
		||||
            if n_clicks is None or n_clicks == 0:
 | 
			
		||||
                return dash.no_update, dash.no_update, dash.no_update
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            return None, None, "Click on a point to see details"
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _create_detail_card(item, item_type):
 | 
			
		||||
        return dbc.Card([
 | 
			
		||||
            dbc.CardBody([
 | 
			
		||||
                html.H5(f"{item_type}: {item['id']}", className="card-title"),
 | 
			
		||||
                html.P(f"Text: {item['text']}", className="card-text"),
 | 
			
		||||
                html.P(f"Category: {item.get('category', 'Unknown')}", className="card-text"),
 | 
			
		||||
                html.P(f"Subcategory: {item.get('subcategory', 'Unknown')}", className="card-text"),
 | 
			
		||||
                html.P(f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}", className="card-text"),
 | 
			
		||||
                html.P(f"Type: {item_type}", className="card-text text-muted")
 | 
			
		||||
            ])
 | 
			
		||||
        ])
 | 
			
		||||
        return dbc.Card(
 | 
			
		||||
            [
 | 
			
		||||
                dbc.CardBody(
 | 
			
		||||
                    [
 | 
			
		||||
                        html.H5(f"{item_type}: {item['id']}", className="card-title"),
 | 
			
		||||
                        html.P(f"Text: {item['text']}", className="card-text"),
 | 
			
		||||
                        html.P(
 | 
			
		||||
                            f"Category: {item.get('category', 'Unknown')}",
 | 
			
		||||
                            className="card-text",
 | 
			
		||||
                        ),
 | 
			
		||||
                        html.P(
 | 
			
		||||
                            f"Subcategory: {item.get('subcategory', 'Unknown')}",
 | 
			
		||||
                            className="card-text",
 | 
			
		||||
                        ),
 | 
			
		||||
                        html.P(
 | 
			
		||||
                            f"Tags: {', '.join(item.get('tags', [])) if item.get('tags') else 'None'}",
 | 
			
		||||
                            className="card-text",
 | 
			
		||||
                        ),
 | 
			
		||||
                        html.P(f"Type: {item_type}", className="card-text text-muted"),
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -7,81 +7,102 @@ from ...visualization.plots import PlotFactory
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VisualizationCallbacks:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.plot_factory = PlotFactory()
 | 
			
		||||
        self._register_callbacks()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _register_callbacks(self):
 | 
			
		||||
        
 | 
			
		||||
        @callback(
 | 
			
		||||
            Output('embedding-plot', 'figure'),
 | 
			
		||||
            [Input('processed-data', 'data'),
 | 
			
		||||
             Input('processed-prompts', 'data'),
 | 
			
		||||
             Input('method-dropdown', 'value'),
 | 
			
		||||
             Input('color-dropdown', 'value'),
 | 
			
		||||
             Input('dimension-toggle', 'value'),
 | 
			
		||||
             Input('show-prompts-toggle', 'value')]
 | 
			
		||||
            Output("embedding-plot", "figure"),
 | 
			
		||||
            [
 | 
			
		||||
                Input("processed-data", "data"),
 | 
			
		||||
                Input("processed-prompts", "data"),
 | 
			
		||||
                Input("method-dropdown", "value"),
 | 
			
		||||
                Input("color-dropdown", "value"),
 | 
			
		||||
                Input("dimension-toggle", "value"),
 | 
			
		||||
                Input("show-prompts-toggle", "value"),
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
        def update_plot(data, prompts_data, method, color_by, dimensions, show_prompts):
 | 
			
		||||
            if not data or 'error' in data:
 | 
			
		||||
            if not data or "error" in data:
 | 
			
		||||
                return go.Figure().add_annotation(
 | 
			
		||||
                    text="Upload a valid NDJSON file to see visualization",
 | 
			
		||||
                    xref="paper", yref="paper",
 | 
			
		||||
                    x=0.5, y=0.5, xanchor='center', yanchor='middle',
 | 
			
		||||
                    showarrow=False, font=dict(size=16)
 | 
			
		||||
                    xref="paper",
 | 
			
		||||
                    yref="paper",
 | 
			
		||||
                    x=0.5,
 | 
			
		||||
                    y=0.5,
 | 
			
		||||
                    xanchor="center",
 | 
			
		||||
                    yanchor="middle",
 | 
			
		||||
                    showarrow=False,
 | 
			
		||||
                    font=dict(size=16),
 | 
			
		||||
                )
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                doc_embeddings = np.array(data['embeddings'])
 | 
			
		||||
                doc_embeddings = np.array(data["embeddings"])
 | 
			
		||||
                all_embeddings = doc_embeddings
 | 
			
		||||
                has_prompts = prompts_data and 'error' not in prompts_data and prompts_data.get('prompts')
 | 
			
		||||
                
 | 
			
		||||
                has_prompts = (
 | 
			
		||||
                    prompts_data
 | 
			
		||||
                    and "error" not in prompts_data
 | 
			
		||||
                    and prompts_data.get("prompts")
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if has_prompts:
 | 
			
		||||
                    prompt_embeddings = np.array(prompts_data['embeddings'])
 | 
			
		||||
                    prompt_embeddings = np.array(prompts_data["embeddings"])
 | 
			
		||||
                    all_embeddings = np.vstack([doc_embeddings, prompt_embeddings])
 | 
			
		||||
                
 | 
			
		||||
                n_components = 3 if dimensions == '3d' else 2
 | 
			
		||||
                
 | 
			
		||||
                reducer = ReducerFactory.create_reducer(method, n_components=n_components)
 | 
			
		||||
 | 
			
		||||
                n_components = 3 if dimensions == "3d" else 2
 | 
			
		||||
 | 
			
		||||
                reducer = ReducerFactory.create_reducer(
 | 
			
		||||
                    method, n_components=n_components
 | 
			
		||||
                )
 | 
			
		||||
                reduced_data = reducer.fit_transform(all_embeddings)
 | 
			
		||||
                
 | 
			
		||||
                doc_reduced = reduced_data.reduced_embeddings[:len(doc_embeddings)]
 | 
			
		||||
 | 
			
		||||
                doc_reduced = reduced_data.reduced_embeddings[: len(doc_embeddings)]
 | 
			
		||||
                prompt_reduced = None
 | 
			
		||||
                if has_prompts:
 | 
			
		||||
                    prompt_reduced = reduced_data.reduced_embeddings[len(doc_embeddings):]
 | 
			
		||||
                
 | 
			
		||||
                documents = [self._dict_to_document(doc) for doc in data['documents']]
 | 
			
		||||
                    prompt_reduced = reduced_data.reduced_embeddings[
 | 
			
		||||
                        len(doc_embeddings) :
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                documents = [self._dict_to_document(doc) for doc in data["documents"]]
 | 
			
		||||
                prompts = None
 | 
			
		||||
                if has_prompts:
 | 
			
		||||
                    prompts = [self._dict_to_document(prompt) for prompt in prompts_data['prompts']]
 | 
			
		||||
                
 | 
			
		||||
                    prompts = [
 | 
			
		||||
                        self._dict_to_document(prompt)
 | 
			
		||||
                        for prompt in prompts_data["prompts"]
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                plot_data = PlotData(
 | 
			
		||||
                    documents=documents,
 | 
			
		||||
                    coordinates=doc_reduced,
 | 
			
		||||
                    prompts=prompts,
 | 
			
		||||
                    prompt_coordinates=prompt_reduced
 | 
			
		||||
                    prompt_coordinates=prompt_reduced,
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
                return self.plot_factory.create_plot(
 | 
			
		||||
                    plot_data, dimensions, color_by, reduced_data.method, show_prompts
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                return go.Figure().add_annotation(
 | 
			
		||||
                    text=f"Error creating visualization: {str(e)}",
 | 
			
		||||
                    xref="paper", yref="paper",
 | 
			
		||||
                    x=0.5, y=0.5, xanchor='center', yanchor='middle',
 | 
			
		||||
                    showarrow=False, font=dict(size=16)
 | 
			
		||||
                    xref="paper",
 | 
			
		||||
                    yref="paper",
 | 
			
		||||
                    x=0.5,
 | 
			
		||||
                    y=0.5,
 | 
			
		||||
                    xanchor="center",
 | 
			
		||||
                    yanchor="middle",
 | 
			
		||||
                    showarrow=False,
 | 
			
		||||
                    font=dict(size=16),
 | 
			
		||||
                )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _dict_to_document(doc_dict):
 | 
			
		||||
        return Document(
 | 
			
		||||
            id=doc_dict['id'],
 | 
			
		||||
            text=doc_dict['text'],
 | 
			
		||||
            embedding=doc_dict['embedding'],
 | 
			
		||||
            category=doc_dict.get('category'),
 | 
			
		||||
            subcategory=doc_dict.get('subcategory'),
 | 
			
		||||
            tags=doc_dict.get('tags', [])
 | 
			
		||||
        )
 | 
			
		||||
            id=doc_dict["id"],
 | 
			
		||||
            text=doc_dict["text"],
 | 
			
		||||
            embedding=doc_dict["embedding"],
 | 
			
		||||
            category=doc_dict.get("category"),
 | 
			
		||||
            subcategory=doc_dict.get("subcategory"),
 | 
			
		||||
            tags=doc_dict.get("tags", []),
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -4,79 +4,81 @@ from .upload import UploadComponent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SidebarComponent:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.upload_component = UploadComponent()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def create_layout(self):
 | 
			
		||||
        return dbc.Col([
 | 
			
		||||
            html.H5("Upload Data", className="mb-3"),
 | 
			
		||||
            self.upload_component.create_data_upload(),
 | 
			
		||||
            self.upload_component.create_prompts_upload(),
 | 
			
		||||
            self.upload_component.create_reset_button(),
 | 
			
		||||
            
 | 
			
		||||
            html.H5("Visualization Controls", className="mb-3"),
 | 
			
		||||
            self._create_method_dropdown(),
 | 
			
		||||
            self._create_color_dropdown(),
 | 
			
		||||
            self._create_dimension_toggle(),
 | 
			
		||||
            self._create_prompts_toggle(),
 | 
			
		||||
            
 | 
			
		||||
            html.H5("Point Details", className="mb-3"),
 | 
			
		||||
            html.Div(id='point-details', children="Click on a point to see details")
 | 
			
		||||
            
 | 
			
		||||
        ], width=3, style={'padding-right': '20px'})
 | 
			
		||||
    
 | 
			
		||||
        return dbc.Col(
 | 
			
		||||
            [
 | 
			
		||||
                html.H5("Upload Data", className="mb-3"),
 | 
			
		||||
                self.upload_component.create_data_upload(),
 | 
			
		||||
                self.upload_component.create_prompts_upload(),
 | 
			
		||||
                self.upload_component.create_reset_button(),
 | 
			
		||||
                html.H5("Visualization Controls", className="mb-3"),
 | 
			
		||||
                self._create_method_dropdown(),
 | 
			
		||||
                self._create_color_dropdown(),
 | 
			
		||||
                self._create_dimension_toggle(),
 | 
			
		||||
                self._create_prompts_toggle(),
 | 
			
		||||
                html.H5("Point Details", className="mb-3"),
 | 
			
		||||
                html.Div(
 | 
			
		||||
                    id="point-details", children="Click on a point to see details"
 | 
			
		||||
                ),
 | 
			
		||||
            ],
 | 
			
		||||
            width=3,
 | 
			
		||||
            style={"padding-right": "20px"},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_method_dropdown(self):
 | 
			
		||||
        return [
 | 
			
		||||
            dbc.Label("Method:"),
 | 
			
		||||
            dcc.Dropdown(
 | 
			
		||||
                id='method-dropdown',
 | 
			
		||||
                id="method-dropdown",
 | 
			
		||||
                options=[
 | 
			
		||||
                    {'label': 'PCA', 'value': 'pca'},
 | 
			
		||||
                    {'label': 't-SNE', 'value': 'tsne'},
 | 
			
		||||
                    {'label': 'UMAP', 'value': 'umap'}
 | 
			
		||||
                    {"label": "PCA", "value": "pca"},
 | 
			
		||||
                    {"label": "t-SNE", "value": "tsne"},
 | 
			
		||||
                    {"label": "UMAP", "value": "umap"},
 | 
			
		||||
                ],
 | 
			
		||||
                value='pca',
 | 
			
		||||
                style={'margin-bottom': '15px'}
 | 
			
		||||
            )
 | 
			
		||||
                value="pca",
 | 
			
		||||
                style={"margin-bottom": "15px"},
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _create_color_dropdown(self):
 | 
			
		||||
        return [
 | 
			
		||||
            dbc.Label("Color by:"),
 | 
			
		||||
            dcc.Dropdown(
 | 
			
		||||
                id='color-dropdown',
 | 
			
		||||
                id="color-dropdown",
 | 
			
		||||
                options=[
 | 
			
		||||
                    {'label': 'Category', 'value': 'category'},
 | 
			
		||||
                    {'label': 'Subcategory', 'value': 'subcategory'},
 | 
			
		||||
                    {'label': 'Tags', 'value': 'tags'}
 | 
			
		||||
                    {"label": "Category", "value": "category"},
 | 
			
		||||
                    {"label": "Subcategory", "value": "subcategory"},
 | 
			
		||||
                    {"label": "Tags", "value": "tags"},
 | 
			
		||||
                ],
 | 
			
		||||
                value='category',
 | 
			
		||||
                style={'margin-bottom': '15px'}
 | 
			
		||||
            )
 | 
			
		||||
                value="category",
 | 
			
		||||
                style={"margin-bottom": "15px"},
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _create_dimension_toggle(self):
 | 
			
		||||
        return [
 | 
			
		||||
            dbc.Label("Dimensions:"),
 | 
			
		||||
            dcc.RadioItems(
 | 
			
		||||
                id='dimension-toggle',
 | 
			
		||||
                id="dimension-toggle",
 | 
			
		||||
                options=[
 | 
			
		||||
                    {'label': '2D', 'value': '2d'},
 | 
			
		||||
                    {'label': '3D', 'value': '3d'}
 | 
			
		||||
                    {"label": "2D", "value": "2d"},
 | 
			
		||||
                    {"label": "3D", "value": "3d"},
 | 
			
		||||
                ],
 | 
			
		||||
                value='3d',
 | 
			
		||||
                style={'margin-bottom': '20px'}
 | 
			
		||||
            )
 | 
			
		||||
                value="3d",
 | 
			
		||||
                style={"margin-bottom": "20px"},
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def _create_prompts_toggle(self):
 | 
			
		||||
        return [
 | 
			
		||||
            dbc.Label("Show Prompts:"),
 | 
			
		||||
            dcc.Checklist(
 | 
			
		||||
                id='show-prompts-toggle',
 | 
			
		||||
                options=[{'label': 'Show prompts on plot', 'value': 'show'}],
 | 
			
		||||
                value=['show'],
 | 
			
		||||
                style={'margin-bottom': '20px'}
 | 
			
		||||
            )
 | 
			
		||||
        ]
 | 
			
		||||
                id="show-prompts-toggle",
 | 
			
		||||
                options=[{"label": "Show prompts on plot", "value": "show"}],
 | 
			
		||||
                value=["show"],
 | 
			
		||||
                style={"margin-bottom": "20px"},
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
 
 | 
			
		||||
@@ -3,58 +3,51 @@ import dash_bootstrap_components as dbc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UploadComponent:
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_data_upload():
 | 
			
		||||
        return dcc.Upload(
 | 
			
		||||
            id='upload-data',
 | 
			
		||||
            children=html.Div([
 | 
			
		||||
                'Drag and Drop or ',
 | 
			
		||||
                html.A('Select Files')
 | 
			
		||||
            ]),
 | 
			
		||||
            id="upload-data",
 | 
			
		||||
            children=html.Div(["Drag and Drop or ", html.A("Select Files")]),
 | 
			
		||||
            style={
 | 
			
		||||
                'width': '100%',
 | 
			
		||||
                'height': '60px',
 | 
			
		||||
                'lineHeight': '60px',
 | 
			
		||||
                'borderWidth': '1px',
 | 
			
		||||
                'borderStyle': 'dashed',
 | 
			
		||||
                'borderRadius': '5px',
 | 
			
		||||
                'textAlign': 'center',
 | 
			
		||||
                'margin-bottom': '20px'
 | 
			
		||||
                "width": "100%",
 | 
			
		||||
                "height": "60px",
 | 
			
		||||
                "lineHeight": "60px",
 | 
			
		||||
                "borderWidth": "1px",
 | 
			
		||||
                "borderStyle": "dashed",
 | 
			
		||||
                "borderRadius": "5px",
 | 
			
		||||
                "textAlign": "center",
 | 
			
		||||
                "margin-bottom": "20px",
 | 
			
		||||
            },
 | 
			
		||||
            multiple=False
 | 
			
		||||
            multiple=False,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_prompts_upload():
 | 
			
		||||
        return dcc.Upload(
 | 
			
		||||
            id='upload-prompts',
 | 
			
		||||
            children=html.Div([
 | 
			
		||||
                'Drag and Drop Prompts or ',
 | 
			
		||||
                html.A('Select Files')
 | 
			
		||||
            ]),
 | 
			
		||||
            id="upload-prompts",
 | 
			
		||||
            children=html.Div(["Drag and Drop Prompts or ", html.A("Select Files")]),
 | 
			
		||||
            style={
 | 
			
		||||
                'width': '100%',
 | 
			
		||||
                'height': '60px',
 | 
			
		||||
                'lineHeight': '60px',
 | 
			
		||||
                'borderWidth': '1px',
 | 
			
		||||
                'borderStyle': 'dashed',
 | 
			
		||||
                'borderRadius': '5px',
 | 
			
		||||
                'textAlign': 'center',
 | 
			
		||||
                'margin-bottom': '20px',
 | 
			
		||||
                'borderColor': '#28a745'
 | 
			
		||||
                "width": "100%",
 | 
			
		||||
                "height": "60px",
 | 
			
		||||
                "lineHeight": "60px",
 | 
			
		||||
                "borderWidth": "1px",
 | 
			
		||||
                "borderStyle": "dashed",
 | 
			
		||||
                "borderRadius": "5px",
 | 
			
		||||
                "textAlign": "center",
 | 
			
		||||
                "margin-bottom": "20px",
 | 
			
		||||
                "borderColor": "#28a745",
 | 
			
		||||
            },
 | 
			
		||||
            multiple=False
 | 
			
		||||
            multiple=False,
 | 
			
		||||
        )
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_reset_button():
 | 
			
		||||
        return dbc.Button(
 | 
			
		||||
            "Reset All Data",
 | 
			
		||||
            id='reset-button',
 | 
			
		||||
            color='danger',
 | 
			
		||||
            id="reset-button",
 | 
			
		||||
            color="danger",
 | 
			
		||||
            outline=True,
 | 
			
		||||
            size='sm',
 | 
			
		||||
            className='mb-3',
 | 
			
		||||
            style={'width': '100%'}
 | 
			
		||||
        )
 | 
			
		||||
            size="sm",
 | 
			
		||||
            className="mb-3",
 | 
			
		||||
            style={"width": "100%"},
 | 
			
		||||
        )
 | 
			
		||||
 
 | 
			
		||||
@@ -4,41 +4,43 @@ from .components.sidebar import SidebarComponent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AppLayout:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.sidebar = SidebarComponent()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def create_layout(self):
 | 
			
		||||
        return dbc.Container([
 | 
			
		||||
            self._create_header(),
 | 
			
		||||
            self._create_main_content(),
 | 
			
		||||
            self._create_stores()
 | 
			
		||||
        ], fluid=True)
 | 
			
		||||
    
 | 
			
		||||
        return dbc.Container(
 | 
			
		||||
            [self._create_header(), self._create_main_content(), self._create_stores()],
 | 
			
		||||
            fluid=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_header(self):
 | 
			
		||||
        return dbc.Row([
 | 
			
		||||
            dbc.Col([
 | 
			
		||||
                html.H1("EmbeddingBuddy", className="text-center mb-4"),
 | 
			
		||||
            ], width=12)
 | 
			
		||||
        ])
 | 
			
		||||
    
 | 
			
		||||
        return dbc.Row(
 | 
			
		||||
            [
 | 
			
		||||
                dbc.Col(
 | 
			
		||||
                    [
 | 
			
		||||
                        html.H1("EmbeddingBuddy", className="text-center mb-4"),
 | 
			
		||||
                    ],
 | 
			
		||||
                    width=12,
 | 
			
		||||
                )
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_main_content(self):
 | 
			
		||||
        return dbc.Row([
 | 
			
		||||
            self.sidebar.create_layout(),
 | 
			
		||||
            self._create_visualization_area()
 | 
			
		||||
        ])
 | 
			
		||||
    
 | 
			
		||||
        return dbc.Row(
 | 
			
		||||
            [self.sidebar.create_layout(), self._create_visualization_area()]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_visualization_area(self):
 | 
			
		||||
        return dbc.Col([
 | 
			
		||||
            dcc.Graph(
 | 
			
		||||
                id='embedding-plot',
 | 
			
		||||
                style={'height': '85vh', 'width': '100%'},
 | 
			
		||||
                config={'responsive': True, 'displayModeBar': True}
 | 
			
		||||
            )
 | 
			
		||||
        ], width=9)
 | 
			
		||||
    
 | 
			
		||||
        return dbc.Col(
 | 
			
		||||
            [
 | 
			
		||||
                dcc.Graph(
 | 
			
		||||
                    id="embedding-plot",
 | 
			
		||||
                    style={"height": "85vh", "width": "100%"},
 | 
			
		||||
                    config={"responsive": True, "displayModeBar": True},
 | 
			
		||||
                )
 | 
			
		||||
            ],
 | 
			
		||||
            width=9,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _create_stores(self):
 | 
			
		||||
        return [
 | 
			
		||||
            dcc.Store(id='processed-data'),
 | 
			
		||||
            dcc.Store(id='processed-prompts')
 | 
			
		||||
        ]
 | 
			
		||||
        return [dcc.Store(id="processed-data"), dcc.Store(id="processed-prompts")]
 | 
			
		||||
 
 | 
			
		||||
@@ -4,30 +4,33 @@ from ..models.schemas import Document
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColorMapper:
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_color_mapping(documents: List[Document], color_by: str) -> List[str]:
 | 
			
		||||
        if color_by == 'category':
 | 
			
		||||
        if color_by == "category":
 | 
			
		||||
            return [doc.category for doc in documents]
 | 
			
		||||
        elif color_by == 'subcategory':
 | 
			
		||||
        elif color_by == "subcategory":
 | 
			
		||||
            return [doc.subcategory for doc in documents]
 | 
			
		||||
        elif color_by == 'tags':
 | 
			
		||||
            return [', '.join(doc.tags) if doc.tags else 'No tags' for doc in documents]
 | 
			
		||||
        elif color_by == "tags":
 | 
			
		||||
            return [", ".join(doc.tags) if doc.tags else "No tags" for doc in documents]
 | 
			
		||||
        else:
 | 
			
		||||
            return ['All'] * len(documents)
 | 
			
		||||
    
 | 
			
		||||
            return ["All"] * len(documents)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def to_grayscale_hex(color_str: str) -> str:
 | 
			
		||||
        try:
 | 
			
		||||
            if color_str.startswith('#'):
 | 
			
		||||
                rgb = tuple(int(color_str[i:i+2], 16) for i in (1, 3, 5))
 | 
			
		||||
            if color_str.startswith("#"):
 | 
			
		||||
                rgb = tuple(int(color_str[i : i + 2], 16) for i in (1, 3, 5))
 | 
			
		||||
            else:
 | 
			
		||||
                rgb = pc.hex_to_rgb(pc.convert_colors_to_same_type([color_str], colortype='hex')[0][0])
 | 
			
		||||
            
 | 
			
		||||
                rgb = pc.hex_to_rgb(
 | 
			
		||||
                    pc.convert_colors_to_same_type([color_str], colortype="hex")[0][0]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            gray_value = int(0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2])
 | 
			
		||||
            gray_rgb = (gray_value * 0.7 + rgb[0] * 0.3, 
 | 
			
		||||
                       gray_value * 0.7 + rgb[1] * 0.3, 
 | 
			
		||||
                       gray_value * 0.7 + rgb[2] * 0.3)
 | 
			
		||||
            return f'rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})'
 | 
			
		||||
        except: # noqa: E722
 | 
			
		||||
            return 'rgb(128,128,128)'
 | 
			
		||||
            gray_rgb = (
 | 
			
		||||
                gray_value * 0.7 + rgb[0] * 0.3,
 | 
			
		||||
                gray_value * 0.7 + rgb[1] * 0.3,
 | 
			
		||||
                gray_value * 0.7 + rgb[2] * 0.3,
 | 
			
		||||
            )
 | 
			
		||||
            return f"rgb({int(gray_rgb[0])},{int(gray_rgb[1])},{int(gray_rgb[2])})"
 | 
			
		||||
        except:  # noqa: E722
 | 
			
		||||
            return "rgb(128,128,128)"
 | 
			
		||||
 
 | 
			
		||||
@@ -7,139 +7,172 @@ from .colors import ColorMapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotFactory:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.color_mapper = ColorMapper()
 | 
			
		||||
    
 | 
			
		||||
    def create_plot(self, plot_data: PlotData, dimensions: str = '3d', 
 | 
			
		||||
                   color_by: str = 'category', method: str = 'PCA',
 | 
			
		||||
                   show_prompts: Optional[List[str]] = None) -> go.Figure:
 | 
			
		||||
        
 | 
			
		||||
        if plot_data.prompts and show_prompts and 'show' in show_prompts:
 | 
			
		||||
 | 
			
		||||
    def create_plot(
 | 
			
		||||
        self,
 | 
			
		||||
        plot_data: PlotData,
 | 
			
		||||
        dimensions: str = "3d",
 | 
			
		||||
        color_by: str = "category",
 | 
			
		||||
        method: str = "PCA",
 | 
			
		||||
        show_prompts: Optional[List[str]] = None,
 | 
			
		||||
    ) -> go.Figure:
 | 
			
		||||
        if plot_data.prompts and show_prompts and "show" in show_prompts:
 | 
			
		||||
            return self._create_dual_plot(plot_data, dimensions, color_by, method)
 | 
			
		||||
        else:
 | 
			
		||||
            return self._create_single_plot(plot_data, dimensions, color_by, method)
 | 
			
		||||
    
 | 
			
		||||
    def _create_single_plot(self, plot_data: PlotData, dimensions: str, 
 | 
			
		||||
                           color_by: str, method: str) -> go.Figure:
 | 
			
		||||
        df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
 | 
			
		||||
        color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
 | 
			
		||||
        
 | 
			
		||||
        hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
 | 
			
		||||
        
 | 
			
		||||
        if dimensions == '3d':
 | 
			
		||||
 | 
			
		||||
    def _create_single_plot(
 | 
			
		||||
        self, plot_data: PlotData, dimensions: str, color_by: str, method: str
 | 
			
		||||
    ) -> go.Figure:
 | 
			
		||||
        df = self._prepare_dataframe(
 | 
			
		||||
            plot_data.documents, plot_data.coordinates, dimensions
 | 
			
		||||
        )
 | 
			
		||||
        color_values = self.color_mapper.create_color_mapping(
 | 
			
		||||
            plot_data.documents, color_by
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
 | 
			
		||||
 | 
			
		||||
        if dimensions == "3d":
 | 
			
		||||
            fig = px.scatter_3d(
 | 
			
		||||
                df, x='dim_1', y='dim_2', z='dim_3',
 | 
			
		||||
                df,
 | 
			
		||||
                x="dim_1",
 | 
			
		||||
                y="dim_2",
 | 
			
		||||
                z="dim_3",
 | 
			
		||||
                color=color_values,
 | 
			
		||||
                hover_data=hover_fields,
 | 
			
		||||
                title=f'3D Embedding Visualization - {method} (colored by {color_by})'
 | 
			
		||||
                title=f"3D Embedding Visualization - {method} (colored by {color_by})",
 | 
			
		||||
            )
 | 
			
		||||
            fig.update_traces(marker=dict(size=5))
 | 
			
		||||
        else:
 | 
			
		||||
            fig = px.scatter(
 | 
			
		||||
                df, x='dim_1', y='dim_2',
 | 
			
		||||
                df,
 | 
			
		||||
                x="dim_1",
 | 
			
		||||
                y="dim_2",
 | 
			
		||||
                color=color_values,
 | 
			
		||||
                hover_data=hover_fields,
 | 
			
		||||
                title=f'2D Embedding Visualization - {method} (colored by {color_by})'
 | 
			
		||||
                title=f"2D Embedding Visualization - {method} (colored by {color_by})",
 | 
			
		||||
            )
 | 
			
		||||
            fig.update_traces(marker=dict(size=8))
 | 
			
		||||
        
 | 
			
		||||
        fig.update_layout(
 | 
			
		||||
            height=None,
 | 
			
		||||
            autosize=True,
 | 
			
		||||
            margin=dict(l=0, r=0, t=50, b=0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        fig.update_layout(height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0))
 | 
			
		||||
        return fig
 | 
			
		||||
    
 | 
			
		||||
    def _create_dual_plot(self, plot_data: PlotData, dimensions: str, 
 | 
			
		||||
                         color_by: str, method: str) -> go.Figure:
 | 
			
		||||
 | 
			
		||||
    def _create_dual_plot(
 | 
			
		||||
        self, plot_data: PlotData, dimensions: str, color_by: str, method: str
 | 
			
		||||
    ) -> go.Figure:
 | 
			
		||||
        fig = go.Figure()
 | 
			
		||||
        
 | 
			
		||||
        doc_df = self._prepare_dataframe(plot_data.documents, plot_data.coordinates, dimensions)
 | 
			
		||||
        doc_color_values = self.color_mapper.create_color_mapping(plot_data.documents, color_by)
 | 
			
		||||
        
 | 
			
		||||
        hover_fields = ['id', 'text_preview', 'category', 'subcategory', 'tags_str']
 | 
			
		||||
        
 | 
			
		||||
        if dimensions == '3d':
 | 
			
		||||
 | 
			
		||||
        doc_df = self._prepare_dataframe(
 | 
			
		||||
            plot_data.documents, plot_data.coordinates, dimensions
 | 
			
		||||
        )
 | 
			
		||||
        doc_color_values = self.color_mapper.create_color_mapping(
 | 
			
		||||
            plot_data.documents, color_by
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hover_fields = ["id", "text_preview", "category", "subcategory", "tags_str"]
 | 
			
		||||
 | 
			
		||||
        if dimensions == "3d":
 | 
			
		||||
            doc_fig = px.scatter_3d(
 | 
			
		||||
                doc_df, x='dim_1', y='dim_2', z='dim_3',
 | 
			
		||||
                doc_df,
 | 
			
		||||
                x="dim_1",
 | 
			
		||||
                y="dim_2",
 | 
			
		||||
                z="dim_3",
 | 
			
		||||
                color=doc_color_values,
 | 
			
		||||
                hover_data=hover_fields
 | 
			
		||||
                hover_data=hover_fields,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            doc_fig = px.scatter(
 | 
			
		||||
                doc_df, x='dim_1', y='dim_2',
 | 
			
		||||
                doc_df,
 | 
			
		||||
                x="dim_1",
 | 
			
		||||
                y="dim_2",
 | 
			
		||||
                color=doc_color_values,
 | 
			
		||||
                hover_data=hover_fields
 | 
			
		||||
                hover_data=hover_fields,
 | 
			
		||||
            )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        for trace in doc_fig.data:
 | 
			
		||||
            trace.name = f'Documents - {trace.name}'
 | 
			
		||||
            if dimensions == '3d':
 | 
			
		||||
            trace.name = f"Documents - {trace.name}"
 | 
			
		||||
            if dimensions == "3d":
 | 
			
		||||
                trace.marker.size = 5
 | 
			
		||||
                trace.marker.symbol = 'circle'
 | 
			
		||||
                trace.marker.symbol = "circle"
 | 
			
		||||
            else:
 | 
			
		||||
                trace.marker.size = 8
 | 
			
		||||
                trace.marker.symbol = 'circle'
 | 
			
		||||
                trace.marker.symbol = "circle"
 | 
			
		||||
            trace.marker.opacity = 1.0
 | 
			
		||||
            fig.add_trace(trace)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        if plot_data.prompts and plot_data.prompt_coordinates is not None:
 | 
			
		||||
            prompt_df = self._prepare_dataframe(plot_data.prompts, plot_data.prompt_coordinates, dimensions)
 | 
			
		||||
            prompt_color_values = self.color_mapper.create_color_mapping(plot_data.prompts, color_by)
 | 
			
		||||
            
 | 
			
		||||
            if dimensions == '3d':
 | 
			
		||||
            prompt_df = self._prepare_dataframe(
 | 
			
		||||
                plot_data.prompts, plot_data.prompt_coordinates, dimensions
 | 
			
		||||
            )
 | 
			
		||||
            prompt_color_values = self.color_mapper.create_color_mapping(
 | 
			
		||||
                plot_data.prompts, color_by
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if dimensions == "3d":
 | 
			
		||||
                prompt_fig = px.scatter_3d(
 | 
			
		||||
                    prompt_df, x='dim_1', y='dim_2', z='dim_3',
 | 
			
		||||
                    prompt_df,
 | 
			
		||||
                    x="dim_1",
 | 
			
		||||
                    y="dim_2",
 | 
			
		||||
                    z="dim_3",
 | 
			
		||||
                    color=prompt_color_values,
 | 
			
		||||
                    hover_data=hover_fields
 | 
			
		||||
                    hover_data=hover_fields,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                prompt_fig = px.scatter(
 | 
			
		||||
                    prompt_df, x='dim_1', y='dim_2',
 | 
			
		||||
                    prompt_df,
 | 
			
		||||
                    x="dim_1",
 | 
			
		||||
                    y="dim_2",
 | 
			
		||||
                    color=prompt_color_values,
 | 
			
		||||
                    hover_data=hover_fields
 | 
			
		||||
                    hover_data=hover_fields,
 | 
			
		||||
                )
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            for trace in prompt_fig.data:
 | 
			
		||||
                if hasattr(trace.marker, 'color') and isinstance(trace.marker.color, str):
 | 
			
		||||
                    trace.marker.color = self.color_mapper.to_grayscale_hex(trace.marker.color)
 | 
			
		||||
                
 | 
			
		||||
                trace.name = f'Prompts - {trace.name}'
 | 
			
		||||
                if dimensions == '3d':
 | 
			
		||||
                if hasattr(trace.marker, "color") and isinstance(
 | 
			
		||||
                    trace.marker.color, str
 | 
			
		||||
                ):
 | 
			
		||||
                    trace.marker.color = self.color_mapper.to_grayscale_hex(
 | 
			
		||||
                        trace.marker.color
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                trace.name = f"Prompts - {trace.name}"
 | 
			
		||||
                if dimensions == "3d":
 | 
			
		||||
                    trace.marker.size = 6
 | 
			
		||||
                    trace.marker.symbol = 'diamond'
 | 
			
		||||
                    trace.marker.symbol = "diamond"
 | 
			
		||||
                else:
 | 
			
		||||
                    trace.marker.size = 10
 | 
			
		||||
                    trace.marker.symbol = 'diamond'
 | 
			
		||||
                    trace.marker.symbol = "diamond"
 | 
			
		||||
                trace.marker.opacity = 0.8
 | 
			
		||||
                fig.add_trace(trace)
 | 
			
		||||
        
 | 
			
		||||
        title = f'{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})'
 | 
			
		||||
 | 
			
		||||
        title = f"{dimensions.upper()} Embedding Visualization - {method} (colored by {color_by})"
 | 
			
		||||
        fig.update_layout(
 | 
			
		||||
            title=title,
 | 
			
		||||
            height=None,
 | 
			
		||||
            autosize=True,
 | 
			
		||||
            margin=dict(l=0, r=0, t=50, b=0)
 | 
			
		||||
            title=title, height=None, autosize=True, margin=dict(l=0, r=0, t=50, b=0)
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        return fig
 | 
			
		||||
    
 | 
			
		||||
    def _prepare_dataframe(self, documents: List[Document], coordinates, dimensions: str) -> pd.DataFrame:
 | 
			
		||||
 | 
			
		||||
    def _prepare_dataframe(
 | 
			
		||||
        self, documents: List[Document], coordinates, dimensions: str
 | 
			
		||||
    ) -> pd.DataFrame:
 | 
			
		||||
        df_data = []
 | 
			
		||||
        for i, doc in enumerate(documents):
 | 
			
		||||
            row = {
 | 
			
		||||
                'id': doc.id,
 | 
			
		||||
                'text': doc.text,
 | 
			
		||||
                'text_preview': doc.text[:100] + "..." if len(doc.text) > 100 else doc.text,
 | 
			
		||||
                'category': doc.category,
 | 
			
		||||
                'subcategory': doc.subcategory,
 | 
			
		||||
                'tags_str': ', '.join(doc.tags) if doc.tags else 'None',
 | 
			
		||||
                'dim_1': coordinates[i, 0],
 | 
			
		||||
                'dim_2': coordinates[i, 1],
 | 
			
		||||
                "id": doc.id,
 | 
			
		||||
                "text": doc.text,
 | 
			
		||||
                "text_preview": doc.text[:100] + "..."
 | 
			
		||||
                if len(doc.text) > 100
 | 
			
		||||
                else doc.text,
 | 
			
		||||
                "category": doc.category,
 | 
			
		||||
                "subcategory": doc.subcategory,
 | 
			
		||||
                "tags_str": ", ".join(doc.tags) if doc.tags else "None",
 | 
			
		||||
                "dim_1": coordinates[i, 0],
 | 
			
		||||
                "dim_2": coordinates[i, 1],
 | 
			
		||||
            }
 | 
			
		||||
            if dimensions == '3d':
 | 
			
		||||
                row['dim_3'] = coordinates[i, 2]
 | 
			
		||||
            if dimensions == "3d":
 | 
			
		||||
                row["dim_3"] = coordinates[i, 2]
 | 
			
		||||
            df_data.append(row)
 | 
			
		||||
        
 | 
			
		||||
        return pd.DataFrame(df_data)
 | 
			
		||||
 | 
			
		||||
        return pd.DataFrame(df_data)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,62 +6,64 @@ from src.embeddingbuddy.models.schemas import Document
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestNDJSONParser:
 | 
			
		||||
    
 | 
			
		||||
    def test_parse_text_basic(self):
 | 
			
		||||
        text_content = '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
 | 
			
		||||
        text_content = (
 | 
			
		||||
            '{"id": "test1", "text": "Hello world", "embedding": [0.1, 0.2, 0.3]}'
 | 
			
		||||
        )
 | 
			
		||||
        documents = NDJSONParser.parse_text(text_content)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert len(documents) == 1
 | 
			
		||||
        assert documents[0].id == "test1"
 | 
			
		||||
        assert documents[0].text == "Hello world"
 | 
			
		||||
        assert documents[0].embedding == [0.1, 0.2, 0.3]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_parse_text_with_metadata(self):
 | 
			
		||||
        text_content = '{"id": "test1", "text": "Hello", "embedding": [0.1, 0.2], "category": "greeting", "tags": ["test"]}'
 | 
			
		||||
        documents = NDJSONParser.parse_text(text_content)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert documents[0].category == "greeting"
 | 
			
		||||
        assert documents[0].tags == ["test"]
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_parse_text_missing_id(self):
 | 
			
		||||
        text_content = '{"text": "Hello", "embedding": [0.1, 0.2]}'
 | 
			
		||||
        documents = NDJSONParser.parse_text(text_content)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert len(documents) == 1
 | 
			
		||||
        assert documents[0].id is not None  # Should be auto-generated
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDataProcessor:
 | 
			
		||||
    
 | 
			
		||||
    def test_extract_embeddings(self):
 | 
			
		||||
        documents = [
 | 
			
		||||
            Document(id="1", text="test1", embedding=[0.1, 0.2]),
 | 
			
		||||
            Document(id="2", text="test2", embedding=[0.3, 0.4])
 | 
			
		||||
            Document(id="2", text="test2", embedding=[0.3, 0.4]),
 | 
			
		||||
        ]
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        processor = DataProcessor()
 | 
			
		||||
        embeddings = processor._extract_embeddings(documents)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert embeddings.shape == (2, 2)
 | 
			
		||||
        assert np.allclose(embeddings[0], [0.1, 0.2])
 | 
			
		||||
        assert np.allclose(embeddings[1], [0.3, 0.4])
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_combine_data(self):
 | 
			
		||||
        from src.embeddingbuddy.models.schemas import ProcessedData
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        doc_data = ProcessedData(
 | 
			
		||||
            documents=[Document(id="1", text="doc", embedding=[0.1, 0.2])],
 | 
			
		||||
            embeddings=np.array([[0.1, 0.2]])
 | 
			
		||||
            embeddings=np.array([[0.1, 0.2]]),
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        prompt_data = ProcessedData(
 | 
			
		||||
            documents=[Document(id="p1", text="prompt", embedding=[0.3, 0.4])],
 | 
			
		||||
            embeddings=np.array([[0.3, 0.4]])
 | 
			
		||||
            embeddings=np.array([[0.3, 0.4]]),
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        processor = DataProcessor()
 | 
			
		||||
        all_embeddings, documents, prompts = processor.combine_data(doc_data, prompt_data)
 | 
			
		||||
        
 | 
			
		||||
        all_embeddings, documents, prompts = processor.combine_data(
 | 
			
		||||
            doc_data, prompt_data
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        assert all_embeddings.shape == (2, 2)
 | 
			
		||||
        assert len(documents) == 1
 | 
			
		||||
        assert len(prompts) == 1
 | 
			
		||||
@@ -70,4 +72,4 @@ class TestDataProcessor:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
 
 | 
			
		||||
@@ -1,89 +1,90 @@
 | 
			
		||||
import pytest
 | 
			
		||||
import numpy as np
 | 
			
		||||
from src.embeddingbuddy.models.reducers import ReducerFactory, PCAReducer, TSNEReducer, UMAPReducer
 | 
			
		||||
from src.embeddingbuddy.models.reducers import (
 | 
			
		||||
    ReducerFactory,
 | 
			
		||||
    PCAReducer,
 | 
			
		||||
    TSNEReducer,
 | 
			
		||||
    UMAPReducer,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestReducerFactory:
 | 
			
		||||
    
 | 
			
		||||
    def test_create_pca_reducer(self):
 | 
			
		||||
        reducer = ReducerFactory.create_reducer('pca', n_components=2)
 | 
			
		||||
        reducer = ReducerFactory.create_reducer("pca", n_components=2)
 | 
			
		||||
        assert isinstance(reducer, PCAReducer)
 | 
			
		||||
        assert reducer.n_components == 2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_create_tsne_reducer(self):
 | 
			
		||||
        reducer = ReducerFactory.create_reducer('tsne', n_components=3)
 | 
			
		||||
        reducer = ReducerFactory.create_reducer("tsne", n_components=3)
 | 
			
		||||
        assert isinstance(reducer, TSNEReducer)
 | 
			
		||||
        assert reducer.n_components == 3
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_create_umap_reducer(self):
 | 
			
		||||
        reducer = ReducerFactory.create_reducer('umap', n_components=2)
 | 
			
		||||
        reducer = ReducerFactory.create_reducer("umap", n_components=2)
 | 
			
		||||
        assert isinstance(reducer, UMAPReducer)
 | 
			
		||||
        assert reducer.n_components == 2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_invalid_method(self):
 | 
			
		||||
        with pytest.raises(ValueError, match="Unknown reduction method"):
 | 
			
		||||
            ReducerFactory.create_reducer('invalid_method')
 | 
			
		||||
    
 | 
			
		||||
            ReducerFactory.create_reducer("invalid_method")
 | 
			
		||||
 | 
			
		||||
    def test_available_methods(self):
 | 
			
		||||
        methods = ReducerFactory.get_available_methods()
 | 
			
		||||
        assert 'pca' in methods
 | 
			
		||||
        assert 'tsne' in methods
 | 
			
		||||
        assert 'umap' in methods
 | 
			
		||||
        assert "pca" in methods
 | 
			
		||||
        assert "tsne" in methods
 | 
			
		||||
        assert "umap" in methods
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPCAReducer:
 | 
			
		||||
    
 | 
			
		||||
    def test_fit_transform(self):
 | 
			
		||||
        embeddings = np.random.rand(100, 512)
 | 
			
		||||
        reducer = PCAReducer(n_components=2)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = reducer.fit_transform(embeddings)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert result.reduced_embeddings.shape == (100, 2)
 | 
			
		||||
        assert result.variance_explained is not None
 | 
			
		||||
        assert result.method == "PCA"
 | 
			
		||||
        assert result.n_components == 2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_method_name(self):
 | 
			
		||||
        reducer = PCAReducer()
 | 
			
		||||
        assert reducer.get_method_name() == "PCA"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTSNEReducer:
 | 
			
		||||
    
 | 
			
		||||
    def test_fit_transform_small_dataset(self):
 | 
			
		||||
        embeddings = np.random.rand(30, 10)  # Small dataset for faster testing
 | 
			
		||||
        reducer = TSNEReducer(n_components=2)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = reducer.fit_transform(embeddings)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert result.reduced_embeddings.shape == (30, 2)
 | 
			
		||||
        assert result.variance_explained is None  # t-SNE doesn't provide this
 | 
			
		||||
        assert result.method == "t-SNE"
 | 
			
		||||
        assert result.n_components == 2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_method_name(self):
 | 
			
		||||
        reducer = TSNEReducer()
 | 
			
		||||
        assert reducer.get_method_name() == "t-SNE"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestUMAPReducer:
 | 
			
		||||
    
 | 
			
		||||
    def test_fit_transform(self):
 | 
			
		||||
        embeddings = np.random.rand(50, 10)
 | 
			
		||||
        reducer = UMAPReducer(n_components=2)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        result = reducer.fit_transform(embeddings)
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
        assert result.reduced_embeddings.shape == (50, 2)
 | 
			
		||||
        assert result.variance_explained is None  # UMAP doesn't provide this
 | 
			
		||||
        assert result.method == "UMAP"
 | 
			
		||||
        assert result.n_components == 2
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def test_method_name(self):
 | 
			
		||||
        reducer = UMAPReducer()
 | 
			
		||||
        assert reducer.get_method_name() == "UMAP"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user