from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import List, Dict, Any import os import datetime import torch from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login # ========================================== # 1. APP SETUP # ========================================== app = FastAPI(title="FunctionGemma Brain API", version="1.0.0") MODEL_ID = "google/functiongemma-270m-it" tokenizer = None model = None # ========================================== # 2. FEW-SHOT EXAMPLES (The Teacher) # ========================================== # We teach the model the correct tool names here. # This list simulates a previous conversation so the model knows what to do. FEW_SHOT_MESSAGES = [ # Example 1: Counting/Stats {"role": "user", "content": "How many regions are there?"}, {"role": "model", "content": "call:get_aggregate_stats{target_entity:revenue_region}"}, # Example 2: Specific Search {"role": "user", "content": "What is the water level in Aadale dam?"}, {"role": "model", "content": "call:search_specific_dam{dam_name:Aadale}"}, # Example 3: Filtering {"role": "user", "content": "Show me Major dams in Pune."}, {"role": "model", "content": "call:filter_dams{district:Pune,project_type:Major}"}, # Example 4: Irrelevant Question (Teach it to NOT call functions for random stuff) {"role": "user", "content": "What is the capital of France?"}, {"role": "model", "content": "I cannot answer that as it is not related to the dam database."} ] # ========================================== # 3. STARTUP # ========================================== @app.on_event("startup") async def startup(): global tokenizer, model hf_token = os.getenv("HF_TOKEN") if not hf_token: raise RuntimeError("HF_TOKEN missing") login(token=hf_token) print(f"🧠 Loading {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32) print("✅ Model Loaded.") # ========================================== # 4. API ENDPOINT # ========================================== class ChatRequest(BaseModel): query: str tools: List[Dict[str, Any]] include_date: bool = True @app.post("/generate") async def generate_function_call(request: ChatRequest): if not model: raise HTTPException(status_code=503, detail="Model loading") try: # 1. System Prompt system_content = "You are a model that can do function calling with the following functions." if request.include_date: today = datetime.date.today().isoformat() system_content += f" Today is {today}." # 2. Construct History: System -> Examples -> Current User Query messages = [{"role": "system", "content": system_content}] # Inject the examples! messages.extend(FEW_SHOT_MESSAGES) # Add the actual user query messages.append({"role": "user", "content": request.query}) # 3. Tokenize inputs = tokenizer.apply_chat_template( messages, tools=request.tools, add_generation_prompt=True, return_dict=True, return_tensors="pt", ) # 4. Generate outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False) generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True) return {"response": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=str(e))