Skip to content

Streamlit Dashboard Source Code

TODO: link to github

TODO: test again to make sure code is correct

from databricks.sdk.core import Config
from databricks.sdk import WorkspaceClient
import streamlit as st
import numpy as np
import pandas as pd
import os

# Import the auth module to get the user's or service principal's information
import auth

# Set the page configuration
st.set_page_config(page_title="Streamlit in Databricks!", page_icon="", layout="wide")

# Get the user's information
user_info = auth.get_user_info()

# Define the authentication methods
service_principal = 'Service Principal'
app_user = 'App User'


def get_sql_connection(method):
    if method == app_user:
        return auth.get_user_sql_connection(os.getenv("DATABRICKS_WAREHOUSE_HTTP_PATH"))
    else:
        return auth.get_app_sql_connection(os.getenv("DATABRICKS_WAREHOUSE_HTTP_PATH"))


# Define the pages
pages = ['Home', 'App Info']

# Create a selectbox in the sidebar for navigation
selected_page = st.sidebar.selectbox('Choose a page', pages)

# Display the user's information in the sidebar
st.sidebar.title("User Info")
st.sidebar.write(f"User Name: {user_info.get('user_name')}")
st.sidebar.write(f"User Email: {user_info.get('user_email')}")
st.sidebar.write(f"User ID: {user_info.get('user_id')}")


# Display the authentication method selectbox in the sidebar
st.sidebar.title("Authentication Method")
options = [app_user, service_principal]
default_index = options.index(app_user)
st.session_state.authentication_method = st.sidebar.selectbox('Choose the authentication method', options, index=default_index)

# Display the selected page
if selected_page == 'Home':

    # cache the data for 30 seconds. Method is used as a parameter to invalidate the cache when the authentication method changes
    @st.cache_data(ttl=30)
    def get_scatter_data(method):
        print(f"Fetching data from Databricks using {method}\n", flush=True)

        # Create a connection to the Databricks SQL warehouse
        connection = get_sql_connection(method)
        cursor = connection.cursor()

        try:
            cursor.execute(
                f"""
                SELECT
                    tpep_pickup_datetime as pickup_time,
                    trip_distance,
                    fare_amount
                FROM andre.nyctaxi.trips;
                """
            )

            df = cursor.fetchall_arrow().to_pandas()
            return df
        except Exception as e:
            print(e, flush=True)
            # enhance the error message, maintaining original and raise it
            raise Exception(f"Error fetching data from Databricks: {e}") 
        finally:
            cursor.close()
            connection.close()


    def get_taxi_df():
        return get_scatter_data(st.session_state.authentication_method)


    st.header("NYC Taxi Trips")
    data = get_taxi_df()

    if st.checkbox('Show data table'):
        st.subheader("Raw Data")
        st.dataframe(data=data, height=600, use_container_width=True)


    st.subheader("Number of pickups by hour")
    if data is not None:
        hist_values = np.histogram(data['pickup_time'].dt.hour, bins=24, range=(0, 24))[0]
        st.bar_chart(hist_values)
    else:
        # User does not have access to the data
        st.write("No data available. Check your permissions.")

elif selected_page == 'App Info':

        # to use the SDK on behalf of the user, we need to create a Config object with the user's credentials
    w = auth.get_app_workspace_client()
    app = w.apps.get(os.getenv('DATABRICKS_APP_NAME'))
    # Display the app info
    st.header('App Info')
    st.write('App Name:', app.name)
    st.write('App URL:', app.url)
    st.markdown('App state: `{}`'.format(app.status.state.value))
    st.write('App status message:', app.status.message)
    st.write('App creator:', app.creator)

    st.write('Host:', os.getenv('DATABRICKS_HOST'))
    st.write('Warehouse ID:', os.getenv('DATABRICKS_WAREHOUSE_ID'))
    # display the app info
    st.subheader('Active Deployment')
    st.markdown('Deployment ID: `{}`'.format(app.active_deployment.deployment_id))
    st.markdown('Deployment state: `{}`'.format(app.active_deployment.status.state.value))
    st.markdown('Deployment message: `{}`'.format(app.active_deployment.status.message))
    st.markdown('Deployment source code path: `{}`'.format(app.active_deployment.source_code_path))
    st.write('Deployment creator:', app.active_deployment.creator)
    st.write('Deployment create time:', app.active_deployment.create_time)

