Skip to content

Gradio + SQL

TODO: link to github

from fastapi import FastAPI
import gradio as gr
import os
app = FastAPI()

import plotly.graph_objects as go
from datasets import load_dataset
from databricks import sdk, sql


def query(request):
    email = request.headers.get("x-forwarded-email")
    # token = request.headers.get("x-forwarded-access-token")
    cfg = sdk.config.Config()
    print("using SP auth")
    with sql.connect(
        server_hostname=cfg.host,
        http_path=os.getenv("HTTP_PATH"),
        credentials_provider = lambda: cfg.authenticate,
        # access_token=token,
    ) as conn:
        with conn.cursor() as cursor:
            cursor.execute("SELECT * FROM andre.apps.ab_nyc_2019")
            data = cursor.fetchall_arrow().to_pandas()
            return data


def filter_map(min_price, max_price, boroughs, request: gr.Request):
    df = query(request)

    filtered_df = df[(df['neighbourhood_group'].isin(boroughs)) & 
          (df['price'] > min_price) & (df['price'] < max_price)]
    names = filtered_df["name"].tolist()
    prices = filtered_df["price"].tolist()
    text_list = [(names[i], prices[i]) for i in range(0, len(names))]
    fig = go.Figure(go.Scattermapbox(
            customdata=text_list,
            lat=filtered_df['latitude'].tolist(),
            lon=filtered_df['longitude'].tolist(),
            mode='markers',
            marker=go.scattermapbox.Marker(
                size=6
            ),
            hoverinfo="text",
            hovertemplate='<b>Name</b>: %{customdata[0]}<br><b>Price</b>: $%{customdata[1]}'
        ))

    fig.update_layout(
        mapbox_style="open-street-map",
        hovermode='closest',
        mapbox=dict(
            bearing=0,
            center=go.layout.mapbox.Center(
                lat=40.67,
                lon=-73.90
            ),
            pitch=0,
            zoom=9
        ),
    )

    return fig

with gr.Blocks() as io:
    with gr.Column():
        with gr.Row():
            min_price = gr.Number(value=250, label="Minimum Price")
            max_price = gr.Number(value=1000, label="Maximum Price")
        boroughs = gr.CheckboxGroup(choices=["Queens", "Brooklyn", "Manhattan", "Bronx", "Staten Island"], value=["Queens", "Brooklyn"], label="Select Boroughs:")
        btn = gr.Button(value="Update Filter")
        map = gr.Plot()
    io.load(filter_map, [min_price, max_price, boroughs], map)
    btn.click(filter_map, [min_price, max_price, boroughs], map)

app = gr.mount_gradio_app(app, io, path="/")
command: ["uvicorn", "app:app", "--workers", "4"]
env:
- name: HTTP_PATH
  value: "/sql/1.0/warehouses/dd43ee29fedd958d"