package database import ( "fmt" "log" "os" "tankstopp/internal/models" "time" "golang.org/x/crypto/bcrypt" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) type DB struct { conn *gorm.DB } // NewDB creates a new database connection using GORM with configuration func NewDB(config *Config) (*DB, error) { // Validate configuration if err := config.Validate(); err != nil { return nil, fmt.Errorf("invalid configuration: %w", err) } // Configure GORM gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(config.LogLevel), PrepareStmt: config.PrepareStmt, DisableForeignKeyConstraintWhenMigrating: config.DisableForeignKeyCheck, IgnoreRelationshipsWhenMigrating: config.IgnoreRelationshipsWhenMigrating, QueryFields: config.QueryFields, CreateBatchSize: config.CreateBatchSize, DryRun: config.DryRun, } // Configure slow query logging if config.SlowQueryLog > 0 { env := os.Getenv("ENV") isDev := env == "development" || env == "dev" || env == "" customLogger := logger.New( log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer logger.Config{ SlowThreshold: config.SlowQueryLog, LogLevel: config.LogLevel, IgnoreRecordNotFoundError: true, Colorful: isDev, }, ) gormConfig.Logger = customLogger } conn, err := gorm.Open(sqlite.Open(config.DatabasePath), gormConfig) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // Get underlying SQL DB to configure connection pool sqlDB, err := conn.DB() if err != nil { return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) } // Set connection pool settings from configuration sqlDB.SetMaxIdleConns(config.MaxIdleConns) sqlDB.SetMaxOpenConns(config.MaxOpenConns) sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime) sqlDB.SetConnMaxIdleTime(config.ConnMaxIdleTime) db := &DB{conn: conn} // Run migrations if enabled if config.AutoMigrate { if err := db.migrate(); err != nil { return nil, fmt.Errorf("failed to migrate database: %w", err) } } return db, nil } // NewDBWithDefaults creates a new database connection with default configuration func NewDBWithDefaults(databasePath string) (*DB, error) { config := DefaultConfig() config.DatabasePath = databasePath return NewDB(config) } // NewDBFromEnv creates a new database connection using environment variables func NewDBFromEnv() (*DB, error) { config := LoadFromEnv() return NewDB(config) } // NewDBFromConfig creates a new database connection using configuration file func NewDBFromConfig(configPath string) (*DB, error) { config := LoadFromConfig(configPath) return NewDB(config) } // Close closes the database connection func (db *DB) Close() error { sqlDB, err := db.conn.DB() if err != nil { return err } return sqlDB.Close() } // migrate runs database migrations func (db *DB) migrate() error { // Auto-migrate the schema return db.conn.AutoMigrate(&models.User{}, &models.Vehicle{}, &models.FuelStop{}) } // CreateFuelStop inserts a new fuel stop into the database func (db *DB) CreateFuelStop(stop *models.FuelStop) error { // Set timestamps now := time.Now() stop.CreatedAt = now stop.UpdatedAt = now result := db.conn.Create(stop) if result.Error != nil { return fmt.Errorf("failed to create fuel stop: %w", result.Error) } return nil } // GetFuelStops retrieves all fuel stops for a specific user from the database func (db *DB) GetFuelStops(userID uint) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Preload("Vehicle").Where("user_id = ?", userID). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get fuel stops: %w", result.Error) } return stops, nil } // GetFuelStopsByVehicle retrieves all fuel stops for a specific vehicle func (db *DB) GetFuelStopsByVehicle(vehicleID, userID uint) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Where("vehicle_id = ? AND user_id = ?", vehicleID, userID). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get fuel stops by vehicle: %w", result.Error) } return stops, nil } // GetFuelStopByID retrieves a fuel stop by its ID and user ID func (db *DB) GetFuelStopByID(id, userID uint) (*models.FuelStop, error) { var stop models.FuelStop result := db.conn.Where("id = ? AND user_id = ?", id, userID).First(&stop) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil // Return nil when record not found } return nil, fmt.Errorf("failed to get fuel stop: %w", result.Error) } return &stop, nil } // UpdateFuelStop updates an existing fuel stop func (db *DB) UpdateFuelStop(stop *models.FuelStop) error { // Update timestamp stop.UpdatedAt = time.Now() result := db.conn.Model(stop). Where("id = ? AND user_id = ?", stop.ID, stop.UserID). Updates(stop) if result.Error != nil { return fmt.Errorf("failed to update fuel stop: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("fuel stop not found or access denied") } return nil } // DeleteFuelStop deletes a fuel stop by its ID and user ID func (db *DB) DeleteFuelStop(id, userID uint) error { result := db.conn.Where("id = ? AND user_id = ?", id, userID).Delete(&models.FuelStop{}) if result.Error != nil { return fmt.Errorf("failed to delete fuel stop: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("fuel stop not found or access denied") } return nil } // CreateVehicle creates a new vehicle for a user func (db *DB) CreateVehicle(vehicle *models.Vehicle) error { // Set timestamps now := time.Now() vehicle.CreatedAt = now vehicle.UpdatedAt = now result := db.conn.Create(vehicle) if result.Error != nil { return fmt.Errorf("failed to create vehicle: %w", result.Error) } return nil } // GetVehicles retrieves all vehicles for a specific user func (db *DB) GetVehicles(userID uint) ([]models.Vehicle, error) { var vehicles []models.Vehicle result := db.conn.Where("user_id = ?", userID). Order("is_active DESC, name ASC"). Find(&vehicles) if result.Error != nil { return nil, fmt.Errorf("failed to get vehicles: %w", result.Error) } return vehicles, nil } // GetActiveVehicles retrieves only active vehicles for a specific user func (db *DB) GetActiveVehicles(userID uint) ([]models.Vehicle, error) { var vehicles []models.Vehicle result := db.conn.Where("user_id = ? AND is_active = ?", userID, true). Order("name ASC"). Find(&vehicles) if result.Error != nil { return nil, fmt.Errorf("failed to get active vehicles: %w", result.Error) } return vehicles, nil } // GetVehicleByID retrieves a vehicle by its ID and user ID func (db *DB) GetVehicleByID(id, userID uint) (*models.Vehicle, error) { var vehicle models.Vehicle result := db.conn.Where("id = ? AND user_id = ?", id, userID).First(&vehicle) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil // Return nil when vehicle not found } return nil, fmt.Errorf("failed to get vehicle: %w", result.Error) } return &vehicle, nil } // UpdateVehicle updates an existing vehicle func (db *DB) UpdateVehicle(vehicle *models.Vehicle) error { // Update timestamp vehicle.UpdatedAt = time.Now() result := db.conn.Model(vehicle). Where("id = ? AND user_id = ?", vehicle.ID, vehicle.UserID). Updates(vehicle) if result.Error != nil { return fmt.Errorf("failed to update vehicle: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("vehicle not found or access denied") } return nil } // DeleteVehicle deletes a vehicle by its ID and user ID func (db *DB) DeleteVehicle(id, userID uint) error { // Check if vehicle has fuel stops var count int64 if err := db.conn.Model(&models.FuelStop{}).Where("vehicle_id = ?", id).Count(&count).Error; err != nil { return fmt.Errorf("failed to check fuel stops: %w", err) } if count > 0 { return fmt.Errorf("cannot delete vehicle with existing fuel stops") } result := db.conn.Where("id = ? AND user_id = ?", id, userID).Delete(&models.Vehicle{}) if result.Error != nil { return fmt.Errorf("failed to delete vehicle: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("vehicle not found or access denied") } return nil } // GetVehicleWithFuelStops retrieves a vehicle with its fuel stops func (db *DB) GetVehicleWithFuelStops(vehicleID, userID uint) (*models.Vehicle, error) { var vehicle models.Vehicle result := db.conn.Preload("FuelStops", func(db *gorm.DB) *gorm.DB { return db.Order("date DESC") }).Where("id = ? AND user_id = ?", vehicleID, userID).First(&vehicle) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get vehicle with fuel stops: %w", result.Error) } return &vehicle, nil } // GetVehicleCount returns the total number of vehicles for a user func (db *DB) GetVehicleCount(userID uint) (int64, error) { var count int64 result := db.conn.Model(&models.Vehicle{}).Where("user_id = ?", userID).Count(&count) if result.Error != nil { return 0, fmt.Errorf("failed to count vehicles: %w", result.Error) } return count, nil } // ValidateVehicle validates vehicle data before creation/update func (db *DB) ValidateVehicle(vehicle *models.Vehicle) error { if vehicle.UserID == 0 { return fmt.Errorf("user ID is required") } // Check if user exists exists, err := db.UserExists(vehicle.UserID) if err != nil { return fmt.Errorf("failed to validate user: %w", err) } if !exists { return fmt.Errorf("user does not exist") } if vehicle.Name == "" { return fmt.Errorf("vehicle name is required") } if len(vehicle.Name) > 100 { return fmt.Errorf("vehicle name must be 100 characters or less") } if vehicle.Make != "" && len(vehicle.Make) > 50 { return fmt.Errorf("vehicle make must be 50 characters or less") } if vehicle.Model != "" && len(vehicle.Model) > 50 { return fmt.Errorf("vehicle model must be 50 characters or less") } if vehicle.LicensePlate != "" && len(vehicle.LicensePlate) > 20 { return fmt.Errorf("license plate must be 20 characters or less") } if vehicle.FuelType != "" && len(vehicle.FuelType) > 50 { return fmt.Errorf("fuel type must be 50 characters or less") } if vehicle.Year < 0 || vehicle.Year > time.Now().Year()+1 { return fmt.Errorf("invalid vehicle year") } return nil } // CreateVehicleWithValidation creates a vehicle with validation func (db *DB) CreateVehicleWithValidation(vehicle *models.Vehicle) error { if err := db.ValidateVehicle(vehicle); err != nil { return fmt.Errorf("validation failed: %w", err) } return db.CreateVehicle(vehicle) } // GetFuelStopStats calculates statistics for fuel consumption for a specific user func (db *DB) GetFuelStopStats(userID uint) (*models.FuelStopStats, error) { stats := &models.FuelStopStats{} // Get basic statistics var result struct { TotalStops int64 `json:"total_stops"` TotalLiters float64 `json:"total_liters"` TotalSpent float64 `json:"total_spent"` AveragePrice float64 `json:"average_price"` TotalTripKm float64 `json:"total_trip_km"` MinOdometer int `json:"min_odometer"` MaxOdometer int `json:"max_odometer"` } err := db.conn.Model(&models.FuelStop{}). Select(` COUNT(*) as total_stops, COALESCE(SUM(liters), 0) as total_liters, COALESCE(SUM(total_price), 0) as total_spent, COALESCE(AVG(price_per_l), 0) as average_price, COALESCE(SUM(trip_length), 0) as total_trip_km, COALESCE(MIN(odometer), 0) as min_odometer, COALESCE(MAX(odometer), 0) as max_odometer `). Where("user_id = ?", userID). Scan(&result).Error if err != nil { return nil, fmt.Errorf("failed to get fuel stop stats: %w", err) } stats.TotalStops = int(result.TotalStops) stats.TotalLiters = result.TotalLiters stats.TotalSpent = result.TotalSpent stats.AveragePrice = result.AveragePrice // Get the last fillup var lastStop models.FuelStop err = db.conn.Where("user_id = ?", userID). Order("date DESC"). First(&lastStop).Error if err != nil && err != gorm.ErrRecordNotFound { return nil, fmt.Errorf("failed to get last fillup: %w", err) } if err != gorm.ErrRecordNotFound { stats.LastFillup = &lastStop } // Calculate average consumption using trip length (preferred) or odometer difference (fallback) if stats.TotalStops > 1 { // Primary method: Use trip length if available if result.TotalTripKm > 0 { stats.AverageConsumption = (stats.TotalLiters / result.TotalTripKm) * 100 } else if result.MaxOdometer > result.MinOdometer { // Fallback method: Use odometer difference distanceDriven := result.MaxOdometer - result.MinOdometer if distanceDriven > 0 { stats.AverageConsumption = (stats.TotalLiters / float64(distanceDriven)) * 100 } } } return stats, nil } // CreateUser creates a new user in the database func (db *DB) CreateUser(user *models.User) error { // Hash the password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } user.PasswordHash = string(hashedPassword) user.Password = "" // Clear the plain text password // Set timestamps now := time.Now() user.CreatedAt = now user.UpdatedAt = now result := db.conn.Create(user) if result.Error != nil { return fmt.Errorf("failed to create user: %w", result.Error) } return nil } // GetUserByUsername retrieves a user by username func (db *DB) GetUserByUsername(username string) (*models.User, error) { var user models.User result := db.conn.Where("username = ?", username).First(&user) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil // Return nil when user not found } return nil, fmt.Errorf("failed to get user by username: %w", result.Error) } return &user, nil } // GetUserByID retrieves a user by ID func (db *DB) GetUserByID(id uint) (*models.User, error) { var user models.User result := db.conn.First(&user, id) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil // Return nil when user not found } return nil, fmt.Errorf("failed to get user by ID: %w", result.Error) } return &user, nil } // UpdateUser updates an existing user func (db *DB) UpdateUser(user *models.User) error { // Update timestamp user.UpdatedAt = time.Now() result := db.conn.Model(user). Select("email", "base_currency", "updated_at"). Updates(user) if result.Error != nil { return fmt.Errorf("failed to update user: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("user not found") } return nil } // UpdateUserPassword updates a user's password func (db *DB) UpdateUserPassword(user *models.User, newPassword string) error { // Hash the new password hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return fmt.Errorf("failed to hash password: %w", err) } // Update timestamp now := time.Now() // Update only the password hash and updated_at fields result := db.conn.Model(user). Updates(map[string]interface{}{ "password_hash": string(hashedPassword), "updated_at": now, }) if result.Error != nil { return fmt.Errorf("failed to update password: %w", result.Error) } if result.RowsAffected == 0 { return fmt.Errorf("user not found") } // Update the user object with new password hash user.PasswordHash = string(hashedPassword) user.UpdatedAt = now return nil } // GetUserByEmail retrieves a user by email func (db *DB) GetUserByEmail(email string) (*models.User, error) { var user models.User result := db.conn.Where("email = ?", email).First(&user) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil // Return nil when user not found } return nil, fmt.Errorf("failed to get user by email: %w", result.Error) } return &user, nil } // GetUserWithFuelStops retrieves a user with their fuel stops func (db *DB) GetUserWithFuelStops(userID uint) (*models.User, error) { var user models.User result := db.conn.Preload("FuelStops", func(db *gorm.DB) *gorm.DB { return db.Order("date DESC") }).Preload("Vehicles").First(&user, userID) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get user with fuel stops: %w", result.Error) } return &user, nil } // GetFuelStopsWithPagination retrieves fuel stops with pagination func (db *DB) GetFuelStopsWithPagination(userID uint, limit, offset int) ([]models.FuelStop, int64, error) { var stops []models.FuelStop var total int64 // Get total count err := db.conn.Model(&models.FuelStop{}). Where("user_id = ?", userID). Count(&total).Error if err != nil { return nil, 0, fmt.Errorf("failed to count fuel stops: %w", err) } // Get paginated results err = db.conn.Where("user_id = ?", userID). Order("date DESC"). Limit(limit). Offset(offset). Find(&stops).Error if err != nil { return nil, 0, fmt.Errorf("failed to get fuel stops with pagination: %w", err) } return stops, total, nil } // GetFuelStopsByDateRange retrieves fuel stops within a date range func (db *DB) GetFuelStopsByDateRange(userID uint, startDate, endDate time.Time) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Where("user_id = ? AND date BETWEEN ? AND ?", userID, startDate, endDate). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get fuel stops by date range: %w", result.Error) } return stops, nil } // GetFuelStopsByFuelType retrieves fuel stops by fuel type func (db *DB) GetFuelStopsByFuelType(userID uint, fuelType string) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Where("user_id = ? AND fuel_type = ?", userID, fuelType). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get fuel stops by fuel type: %w", result.Error) } return stops, nil } // GetMonthlyStats retrieves monthly statistics for a user func (db *DB) GetMonthlyStats(userID uint, year int) ([]models.MonthlyStats, error) { var stats []models.MonthlyStats query := ` SELECT strftime('%m', date) as month, strftime('%Y', date) as year, COUNT(*) as total_stops, SUM(liters) as total_liters, SUM(total_price) as total_spent, AVG(price_per_l) as average_price FROM fuel_stops WHERE user_id = ? AND strftime('%Y', date) = ? GROUP BY strftime('%Y-%m', date) ORDER BY month ` err := db.conn.Raw(query, userID, fmt.Sprintf("%d", year)).Scan(&stats).Error if err != nil { return nil, fmt.Errorf("failed to get monthly stats: %w", err) } return stats, nil } // BulkCreateFuelStops creates multiple fuel stops in a single transaction func (db *DB) BulkCreateFuelStops(stops []models.FuelStop) error { if len(stops) == 0 { return nil } // Set timestamps for all stops now := time.Now() for i := range stops { stops[i].CreatedAt = now stops[i].UpdatedAt = now } // Use transaction for bulk insert return db.conn.Transaction(func(tx *gorm.DB) error { return tx.CreateInBatches(stops, 100).Error }) } // DeleteAllUserData deletes all data for a user (for account deletion) func (db *DB) DeleteAllUserData(userID uint) error { return db.conn.Transaction(func(tx *gorm.DB) error { // Delete all fuel stops first (due to foreign key constraint) if err := tx.Where("user_id = ?", userID).Delete(&models.FuelStop{}).Error; err != nil { return fmt.Errorf("failed to delete fuel stops: %w", err) } // Delete the user if err := tx.Delete(&models.User{}, userID).Error; err != nil { return fmt.Errorf("failed to delete user: %w", err) } return nil }) } // HealthCheck performs a simple health check on the database func (db *DB) HealthCheck() error { sqlDB, err := db.conn.DB() if err != nil { return fmt.Errorf("failed to get underlying sql.DB: %w", err) } return sqlDB.Ping() } // GetFuelStopCount returns the total number of fuel stops for a user func (db *DB) GetFuelStopCount(userID uint) (int64, error) { var count int64 result := db.conn.Model(&models.FuelStop{}).Where("user_id = ?", userID).Count(&count) if result.Error != nil { return 0, fmt.Errorf("failed to count fuel stops: %w", result.Error) } return count, nil } // GetLatestFuelStop returns the most recent fuel stop for a user func (db *DB) GetLatestFuelStop(userID uint) (*models.FuelStop, error) { var stop models.FuelStop result := db.conn.Where("user_id = ?", userID).Order("date DESC").First(&stop) if result.Error != nil { if result.Error == gorm.ErrRecordNotFound { return nil, nil } return nil, fmt.Errorf("failed to get latest fuel stop: %w", result.Error) } return &stop, nil } // UserExists checks if a user exists by ID func (db *DB) UserExists(userID uint) (bool, error) { var count int64 result := db.conn.Model(&models.User{}).Where("id = ?", userID).Count(&count) if result.Error != nil { return false, fmt.Errorf("failed to check user existence: %w", result.Error) } return count > 0, nil } // ValidateFuelStop validates fuel stop data before creation/update func (db *DB) ValidateFuelStop(stop *models.FuelStop) error { if stop.UserID == 0 { return fmt.Errorf("user ID is required") } // Check if user exists exists, err := db.UserExists(stop.UserID) if err != nil { return fmt.Errorf("failed to validate user: %w", err) } if !exists { return fmt.Errorf("user does not exist") } if stop.Date.IsZero() { return fmt.Errorf("date is required") } if stop.StationName == "" { return fmt.Errorf("station name is required") } if stop.Location == "" { return fmt.Errorf("location is required") } if stop.FuelType == "" { return fmt.Errorf("fuel type is required") } if stop.Liters <= 0 { return fmt.Errorf("liters must be greater than 0") } if stop.PricePerL <= 0 { return fmt.Errorf("price per liter must be greater than 0") } if stop.TotalPrice <= 0 { return fmt.Errorf("total price must be greater than 0") } if stop.Currency == "" { stop.Currency = "EUR" // Set default currency } if stop.TripLength < 0 { return fmt.Errorf("trip length cannot be negative") } if stop.VehicleID == 0 { return fmt.Errorf("vehicle is required") } // Check if vehicle exists and belongs to user vehicle, err := db.GetVehicleByID(stop.VehicleID, stop.UserID) if err != nil { return fmt.Errorf("failed to validate vehicle: %w", err) } if vehicle == nil { return fmt.Errorf("vehicle does not exist or access denied") } return nil } // CreateFuelStopWithValidation creates a fuel stop with validation func (db *DB) CreateFuelStopWithValidation(stop *models.FuelStop) error { if err := db.ValidateFuelStop(stop); err != nil { return fmt.Errorf("validation failed: %w", err) } return db.CreateFuelStop(stop) } // GetFuelStopsWithUser retrieves fuel stops with user information preloaded func (db *DB) GetFuelStopsWithUser(userID uint) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Preload("User").Preload("Vehicle"). Where("user_id = ?", userID). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get fuel stops with user: %w", result.Error) } return stops, nil } // SearchFuelStops performs a text search across station names and locations func (db *DB) SearchFuelStops(userID uint, searchTerm string) ([]models.FuelStop, error) { var stops []models.FuelStop searchPattern := "%" + searchTerm + "%" result := db.conn.Where("user_id = ? AND (station_name LIKE ? OR location LIKE ?)", userID, searchPattern, searchPattern). Order("date DESC"). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to search fuel stops: %w", result.Error) } return stops, nil } // GetRecentFuelStops returns the most recent N fuel stops for a user func (db *DB) GetRecentFuelStops(userID uint, limit int) ([]models.FuelStop, error) { var stops []models.FuelStop result := db.conn.Where("user_id = ?", userID). Order("date DESC"). Limit(limit). Find(&stops) if result.Error != nil { return nil, fmt.Errorf("failed to get recent fuel stops: %w", result.Error) } return stops, nil }