TODO: simplify this. No reason why it needs to be so complicated...

# This package describes how to authenticate users in a Streamlit app running in Databricks.

# The get_user_info function retrieves the user's information from the WebSocket headers.
# The get_user_credentials_provider function returns the user's credentials provider, which is used to authenticate the user to the Databricks workspace.
# The get_service_principal_credentials_provider function returns the app's service principal credentials provider, which is used to authenticate the app to the Databricks workspace.

import os
from typing import Dict, Optional
from streamlit.web.server.websocket_headers import _get_websocket_headers
from databricks.sdk.core import Config, HeaderFactory, oauth_service_principal, CredentialsProvider
from databricks.sdk import WorkspaceClient
from databricks import sql


def get_user_info():
    headers = _get_websocket_headers()

    return dict(
        user_name=headers.get("X-Forwarded-Preferred-Username"),
        user_email=headers.get("X-Forwarded-Email"),
        user_id=headers.get("X-Forwarded-User"),
        access_token=headers.get("X-Forwarded-Access-Token")
    )

def get_app_service_principal_info():
    return dict(
        client_id=os.getenv('DATABRICKS_CLIENT_ID'),
        client_secret=os.getenv('DATABRICKS_CLIENT_SECRET'),
    )

def get_user_credentials_provider() -> CredentialsProvider:
    """Returns a credentials provider for the current user.
    This is so the same method can be used in Databricks SDK and SQL connector.
    It looks complicated, but it is just using the access_token from the user_info."""

    def inner(cfg: Optional[Config] = None) -> Dict[str, str]:
        user_info = get_user_info()
        if not user_info.get("access_token"):
            raise ValueError("User access token not found. Please make sure the feature is enabled in your workspace.")
        static_credentials = {'Authorization': f'Bearer {user_info.get("access_token")}'}
        return lambda: static_credentials

    inner.auth_type = lambda: 'app-user-oauth'

    return inner

def get_app_credentials_provider() -> CredentialsProvider:
    """Returns a credentials provider for the app service principal.
    This is so the same method can be used in Databricks SDK and SQL connector.
    It looks complicated, but it is just using the oauth_service_principal method."""

    def inner(cfg: Optional[Config] = None) -> Dict[str, str]:
        if cfg is None:
              cfg = Config()
        return oauth_service_principal(cfg)

    inner.auth_type = lambda: 'app-service-principal-oauth'

    return inner

def get_user_workspace_client() -> WorkspaceClient:
    """Returns a WorkspaceClient for the current user."""
    return WorkspaceClient(config=Config(credentials_provider=get_user_credentials_provider()))


def get_app_workspace_client() -> WorkspaceClient:
    """Returns a WorkspaceClient for the app service principal. Uses OAuth for authentication."""
    return WorkspaceClient(config=Config(credentials_provider=get_app_credentials_provider()))


def get_user_sql_connection(http_path):
    """Returns a SQL connection for the current user. 
    This is meant to be used for short-lived connections. For long-lived connections, use get_app_sql_connection."""
    return sql.connect(
        server_hostname=os.getenv('DATABRICKS_HOST'),
        http_path=http_path,
        credentials_provider=get_user_credentials_provider()
    )

def get_app_sql_connection(http_path):
    """Returns a SQL connection for the app service principal."""
    return sql.connect(
        server_hostname=os.getenv('DATABRICKS_HOST'),
        http_path=http_path,
        credentials_provider=get_app_credentials_provider()
    )

command: [
  "streamlit", 
  "run", 
  "app.py"
]
env:
  - name: "DATABRICKS_WAREHOUSE_HTTP_PATH"
    value: "/sql/1.0/warehouses/58aa1b363649e722"