first commit
This commit is contained in:
@@ -0,0 +1,495 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Config holds database configuration settings
|
||||
type Config struct {
|
||||
// Database file path
|
||||
DatabasePath string
|
||||
|
||||
// Connection pool settings
|
||||
MaxIdleConns int
|
||||
MaxOpenConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
|
||||
// Logging settings
|
||||
LogLevel logger.LogLevel
|
||||
SlowQueryLog time.Duration
|
||||
|
||||
// Migration settings
|
||||
AutoMigrate bool
|
||||
DropTableFirst bool
|
||||
CreateBatchSize int
|
||||
|
||||
// Performance settings
|
||||
PrepareStmt bool
|
||||
DisableForeignKeyCheck bool
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
|
||||
// Development settings
|
||||
Debug bool
|
||||
DryRun bool
|
||||
QueryFields bool
|
||||
CreateInBatches int
|
||||
}
|
||||
|
||||
// DefaultConfig returns a configuration with sensible defaults
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
DatabasePath: "fuel_stops.db",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: time.Hour,
|
||||
ConnMaxIdleTime: 30 * time.Minute,
|
||||
LogLevel: logger.Silent,
|
||||
SlowQueryLog: 200 * time.Millisecond,
|
||||
AutoMigrate: true,
|
||||
DropTableFirst: false,
|
||||
CreateBatchSize: 1000,
|
||||
PrepareStmt: true,
|
||||
DisableForeignKeyCheck: false,
|
||||
IgnoreRelationshipsWhenMigrating: false,
|
||||
Debug: false,
|
||||
DryRun: false,
|
||||
QueryFields: false,
|
||||
CreateInBatches: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromConfig loads configuration from config file using Viper
|
||||
func LoadFromConfig(configPath string) *Config {
|
||||
config := DefaultConfig()
|
||||
|
||||
// Initialize Viper
|
||||
v := viper.New()
|
||||
|
||||
// Set config file path if provided
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
} else {
|
||||
// Search for config file in multiple locations
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
v.AddConfigPath("$HOME/.tankstopp")
|
||||
v.AddConfigPath("/etc/tankstopp")
|
||||
}
|
||||
|
||||
// Try to read config file
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
// If config file not found, fall back to environment variables
|
||||
return LoadFromEnv()
|
||||
}
|
||||
|
||||
// Load database configuration from Viper
|
||||
if v.IsSet("database.path") {
|
||||
config.DatabasePath = v.GetString("database.path")
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if v.IsSet("database.connection_pool.max_idle_connections") {
|
||||
config.MaxIdleConns = v.GetInt("database.connection_pool.max_idle_connections")
|
||||
}
|
||||
if v.IsSet("database.connection_pool.max_open_connections") {
|
||||
config.MaxOpenConns = v.GetInt("database.connection_pool.max_open_connections")
|
||||
}
|
||||
if v.IsSet("database.connection_pool.connection_max_lifetime") {
|
||||
config.ConnMaxLifetime = v.GetDuration("database.connection_pool.connection_max_lifetime")
|
||||
}
|
||||
if v.IsSet("database.connection_pool.connection_max_idle_time") {
|
||||
config.ConnMaxIdleTime = v.GetDuration("database.connection_pool.connection_max_idle_time")
|
||||
}
|
||||
|
||||
// Logging settings
|
||||
if v.IsSet("database.logging.level") {
|
||||
config.LogLevel = getLogLevelFromString(v.GetString("database.logging.level"))
|
||||
}
|
||||
if v.IsSet("database.logging.debug") {
|
||||
config.Debug = v.GetBool("database.logging.debug")
|
||||
}
|
||||
if v.IsSet("database.logging.slow_query_threshold") {
|
||||
config.SlowQueryLog = v.GetDuration("database.logging.slow_query_threshold")
|
||||
}
|
||||
|
||||
// Migration settings
|
||||
if v.IsSet("database.migration.auto_migrate") {
|
||||
config.AutoMigrate = v.GetBool("database.migration.auto_migrate")
|
||||
}
|
||||
if v.IsSet("database.migration.drop_tables_first") {
|
||||
config.DropTableFirst = v.GetBool("database.migration.drop_tables_first")
|
||||
}
|
||||
if v.IsSet("database.migration.create_batch_size") {
|
||||
config.CreateBatchSize = v.GetInt("database.migration.create_batch_size")
|
||||
}
|
||||
|
||||
// Performance settings
|
||||
if v.IsSet("database.performance.prepare_statements") {
|
||||
config.PrepareStmt = v.GetBool("database.performance.prepare_statements")
|
||||
}
|
||||
if v.IsSet("database.performance.disable_foreign_key_check") {
|
||||
config.DisableForeignKeyCheck = v.GetBool("database.performance.disable_foreign_key_check")
|
||||
}
|
||||
if v.IsSet("database.performance.ignore_relationships_when_migrating") {
|
||||
config.IgnoreRelationshipsWhenMigrating = v.GetBool("database.performance.ignore_relationships_when_migrating")
|
||||
}
|
||||
if v.IsSet("database.performance.query_fields") {
|
||||
config.QueryFields = v.GetBool("database.performance.query_fields")
|
||||
}
|
||||
if v.IsSet("database.performance.dry_run") {
|
||||
config.DryRun = v.GetBool("database.performance.dry_run")
|
||||
}
|
||||
if v.IsSet("database.performance.create_in_batches") {
|
||||
config.CreateInBatches = v.GetInt("database.performance.create_in_batches")
|
||||
}
|
||||
|
||||
// Environment variables still take precedence over config file
|
||||
config = mergeWithEnvVars(config)
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// LoadFromEnv loads configuration from environment variables
|
||||
func LoadFromEnv() *Config {
|
||||
config := DefaultConfig()
|
||||
|
||||
// Database path
|
||||
if dbPath := os.Getenv("DB_PATH"); dbPath != "" {
|
||||
config.DatabasePath = dbPath
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if maxIdle := getEnvInt("DB_MAX_IDLE_CONNS", config.MaxIdleConns); maxIdle > 0 {
|
||||
config.MaxIdleConns = maxIdle
|
||||
}
|
||||
|
||||
if maxOpen := getEnvInt("DB_MAX_OPEN_CONNS", config.MaxOpenConns); maxOpen > 0 {
|
||||
config.MaxOpenConns = maxOpen
|
||||
}
|
||||
|
||||
if lifetime := getEnvDuration("DB_CONN_MAX_LIFETIME", config.ConnMaxLifetime); lifetime > 0 {
|
||||
config.ConnMaxLifetime = lifetime
|
||||
}
|
||||
|
||||
if idleTime := getEnvDuration("DB_CONN_MAX_IDLE_TIME", config.ConnMaxIdleTime); idleTime > 0 {
|
||||
config.ConnMaxIdleTime = idleTime
|
||||
}
|
||||
|
||||
// Logging settings
|
||||
config.LogLevel = getLogLevel()
|
||||
config.Debug = getEnvBool("DB_DEBUG", config.Debug)
|
||||
|
||||
if slowLog := getEnvDuration("DB_SLOW_QUERY_LOG", config.SlowQueryLog); slowLog > 0 {
|
||||
config.SlowQueryLog = slowLog
|
||||
}
|
||||
|
||||
// Migration settings
|
||||
config.AutoMigrate = getEnvBool("DB_AUTO_MIGRATE", config.AutoMigrate)
|
||||
config.DropTableFirst = getEnvBool("DB_DROP_TABLE_FIRST", config.DropTableFirst)
|
||||
|
||||
if batchSize := getEnvInt("DB_CREATE_BATCH_SIZE", config.CreateBatchSize); batchSize > 0 {
|
||||
config.CreateBatchSize = batchSize
|
||||
}
|
||||
|
||||
// Performance settings
|
||||
config.PrepareStmt = getEnvBool("DB_PREPARE_STMT", config.PrepareStmt)
|
||||
config.DisableForeignKeyCheck = getEnvBool("DB_DISABLE_FOREIGN_KEY_CHECK", config.DisableForeignKeyCheck)
|
||||
config.IgnoreRelationshipsWhenMigrating = getEnvBool("DB_IGNORE_RELATIONSHIPS_WHEN_MIGRATING", config.IgnoreRelationshipsWhenMigrating)
|
||||
|
||||
// Development settings
|
||||
config.DryRun = getEnvBool("DB_DRY_RUN", config.DryRun)
|
||||
config.QueryFields = getEnvBool("DB_QUERY_FIELDS", config.QueryFields)
|
||||
|
||||
if inBatches := getEnvInt("DB_CREATE_IN_BATCHES", config.CreateInBatches); inBatches > 0 {
|
||||
config.CreateInBatches = inBatches
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// mergeWithEnvVars merges environment variables into existing config
|
||||
// Environment variables take precedence over config file values
|
||||
func mergeWithEnvVars(config *Config) *Config {
|
||||
// Database path
|
||||
if dbPath := os.Getenv("DB_PATH"); dbPath != "" {
|
||||
config.DatabasePath = dbPath
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if maxIdle := getEnvInt("DB_MAX_IDLE_CONNS", config.MaxIdleConns); maxIdle > 0 {
|
||||
config.MaxIdleConns = maxIdle
|
||||
}
|
||||
|
||||
if maxOpen := getEnvInt("DB_MAX_OPEN_CONNS", config.MaxOpenConns); maxOpen > 0 {
|
||||
config.MaxOpenConns = maxOpen
|
||||
}
|
||||
|
||||
if lifetime := getEnvDuration("DB_CONN_MAX_LIFETIME", config.ConnMaxLifetime); lifetime > 0 {
|
||||
config.ConnMaxLifetime = lifetime
|
||||
}
|
||||
|
||||
if idleTime := getEnvDuration("DB_CONN_MAX_IDLE_TIME", config.ConnMaxIdleTime); idleTime > 0 {
|
||||
config.ConnMaxIdleTime = idleTime
|
||||
}
|
||||
|
||||
// Logging settings
|
||||
if envLogLevel := getLogLevel(); envLogLevel != logger.Silent {
|
||||
config.LogLevel = envLogLevel
|
||||
}
|
||||
if envDebug := os.Getenv("DB_DEBUG"); envDebug != "" {
|
||||
config.Debug = getEnvBool("DB_DEBUG", config.Debug)
|
||||
}
|
||||
|
||||
if slowLog := getEnvDuration("DB_SLOW_QUERY_LOG", config.SlowQueryLog); slowLog > 0 {
|
||||
config.SlowQueryLog = slowLog
|
||||
}
|
||||
|
||||
// Migration settings
|
||||
if envAutoMigrate := os.Getenv("DB_AUTO_MIGRATE"); envAutoMigrate != "" {
|
||||
config.AutoMigrate = getEnvBool("DB_AUTO_MIGRATE", config.AutoMigrate)
|
||||
}
|
||||
if envDropFirst := os.Getenv("DB_DROP_TABLE_FIRST"); envDropFirst != "" {
|
||||
config.DropTableFirst = getEnvBool("DB_DROP_TABLE_FIRST", config.DropTableFirst)
|
||||
}
|
||||
|
||||
if batchSize := getEnvInt("DB_CREATE_BATCH_SIZE", config.CreateBatchSize); batchSize > 0 {
|
||||
config.CreateBatchSize = batchSize
|
||||
}
|
||||
|
||||
// Performance settings
|
||||
if envPrepare := os.Getenv("DB_PREPARE_STMT"); envPrepare != "" {
|
||||
config.PrepareStmt = getEnvBool("DB_PREPARE_STMT", config.PrepareStmt)
|
||||
}
|
||||
if envFKCheck := os.Getenv("DB_DISABLE_FOREIGN_KEY_CHECK"); envFKCheck != "" {
|
||||
config.DisableForeignKeyCheck = getEnvBool("DB_DISABLE_FOREIGN_KEY_CHECK", config.DisableForeignKeyCheck)
|
||||
}
|
||||
if envIgnoreRel := os.Getenv("DB_IGNORE_RELATIONSHIPS_WHEN_MIGRATING"); envIgnoreRel != "" {
|
||||
config.IgnoreRelationshipsWhenMigrating = getEnvBool("DB_IGNORE_RELATIONSHIPS_WHEN_MIGRATING", config.IgnoreRelationshipsWhenMigrating)
|
||||
}
|
||||
|
||||
// Development settings
|
||||
if envDryRun := os.Getenv("DB_DRY_RUN"); envDryRun != "" {
|
||||
config.DryRun = getEnvBool("DB_DRY_RUN", config.DryRun)
|
||||
}
|
||||
if envQueryFields := os.Getenv("DB_QUERY_FIELDS"); envQueryFields != "" {
|
||||
config.QueryFields = getEnvBool("DB_QUERY_FIELDS", config.QueryFields)
|
||||
}
|
||||
|
||||
if inBatches := getEnvInt("DB_CREATE_IN_BATCHES", config.CreateInBatches); inBatches > 0 {
|
||||
config.CreateInBatches = inBatches
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *Config) Validate() error {
|
||||
if c.DatabasePath == "" {
|
||||
return fmt.Errorf("database path cannot be empty")
|
||||
}
|
||||
|
||||
if c.MaxIdleConns < 0 {
|
||||
return fmt.Errorf("max idle connections cannot be negative")
|
||||
}
|
||||
|
||||
if c.MaxOpenConns < 0 {
|
||||
return fmt.Errorf("max open connections cannot be negative")
|
||||
}
|
||||
|
||||
if c.MaxIdleConns > c.MaxOpenConns && c.MaxOpenConns > 0 {
|
||||
return fmt.Errorf("max idle connections (%d) cannot be greater than max open connections (%d)",
|
||||
c.MaxIdleConns, c.MaxOpenConns)
|
||||
}
|
||||
|
||||
if c.ConnMaxLifetime < 0 {
|
||||
return fmt.Errorf("connection max lifetime cannot be negative")
|
||||
}
|
||||
|
||||
if c.ConnMaxIdleTime < 0 {
|
||||
return fmt.Errorf("connection max idle time cannot be negative")
|
||||
}
|
||||
|
||||
if c.SlowQueryLog < 0 {
|
||||
return fmt.Errorf("slow query log threshold cannot be negative")
|
||||
}
|
||||
|
||||
if c.CreateBatchSize <= 0 {
|
||||
return fmt.Errorf("create batch size must be greater than 0")
|
||||
}
|
||||
|
||||
if c.CreateInBatches <= 0 {
|
||||
return fmt.Errorf("create in batches size must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a string representation of the configuration
|
||||
func (c *Config) String() string {
|
||||
return fmt.Sprintf(`Database Configuration:
|
||||
Database Path: %s
|
||||
Max Idle Connections: %d
|
||||
Max Open Connections: %d
|
||||
Connection Max Lifetime: %v
|
||||
Connection Max Idle Time: %v
|
||||
Log Level: %v
|
||||
Slow Query Log Threshold: %v
|
||||
Auto Migrate: %t
|
||||
Prepare Statements: %t
|
||||
Debug Mode: %t
|
||||
Dry Run: %t
|
||||
Create Batch Size: %d
|
||||
Create In Batches: %d`,
|
||||
c.DatabasePath,
|
||||
c.MaxIdleConns,
|
||||
c.MaxOpenConns,
|
||||
c.ConnMaxLifetime,
|
||||
c.ConnMaxIdleTime,
|
||||
c.LogLevel,
|
||||
c.SlowQueryLog,
|
||||
c.AutoMigrate,
|
||||
c.PrepareStmt,
|
||||
c.Debug,
|
||||
c.DryRun,
|
||||
c.CreateBatchSize,
|
||||
c.CreateInBatches,
|
||||
)
|
||||
}
|
||||
|
||||
// IsProduction returns true if running in production environment
|
||||
func (c *Config) IsProduction() bool {
|
||||
env := os.Getenv("ENV")
|
||||
return env == "production" || env == "prod"
|
||||
}
|
||||
|
||||
// IsDevelopment returns true if running in development environment
|
||||
func (c *Config) IsDevelopment() bool {
|
||||
env := os.Getenv("ENV")
|
||||
return env == "development" || env == "dev" || env == ""
|
||||
}
|
||||
|
||||
// IsTest returns true if running in test environment
|
||||
func (c *Config) IsTest() bool {
|
||||
env := os.Getenv("ENV")
|
||||
return env == "test" || env == "testing"
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getEnvInt(key string, defaultValue int) int {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvBool(key string, defaultValue bool) bool {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if boolValue, err := strconv.ParseBool(value); err == nil {
|
||||
return boolValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
if duration, err := time.ParseDuration(value); err == nil {
|
||||
return duration
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getLogLevel() logger.LogLevel {
|
||||
debug := getEnvBool("DB_DEBUG", false)
|
||||
env := os.Getenv("ENV")
|
||||
logLevel := os.Getenv("DB_LOG_LEVEL")
|
||||
|
||||
switch {
|
||||
case debug:
|
||||
return logger.Info
|
||||
case env == "development" || env == "dev":
|
||||
return logger.Warn
|
||||
case env == "test" || env == "testing":
|
||||
return logger.Silent
|
||||
case logLevel == "silent":
|
||||
return logger.Silent
|
||||
case logLevel == "error":
|
||||
return logger.Error
|
||||
case logLevel == "warn":
|
||||
return logger.Warn
|
||||
case logLevel == "info":
|
||||
return logger.Info
|
||||
default:
|
||||
return logger.Silent
|
||||
}
|
||||
}
|
||||
|
||||
// getLogLevelFromString converts string log level to GORM logger level
|
||||
func getLogLevelFromString(level string) logger.LogLevel {
|
||||
switch strings.ToLower(level) {
|
||||
case "silent":
|
||||
return logger.Silent
|
||||
case "error":
|
||||
return logger.Error
|
||||
case "warn", "warning":
|
||||
return logger.Warn
|
||||
case "info":
|
||||
return logger.Info
|
||||
default:
|
||||
return logger.Silent
|
||||
}
|
||||
}
|
||||
|
||||
// Environment variable documentation
|
||||
/*
|
||||
Available Environment Variables:
|
||||
|
||||
Database Settings:
|
||||
DB_PATH - Database file path (default: "fuel_stops.db")
|
||||
DB_AUTO_MIGRATE - Enable automatic migrations (default: true)
|
||||
DB_DROP_TABLE_FIRST - Drop tables before migration (default: false)
|
||||
|
||||
Connection Pool Settings:
|
||||
DB_MAX_IDLE_CONNS - Maximum idle connections (default: 10)
|
||||
DB_MAX_OPEN_CONNS - Maximum open connections (default: 100)
|
||||
DB_CONN_MAX_LIFETIME - Connection maximum lifetime (default: "1h")
|
||||
DB_CONN_MAX_IDLE_TIME - Connection maximum idle time (default: "30m")
|
||||
|
||||
Logging Settings:
|
||||
DB_DEBUG - Enable debug logging (default: false)
|
||||
DB_LOG_LEVEL - Log level: silent, error, warn, info (default: silent)
|
||||
DB_SLOW_QUERY_LOG - Slow query threshold (default: "200ms")
|
||||
|
||||
Performance Settings:
|
||||
DB_PREPARE_STMT - Use prepared statements (default: true)
|
||||
DB_CREATE_BATCH_SIZE - Batch size for migrations (default: 1000)
|
||||
DB_CREATE_IN_BATCHES - Batch size for bulk operations (default: 100)
|
||||
DB_QUERY_FIELDS - Select only required fields (default: false)
|
||||
|
||||
Development Settings:
|
||||
ENV - Environment: development, production, test
|
||||
DB_DRY_RUN - Enable dry run mode (default: false)
|
||||
DB_DISABLE_FOREIGN_KEY_CHECK - Disable FK checks (default: false)
|
||||
DB_IGNORE_RELATIONSHIPS_WHEN_MIGRATING - Ignore relationships in migration (default: false)
|
||||
|
||||
Examples:
|
||||
export DB_DEBUG=true
|
||||
export DB_MAX_OPEN_CONNS=200
|
||||
export DB_CONN_MAX_LIFETIME=2h
|
||||
export DB_LOG_LEVEL=info
|
||||
export ENV=development
|
||||
*/
|
||||
@@ -0,0 +1,894 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"tankstopp/internal/models"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
conn *gorm.DB
|
||||
}
|
||||
|
||||
// NewDB creates a new database connection using GORM with configuration
|
||||
func NewDB(config *Config) (*DB, error) {
|
||||
// Validate configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
// Configure GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(config.LogLevel),
|
||||
PrepareStmt: config.PrepareStmt,
|
||||
DisableForeignKeyConstraintWhenMigrating: config.DisableForeignKeyCheck,
|
||||
IgnoreRelationshipsWhenMigrating: config.IgnoreRelationshipsWhenMigrating,
|
||||
QueryFields: config.QueryFields,
|
||||
CreateBatchSize: config.CreateBatchSize,
|
||||
DryRun: config.DryRun,
|
||||
}
|
||||
|
||||
// Configure slow query logging
|
||||
if config.SlowQueryLog > 0 {
|
||||
env := os.Getenv("ENV")
|
||||
isDev := env == "development" || env == "dev" || env == ""
|
||||
|
||||
customLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: config.SlowQueryLog,
|
||||
LogLevel: config.LogLevel,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
Colorful: isDev,
|
||||
},
|
||||
)
|
||||
gormConfig.Logger = customLogger
|
||||
}
|
||||
|
||||
conn, err := gorm.Open(sqlite.Open(config.DatabasePath), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Get underlying SQL DB to configure connection pool
|
||||
sqlDB, err := conn.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool settings from configuration
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
sqlDB.SetConnMaxIdleTime(config.ConnMaxIdleTime)
|
||||
|
||||
db := &DB{conn: conn}
|
||||
|
||||
// Run migrations if enabled
|
||||
if config.AutoMigrate {
|
||||
if err := db.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// NewDBWithDefaults creates a new database connection with default configuration
|
||||
func NewDBWithDefaults(databasePath string) (*DB, error) {
|
||||
config := DefaultConfig()
|
||||
config.DatabasePath = databasePath
|
||||
return NewDB(config)
|
||||
}
|
||||
|
||||
// NewDBFromEnv creates a new database connection using environment variables
|
||||
func NewDBFromEnv() (*DB, error) {
|
||||
config := LoadFromEnv()
|
||||
return NewDB(config)
|
||||
}
|
||||
|
||||
// NewDBFromConfig creates a new database connection using configuration file
|
||||
func NewDBFromConfig(configPath string) (*DB, error) {
|
||||
config := LoadFromConfig(configPath)
|
||||
return NewDB(config)
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (db *DB) Close() error {
|
||||
sqlDB, err := db.conn.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// migrate runs database migrations
|
||||
func (db *DB) migrate() error {
|
||||
// Auto-migrate the schema
|
||||
return db.conn.AutoMigrate(&models.User{}, &models.Vehicle{}, &models.FuelStop{})
|
||||
}
|
||||
|
||||
// CreateFuelStop inserts a new fuel stop into the database
|
||||
func (db *DB) CreateFuelStop(stop *models.FuelStop) error {
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
stop.CreatedAt = now
|
||||
stop.UpdatedAt = now
|
||||
|
||||
result := db.conn.Create(stop)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to create fuel stop: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFuelStops retrieves all fuel stops for a specific user from the database
|
||||
func (db *DB) GetFuelStops(userID uint) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Preload("Vehicle").Where("user_id = ?", userID).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stops: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// GetFuelStopsByVehicle retrieves all fuel stops for a specific vehicle
|
||||
func (db *DB) GetFuelStopsByVehicle(vehicleID, userID uint) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Where("vehicle_id = ? AND user_id = ?", vehicleID, userID).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stops by vehicle: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// GetFuelStopByID retrieves a fuel stop by its ID and user ID
|
||||
func (db *DB) GetFuelStopByID(id, userID uint) (*models.FuelStop, error) {
|
||||
var stop models.FuelStop
|
||||
|
||||
result := db.conn.Where("id = ? AND user_id = ?", id, userID).First(&stop)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil // Return nil when record not found
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get fuel stop: %w", result.Error)
|
||||
}
|
||||
|
||||
return &stop, nil
|
||||
}
|
||||
|
||||
// UpdateFuelStop updates an existing fuel stop
|
||||
func (db *DB) UpdateFuelStop(stop *models.FuelStop) error {
|
||||
// Update timestamp
|
||||
stop.UpdatedAt = time.Now()
|
||||
|
||||
result := db.conn.Model(stop).
|
||||
Where("id = ? AND user_id = ?", stop.ID, stop.UserID).
|
||||
Updates(stop)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update fuel stop: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("fuel stop not found or access denied")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteFuelStop deletes a fuel stop by its ID and user ID
|
||||
func (db *DB) DeleteFuelStop(id, userID uint) error {
|
||||
result := db.conn.Where("id = ? AND user_id = ?", id, userID).Delete(&models.FuelStop{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete fuel stop: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("fuel stop not found or access denied")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateVehicle creates a new vehicle for a user
|
||||
func (db *DB) CreateVehicle(vehicle *models.Vehicle) error {
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
vehicle.CreatedAt = now
|
||||
vehicle.UpdatedAt = now
|
||||
|
||||
result := db.conn.Create(vehicle)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to create vehicle: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVehicles retrieves all vehicles for a specific user
|
||||
func (db *DB) GetVehicles(userID uint) ([]models.Vehicle, error) {
|
||||
var vehicles []models.Vehicle
|
||||
|
||||
result := db.conn.Where("user_id = ?", userID).
|
||||
Order("is_active DESC, name ASC").
|
||||
Find(&vehicles)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get vehicles: %w", result.Error)
|
||||
}
|
||||
|
||||
return vehicles, nil
|
||||
}
|
||||
|
||||
// GetActiveVehicles retrieves only active vehicles for a specific user
|
||||
func (db *DB) GetActiveVehicles(userID uint) ([]models.Vehicle, error) {
|
||||
var vehicles []models.Vehicle
|
||||
|
||||
result := db.conn.Where("user_id = ? AND is_active = ?", userID, true).
|
||||
Order("name ASC").
|
||||
Find(&vehicles)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get active vehicles: %w", result.Error)
|
||||
}
|
||||
|
||||
return vehicles, nil
|
||||
}
|
||||
|
||||
// GetVehicleByID retrieves a vehicle by its ID and user ID
|
||||
func (db *DB) GetVehicleByID(id, userID uint) (*models.Vehicle, error) {
|
||||
var vehicle models.Vehicle
|
||||
|
||||
result := db.conn.Where("id = ? AND user_id = ?", id, userID).First(&vehicle)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil // Return nil when vehicle not found
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get vehicle: %w", result.Error)
|
||||
}
|
||||
|
||||
return &vehicle, nil
|
||||
}
|
||||
|
||||
// UpdateVehicle updates an existing vehicle
|
||||
func (db *DB) UpdateVehicle(vehicle *models.Vehicle) error {
|
||||
// Update timestamp
|
||||
vehicle.UpdatedAt = time.Now()
|
||||
|
||||
result := db.conn.Model(vehicle).
|
||||
Where("id = ? AND user_id = ?", vehicle.ID, vehicle.UserID).
|
||||
Updates(vehicle)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update vehicle: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("vehicle not found or access denied")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVehicle deletes a vehicle by its ID and user ID
|
||||
func (db *DB) DeleteVehicle(id, userID uint) error {
|
||||
// Check if vehicle has fuel stops
|
||||
var count int64
|
||||
if err := db.conn.Model(&models.FuelStop{}).Where("vehicle_id = ?", id).Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("failed to check fuel stops: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return fmt.Errorf("cannot delete vehicle with existing fuel stops")
|
||||
}
|
||||
|
||||
result := db.conn.Where("id = ? AND user_id = ?", id, userID).Delete(&models.Vehicle{})
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to delete vehicle: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("vehicle not found or access denied")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVehicleWithFuelStops retrieves a vehicle with its fuel stops
|
||||
func (db *DB) GetVehicleWithFuelStops(vehicleID, userID uint) (*models.Vehicle, error) {
|
||||
var vehicle models.Vehicle
|
||||
|
||||
result := db.conn.Preload("FuelStops", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("date DESC")
|
||||
}).Where("id = ? AND user_id = ?", vehicleID, userID).First(&vehicle)
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get vehicle with fuel stops: %w", result.Error)
|
||||
}
|
||||
|
||||
return &vehicle, nil
|
||||
}
|
||||
|
||||
// GetVehicleCount returns the total number of vehicles for a user
|
||||
func (db *DB) GetVehicleCount(userID uint) (int64, error) {
|
||||
var count int64
|
||||
result := db.conn.Model(&models.Vehicle{}).Where("user_id = ?", userID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to count vehicles: %w", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// ValidateVehicle validates vehicle data before creation/update
|
||||
func (db *DB) ValidateVehicle(vehicle *models.Vehicle) error {
|
||||
if vehicle.UserID == 0 {
|
||||
return fmt.Errorf("user ID is required")
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
exists, err := db.UserExists(vehicle.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate user: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user does not exist")
|
||||
}
|
||||
|
||||
if vehicle.Name == "" {
|
||||
return fmt.Errorf("vehicle name is required")
|
||||
}
|
||||
|
||||
if len(vehicle.Name) > 100 {
|
||||
return fmt.Errorf("vehicle name must be 100 characters or less")
|
||||
}
|
||||
|
||||
if vehicle.Make != "" && len(vehicle.Make) > 50 {
|
||||
return fmt.Errorf("vehicle make must be 50 characters or less")
|
||||
}
|
||||
|
||||
if vehicle.Model != "" && len(vehicle.Model) > 50 {
|
||||
return fmt.Errorf("vehicle model must be 50 characters or less")
|
||||
}
|
||||
|
||||
if vehicle.LicensePlate != "" && len(vehicle.LicensePlate) > 20 {
|
||||
return fmt.Errorf("license plate must be 20 characters or less")
|
||||
}
|
||||
|
||||
if vehicle.FuelType != "" && len(vehicle.FuelType) > 50 {
|
||||
return fmt.Errorf("fuel type must be 50 characters or less")
|
||||
}
|
||||
|
||||
if vehicle.Year < 0 || vehicle.Year > time.Now().Year()+1 {
|
||||
return fmt.Errorf("invalid vehicle year")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateVehicleWithValidation creates a vehicle with validation
|
||||
func (db *DB) CreateVehicleWithValidation(vehicle *models.Vehicle) error {
|
||||
if err := db.ValidateVehicle(vehicle); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
return db.CreateVehicle(vehicle)
|
||||
}
|
||||
|
||||
// GetFuelStopStats calculates statistics for fuel consumption for a specific user
|
||||
func (db *DB) GetFuelStopStats(userID uint) (*models.FuelStopStats, error) {
|
||||
stats := &models.FuelStopStats{}
|
||||
|
||||
// Get basic statistics
|
||||
var result struct {
|
||||
TotalStops int64 `json:"total_stops"`
|
||||
TotalLiters float64 `json:"total_liters"`
|
||||
TotalSpent float64 `json:"total_spent"`
|
||||
AveragePrice float64 `json:"average_price"`
|
||||
TotalTripKm float64 `json:"total_trip_km"`
|
||||
MinOdometer int `json:"min_odometer"`
|
||||
MaxOdometer int `json:"max_odometer"`
|
||||
}
|
||||
|
||||
err := db.conn.Model(&models.FuelStop{}).
|
||||
Select(`
|
||||
COUNT(*) as total_stops,
|
||||
COALESCE(SUM(liters), 0) as total_liters,
|
||||
COALESCE(SUM(total_price), 0) as total_spent,
|
||||
COALESCE(AVG(price_per_l), 0) as average_price,
|
||||
COALESCE(SUM(trip_length), 0) as total_trip_km,
|
||||
COALESCE(MIN(odometer), 0) as min_odometer,
|
||||
COALESCE(MAX(odometer), 0) as max_odometer
|
||||
`).
|
||||
Where("user_id = ?", userID).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stop stats: %w", err)
|
||||
}
|
||||
|
||||
stats.TotalStops = int(result.TotalStops)
|
||||
stats.TotalLiters = result.TotalLiters
|
||||
stats.TotalSpent = result.TotalSpent
|
||||
stats.AveragePrice = result.AveragePrice
|
||||
|
||||
// Get the last fillup
|
||||
var lastStop models.FuelStop
|
||||
err = db.conn.Where("user_id = ?", userID).
|
||||
Order("date DESC").
|
||||
First(&lastStop).Error
|
||||
|
||||
if err != nil && err != gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("failed to get last fillup: %w", err)
|
||||
}
|
||||
|
||||
if err != gorm.ErrRecordNotFound {
|
||||
stats.LastFillup = &lastStop
|
||||
}
|
||||
|
||||
// Calculate average consumption using trip length (preferred) or odometer difference (fallback)
|
||||
if stats.TotalStops > 1 {
|
||||
// Primary method: Use trip length if available
|
||||
if result.TotalTripKm > 0 {
|
||||
stats.AverageConsumption = (stats.TotalLiters / result.TotalTripKm) * 100
|
||||
} else if result.MaxOdometer > result.MinOdometer {
|
||||
// Fallback method: Use odometer difference
|
||||
distanceDriven := result.MaxOdometer - result.MinOdometer
|
||||
if distanceDriven > 0 {
|
||||
stats.AverageConsumption = (stats.TotalLiters / float64(distanceDriven)) * 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in the database
|
||||
func (db *DB) CreateUser(user *models.User) error {
|
||||
// Hash the password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
user.Password = "" // Clear the plain text password
|
||||
|
||||
// Set timestamps
|
||||
now := time.Now()
|
||||
user.CreatedAt = now
|
||||
user.UpdatedAt = now
|
||||
|
||||
result := db.conn.Create(user)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to create user: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByUsername retrieves a user by username
|
||||
func (db *DB) GetUserByUsername(username string) (*models.User, error) {
|
||||
var user models.User
|
||||
|
||||
result := db.conn.Where("username = ?", username).First(&user)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil // Return nil when user not found
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by username: %w", result.Error)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by ID
|
||||
func (db *DB) GetUserByID(id uint) (*models.User, error) {
|
||||
var user models.User
|
||||
|
||||
result := db.conn.First(&user, id)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil // Return nil when user not found
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by ID: %w", result.Error)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates an existing user
|
||||
func (db *DB) UpdateUser(user *models.User) error {
|
||||
// Update timestamp
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
result := db.conn.Model(user).
|
||||
Select("email", "base_currency", "updated_at").
|
||||
Updates(user)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update user: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateUserPassword updates a user's password
|
||||
func (db *DB) UpdateUserPassword(user *models.User, newPassword string) error {
|
||||
// Hash the new password
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update timestamp
|
||||
now := time.Now()
|
||||
|
||||
// Update only the password hash and updated_at fields
|
||||
result := db.conn.Model(user).
|
||||
Updates(map[string]interface{}{
|
||||
"password_hash": string(hashedPassword),
|
||||
"updated_at": now,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("failed to update password: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
// Update the user object with new password hash
|
||||
user.PasswordHash = string(hashedPassword)
|
||||
user.UpdatedAt = now
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email
|
||||
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
|
||||
var user models.User
|
||||
|
||||
result := db.conn.Where("email = ?", email).First(&user)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil // Return nil when user not found
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by email: %w", result.Error)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserWithFuelStops retrieves a user with their fuel stops
|
||||
func (db *DB) GetUserWithFuelStops(userID uint) (*models.User, error) {
|
||||
var user models.User
|
||||
|
||||
result := db.conn.Preload("FuelStops", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("date DESC")
|
||||
}).Preload("Vehicles").First(&user, userID)
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user with fuel stops: %w", result.Error)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetFuelStopsWithPagination retrieves fuel stops with pagination
|
||||
func (db *DB) GetFuelStopsWithPagination(userID uint, limit, offset int) ([]models.FuelStop, int64, error) {
|
||||
var stops []models.FuelStop
|
||||
var total int64
|
||||
|
||||
// Get total count
|
||||
err := db.conn.Model(&models.FuelStop{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&total).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count fuel stops: %w", err)
|
||||
}
|
||||
|
||||
// Get paginated results
|
||||
err = db.conn.Where("user_id = ?", userID).
|
||||
Order("date DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&stops).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to get fuel stops with pagination: %w", err)
|
||||
}
|
||||
|
||||
return stops, total, nil
|
||||
}
|
||||
|
||||
// GetFuelStopsByDateRange retrieves fuel stops within a date range
|
||||
func (db *DB) GetFuelStopsByDateRange(userID uint, startDate, endDate time.Time) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Where("user_id = ? AND date BETWEEN ? AND ?", userID, startDate, endDate).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stops by date range: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// GetFuelStopsByFuelType retrieves fuel stops by fuel type
|
||||
func (db *DB) GetFuelStopsByFuelType(userID uint, fuelType string) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Where("user_id = ? AND fuel_type = ?", userID, fuelType).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stops by fuel type: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStats retrieves monthly statistics for a user
|
||||
func (db *DB) GetMonthlyStats(userID uint, year int) ([]models.MonthlyStats, error) {
|
||||
var stats []models.MonthlyStats
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
strftime('%m', date) as month,
|
||||
strftime('%Y', date) as year,
|
||||
COUNT(*) as total_stops,
|
||||
SUM(liters) as total_liters,
|
||||
SUM(total_price) as total_spent,
|
||||
AVG(price_per_l) as average_price
|
||||
FROM fuel_stops
|
||||
WHERE user_id = ? AND strftime('%Y', date) = ?
|
||||
GROUP BY strftime('%Y-%m', date)
|
||||
ORDER BY month
|
||||
`
|
||||
|
||||
err := db.conn.Raw(query, userID, fmt.Sprintf("%d", year)).Scan(&stats).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get monthly stats: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// BulkCreateFuelStops creates multiple fuel stops in a single transaction
|
||||
func (db *DB) BulkCreateFuelStops(stops []models.FuelStop) error {
|
||||
if len(stops) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set timestamps for all stops
|
||||
now := time.Now()
|
||||
for i := range stops {
|
||||
stops[i].CreatedAt = now
|
||||
stops[i].UpdatedAt = now
|
||||
}
|
||||
|
||||
// Use transaction for bulk insert
|
||||
return db.conn.Transaction(func(tx *gorm.DB) error {
|
||||
return tx.CreateInBatches(stops, 100).Error
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteAllUserData deletes all data for a user (for account deletion)
|
||||
func (db *DB) DeleteAllUserData(userID uint) error {
|
||||
return db.conn.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete all fuel stops first (due to foreign key constraint)
|
||||
if err := tx.Where("user_id = ?", userID).Delete(&models.FuelStop{}).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete fuel stops: %w", err)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
if err := tx.Delete(&models.User{}, userID).Error; err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// HealthCheck performs a simple health check on the database
|
||||
func (db *DB) HealthCheck() error {
|
||||
sqlDB, err := db.conn.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
||||
}
|
||||
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetFuelStopCount returns the total number of fuel stops for a user
|
||||
func (db *DB) GetFuelStopCount(userID uint) (int64, error) {
|
||||
var count int64
|
||||
result := db.conn.Model(&models.FuelStop{}).Where("user_id = ?", userID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to count fuel stops: %w", result.Error)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetLatestFuelStop returns the most recent fuel stop for a user
|
||||
func (db *DB) GetLatestFuelStop(userID uint) (*models.FuelStop, error) {
|
||||
var stop models.FuelStop
|
||||
result := db.conn.Where("user_id = ?", userID).Order("date DESC").First(&stop)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get latest fuel stop: %w", result.Error)
|
||||
}
|
||||
return &stop, nil
|
||||
}
|
||||
|
||||
// UserExists checks if a user exists by ID
|
||||
func (db *DB) UserExists(userID uint) (bool, error) {
|
||||
var count int64
|
||||
result := db.conn.Model(&models.User{}).Where("id = ?", userID).Count(&count)
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("failed to check user existence: %w", result.Error)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ValidateFuelStop validates fuel stop data before creation/update
|
||||
func (db *DB) ValidateFuelStop(stop *models.FuelStop) error {
|
||||
if stop.UserID == 0 {
|
||||
return fmt.Errorf("user ID is required")
|
||||
}
|
||||
|
||||
// Check if user exists
|
||||
exists, err := db.UserExists(stop.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate user: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("user does not exist")
|
||||
}
|
||||
|
||||
if stop.Date.IsZero() {
|
||||
return fmt.Errorf("date is required")
|
||||
}
|
||||
|
||||
if stop.StationName == "" {
|
||||
return fmt.Errorf("station name is required")
|
||||
}
|
||||
|
||||
if stop.Location == "" {
|
||||
return fmt.Errorf("location is required")
|
||||
}
|
||||
|
||||
if stop.FuelType == "" {
|
||||
return fmt.Errorf("fuel type is required")
|
||||
}
|
||||
|
||||
if stop.Liters <= 0 {
|
||||
return fmt.Errorf("liters must be greater than 0")
|
||||
}
|
||||
|
||||
if stop.PricePerL <= 0 {
|
||||
return fmt.Errorf("price per liter must be greater than 0")
|
||||
}
|
||||
|
||||
if stop.TotalPrice <= 0 {
|
||||
return fmt.Errorf("total price must be greater than 0")
|
||||
}
|
||||
|
||||
if stop.Currency == "" {
|
||||
stop.Currency = "EUR" // Set default currency
|
||||
}
|
||||
|
||||
if stop.TripLength < 0 {
|
||||
return fmt.Errorf("trip length cannot be negative")
|
||||
}
|
||||
|
||||
if stop.VehicleID == 0 {
|
||||
return fmt.Errorf("vehicle is required")
|
||||
}
|
||||
|
||||
// Check if vehicle exists and belongs to user
|
||||
vehicle, err := db.GetVehicleByID(stop.VehicleID, stop.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to validate vehicle: %w", err)
|
||||
}
|
||||
if vehicle == nil {
|
||||
return fmt.Errorf("vehicle does not exist or access denied")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFuelStopWithValidation creates a fuel stop with validation
|
||||
func (db *DB) CreateFuelStopWithValidation(stop *models.FuelStop) error {
|
||||
if err := db.ValidateFuelStop(stop); err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
return db.CreateFuelStop(stop)
|
||||
}
|
||||
|
||||
// GetFuelStopsWithUser retrieves fuel stops with user information preloaded
|
||||
func (db *DB) GetFuelStopsWithUser(userID uint) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Preload("User").Preload("Vehicle").
|
||||
Where("user_id = ?", userID).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get fuel stops with user: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// SearchFuelStops performs a text search across station names and locations
|
||||
func (db *DB) SearchFuelStops(userID uint, searchTerm string) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
result := db.conn.Where("user_id = ? AND (station_name LIKE ? OR location LIKE ?)",
|
||||
userID, searchPattern, searchPattern).
|
||||
Order("date DESC").
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to search fuel stops: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
|
||||
// GetRecentFuelStops returns the most recent N fuel stops for a user
|
||||
func (db *DB) GetRecentFuelStops(userID uint, limit int) ([]models.FuelStop, error) {
|
||||
var stops []models.FuelStop
|
||||
|
||||
result := db.conn.Where("user_id = ?", userID).
|
||||
Order("date DESC").
|
||||
Limit(limit).
|
||||
Find(&stops)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("failed to get recent fuel stops: %w", result.Error)
|
||||
}
|
||||
|
||||
return stops, nil
|
||||
}
|
||||
Reference in New Issue
Block a user