#include <linux/proc_fs.h>
#include <linux/module.h>
#include <linux/uaccess.h> 
#include <linux/device.h>
#include <linux/usb/composite.h>

static struct gadget_desc_overwrite_table desc_overwrite;

void gadget_pfs_overwrite_desc_device(struct usb_device_descriptor *desc)
{
    struct usb_device_desc_overwrite *src = desc_overwrite.device;
    
    if (!desc || !src) return;
    
    desc->idVendor  = src->idVendor;
    desc->idProduct = src->idProduct;
    desc->bcdDevice = src->bcdDevice;
    
    return;
}
EXPORT_SYMBOL_GPL(gadget_pfs_overwrite_desc_device);

void gadget_pfs_overwrite_desc_config(struct usb_config_descriptor *config, enum usb_device_speed speed)
{
    struct usb_config_desc_overwrite *src = desc_overwrite.config;
    
    if (!config || !src) return;
    
    config->bmAttributes &= ~USB_CONFIG_ATT_SELFPOWER;
    config->bmAttributes |= src->bmAttributes;
    
    switch (speed) {
    case USB_SPEED_SUPER:
        config->bMaxPower = src->bMaxPower_ss;
        break;
    default:
        config->bMaxPower = src->bMaxPower_hs;
        break;
    }
    
    return;
}
EXPORT_SYMBOL_GPL(gadget_pfs_overwrite_desc_config);

void gadget_pfs_overwrite_desc_string(struct usb_string *table)
{
    struct usb_string_desc_overwrite *string = desc_overwrite.string;
    
    if (!table || !string)
        return;
    
    if (string->manufacturer)
        table[USB_GADGET_MANUFACTURER_IDX].s = string->manufacturer->s;
    if (string->product)
        table[USB_GADGET_PRODUCT_IDX].s      = string->product->s;
    if (string->serial)
        table[USB_GADGET_SERIAL_IDX].s       = string->serial->s;
    
    return;
}
EXPORT_SYMBOL_GPL(gadget_pfs_overwrite_desc_string);

int gadget_pfs_overwrite_desc_bos(struct usb_composite_dev *cdev)
{
    uint8_t *dest;
    enum usb_device_speed max_speed;
    struct usb_bos_descriptor       *bos;
    struct usb_ext_cap_descriptor   *cap_ext;
    struct usb_ss_cap_descriptor    *cap_ss;
    struct usb_ssp_cap_descriptor_m *cap_ssp;
    int total_length = 0;
    
    if (!cdev || !desc_overwrite.bos)
        return 0;
    
    dest = cdev->req->buf;
    /* round to the lower speed */
    max_speed = min(cdev->driver->max_speed, cdev->gadget->max_speed);
    
    if (desc_overwrite.bos->bos) {
        bos = (struct usb_bos_descriptor*)dest;
        bos->bLength         = desc_overwrite.bos->bos->bLength;
        bos->bDescriptorType = desc_overwrite.bos->bos->bDescriptorType;
        bos->bNumDeviceCaps  = 0;
        bos->wTotalLength    = bos->bLength;
        
        dest += bos->bLength;
        
        if (desc_overwrite.bos->cap_ext) {
            cap_ext                     = (struct usb_ext_cap_descriptor*)dest;
            cap_ext->bLength            = desc_overwrite.bos->cap_ext->bLength;
            cap_ext->bDescriptorType    = desc_overwrite.bos->cap_ext->bDescriptorType;
            cap_ext->bDevCapabilityType = desc_overwrite.bos->cap_ext->bDevCapabilityType;
            cap_ext->bmAttributes       = desc_overwrite.bos->cap_ext->bmAttributes;
            bos->wTotalLength += cap_ext->bLength;
            bos->bNumDeviceCaps++;
            
            dest += cap_ext->bLength;
        };
        
        if ((desc_overwrite.bos->cap_ss) && (max_speed >= USB_SPEED_SUPER)) {
            cap_ss                        = (struct usb_ss_cap_descriptor*)dest;
            cap_ss->bLength               = desc_overwrite.bos->cap_ss->bLength;
            cap_ss->bDescriptorType       = desc_overwrite.bos->cap_ss->bDescriptorType;
            cap_ss->bDevCapabilityType    = desc_overwrite.bos->cap_ss->bDevCapabilityType;
            cap_ss->bmAttributes          = desc_overwrite.bos->cap_ss->bmAttributes;
            cap_ss->wSpeedSupported       = desc_overwrite.bos->cap_ss->wSpeedSupported;
            cap_ss->bFunctionalitySupport = desc_overwrite.bos->cap_ss->bFunctionalitySupport;
            cap_ss->bU1devExitLat         = desc_overwrite.bos->cap_ss->bU1devExitLat;
            cap_ss->bU2DevExitLat         = desc_overwrite.bos->cap_ss->bU2DevExitLat;
            bos->wTotalLength += cap_ss->bLength;
            bos->bNumDeviceCaps++;
            
            dest += cap_ss->bLength;
        };
        
        if ((desc_overwrite.bos->cap_ssp) && (max_speed >= USB_SPEED_SUPER_PLUS)) {
            cap_ssp                        = (struct usb_ssp_cap_descriptor_m*)dest;
            cap_ssp->bLength               = desc_overwrite.bos->cap_ssp->bLength;
            cap_ssp->bDescriptorType       = desc_overwrite.bos->cap_ssp->bDescriptorType;
            cap_ssp->bDevCapabilityType    = desc_overwrite.bos->cap_ssp->bDevCapabilityType;
            cap_ssp->bmAttributes          = desc_overwrite.bos->cap_ssp->bmAttributes;
            cap_ssp->wFunctionalitySupport = desc_overwrite.bos->cap_ssp->wFunctionalitySupport;
            cap_ssp->wReserved             = desc_overwrite.bos->cap_ssp->wReserved;
            memcpy(cap_ssp->bmSublinkSpeedAttr, desc_overwrite.bos->cap_ssp->bmSublinkSpeedAttr,
                sizeof(cap_ssp->bmSublinkSpeedAttr));
            bos->wTotalLength += cap_ssp->bLength;
            bos->bNumDeviceCaps++;
            
            dest += cap_ssp->bLength;
        };
        
        total_length = bos->wTotalLength;
    };
    
    return total_length;
}
EXPORT_SYMBOL_GPL(gadget_pfs_overwrite_desc_bos);

static void gadget_pfs_overwrite_desc_free(void)
{
    if (desc_overwrite.device) {
        kfree(desc_overwrite.device);
        desc_overwrite.device = NULL;
    }
    if (desc_overwrite.config) {
        kfree(desc_overwrite.config);
        desc_overwrite.config = NULL;
    }
    if (desc_overwrite.string) {
        if (desc_overwrite.string->manufacturer) {
            if (desc_overwrite.string->manufacturer->s) {
                kfree(desc_overwrite.string->manufacturer->s);
                desc_overwrite.string->manufacturer->s = NULL;
            }
            kfree(desc_overwrite.string->manufacturer);
            desc_overwrite.string->manufacturer = NULL;
        }
        if (desc_overwrite.string->product) {
            if (desc_overwrite.string->product->s) {
                kfree(desc_overwrite.string->product->s);
                desc_overwrite.string->product->s = NULL;
            }
            kfree(desc_overwrite.string->product);
            desc_overwrite.string->product = NULL;
        }
        if (desc_overwrite.string->serial) {
            if (desc_overwrite.string->serial->s) {
                kfree(desc_overwrite.string->serial->s);
                desc_overwrite.string->serial->s = NULL;
            }
            kfree(desc_overwrite.string->serial);
            desc_overwrite.string->serial = NULL;
        }
        kfree(desc_overwrite.string);
        desc_overwrite.string = NULL;
    }
    
    if (desc_overwrite.bos) {
        if (desc_overwrite.bos->bos) {
            kfree(desc_overwrite.bos->bos);
            desc_overwrite.bos->bos = NULL;
        }
        if (desc_overwrite.bos->cap_ext) {
            kfree(desc_overwrite.bos->cap_ext);
            desc_overwrite.bos->cap_ext = NULL;
        }
        if (desc_overwrite.bos->cap_ss) {
            kfree(desc_overwrite.bos->cap_ss);
            desc_overwrite.bos->cap_ss = NULL;
        }
        if (desc_overwrite.bos->cap_ssp) {
            kfree(desc_overwrite.bos->cap_ssp);
            desc_overwrite.bos->cap_ssp = NULL;
        }
        kfree(desc_overwrite.bos);
        desc_overwrite.bos = NULL;
    }
}

static ssize_t gadget_proc_read(struct file *filp, char __user *buf, size_t count, loff_t *f_pos)
{
    struct usb_device_desc_overwrite *device_desc = desc_overwrite.device;
    struct usb_config_desc_overwrite *config_desc = desc_overwrite.config;
    struct usb_string_desc_overwrite *string_desc = desc_overwrite.string;
    
    if (device_desc) {
        printk(KERN_INFO "--- device descriptor ---\n");
        printk(KERN_INFO "  idVendor : 0x%04X\n", device_desc->idVendor);
        printk(KERN_INFO "  idProduct: 0x%04X\n", device_desc->idProduct);
        printk(KERN_INFO "  bcdDevice: 0x%04X\n", device_desc->bcdDevice);
    }
    
    if (config_desc) {
        printk(KERN_INFO "--- config descriptor ---\n");
        printk(KERN_INFO "  bmAttributes  : 0x%02X\n", config_desc->bmAttributes);
        printk(KERN_INFO "  bMaxPower(hs) : %d mA\n", config_desc->bMaxPower_hs);
        printk(KERN_INFO "  bMaxPower(ss) : %d mA\n", config_desc->bMaxPower_ss);
    }
    
    if (string_desc) {
        printk(KERN_INFO "--- string descriptor ---\n");
        if (string_desc->manufacturer) {
            printk(KERN_INFO "  Manufacturer : %s(%d)\n", string_desc->manufacturer->s, string_desc->manufacturer->len);
        }
        if (string_desc->product) {
            printk(KERN_INFO "  Product      : %s(%d)\n", string_desc->product->s, string_desc->product->len);
        }
        if (string_desc->serial) {
            printk(KERN_INFO "  SerialNumber : %s(%d)\n", string_desc->serial->s, string_desc->serial->len);
        }
    }
    
    return 0;
}

static ssize_t gadget_proc_write(struct file *filp, const char __user *buf, size_t count, loff_t *f_pos)
{
    int ret;
    struct gadget_desc_overwrite_table desc_overwrite_tmp;

    /* free table if exists */
    gadget_pfs_overwrite_desc_free();
    
    if (count < sizeof(struct gadget_desc_overwrite_table)) {
        printk(KERN_ERR "write size shortage\n");
        return -EFAULT;
    }
    
    ret = copy_from_user(&desc_overwrite_tmp, buf, count);
    if (ret < 0) {
        return -EFAULT;
    }
    
    /* Device Descriptor */
    if (desc_overwrite_tmp.device) {
        desc_overwrite.device = (struct usb_device_desc_overwrite*)kmalloc(sizeof(struct usb_device_desc_overwrite), GFP_KERNEL);
        if (!desc_overwrite.device) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        ret = copy_from_user(desc_overwrite.device, desc_overwrite_tmp.device, sizeof(struct usb_device_desc_overwrite));
        if (ret < 0) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
    }
    
    /* Configuration Descriptor */
    if (desc_overwrite_tmp.config) {
        desc_overwrite.config = (struct usb_config_desc_overwrite*)kmalloc(sizeof(struct usb_config_desc_overwrite), GFP_KERNEL);
        if (!desc_overwrite.config) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        ret = copy_from_user(desc_overwrite.config, desc_overwrite_tmp.config, sizeof(struct usb_config_desc_overwrite));
        if (ret < 0) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
    }
    
    /* String Descriptor table */
    if (desc_overwrite_tmp.string) {
        desc_overwrite.string = (struct usb_string_desc_overwrite*)kmalloc(sizeof(struct usb_string_desc_overwrite), GFP_KERNEL);
        if (!desc_overwrite.string) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        ret = copy_from_user(desc_overwrite.string, desc_overwrite_tmp.string, sizeof(struct usb_string_desc_overwrite));
        if (ret < 0) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        
        if (desc_overwrite_tmp.string->manufacturer) {
            desc_overwrite.string->manufacturer = (struct usb_string_desc*)kmalloc(sizeof(struct usb_string_desc), GFP_KERNEL);
            if (!desc_overwrite.string->manufacturer) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->manufacturer,
                                 desc_overwrite_tmp.string->manufacturer,
                                 sizeof(struct usb_string_desc));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            
            desc_overwrite.string->manufacturer->s =
                (uint8_t*)kmalloc(desc_overwrite.string->manufacturer->len, GFP_KERNEL);
            if (!desc_overwrite.string->manufacturer->s) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->manufacturer->s,
                                 desc_overwrite_tmp.string->manufacturer->s,
                                 desc_overwrite.string->manufacturer->len);
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
        if (desc_overwrite_tmp.string->product) {
            desc_overwrite.string->product = (struct usb_string_desc*)kmalloc(sizeof(struct usb_string_desc), GFP_KERNEL);
            if (!desc_overwrite.string->product) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->product,
                                 desc_overwrite_tmp.string->product,
                                 sizeof(struct usb_string_desc));
            if (ret < 0) return -EFAULT;
            
            desc_overwrite.string->product->s =
                (uint8_t*)kmalloc(desc_overwrite.string->product->len, GFP_KERNEL);
            if (!desc_overwrite.string->product->s) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->product->s,
                                 desc_overwrite_tmp.string->product->s,
                                 desc_overwrite.string->product->len);
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
        if (desc_overwrite_tmp.string->serial) {
            desc_overwrite.string->serial = (struct usb_string_desc*)kmalloc(sizeof(struct usb_string_desc), GFP_KERNEL);
            if (!desc_overwrite.string->serial) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->serial,
                                 desc_overwrite_tmp.string->serial,
                                 sizeof(struct usb_string_desc));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            
            desc_overwrite.string->serial->s =
                (uint8_t*)kmalloc(desc_overwrite.string->serial->len, GFP_KERNEL);
            if (!desc_overwrite.string->serial->s) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.string->serial->s,
                                 desc_overwrite_tmp.string->serial->s,
                                 desc_overwrite.string->serial->len);
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
    }
    
    /* BOS Descriptor */
    if (desc_overwrite_tmp.bos) {
        desc_overwrite.bos = (struct usb_bos_desc_overwrite*)kmalloc(sizeof(struct usb_bos_desc_overwrite), GFP_KERNEL);
        if (!desc_overwrite.bos) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        ret = copy_from_user(desc_overwrite.bos, desc_overwrite_tmp.bos, sizeof(struct usb_bos_desc_overwrite));
        if (ret < 0) {
            printk(KERN_ERR "write mem error line:%d\n", __LINE__);
            return -EFAULT;
        }
        
        if (desc_overwrite_tmp.bos->bos) {
            desc_overwrite.bos->bos = (struct usb_bos_descriptor*)kmalloc(sizeof(struct usb_bos_descriptor), GFP_KERNEL);
            if (!desc_overwrite.bos->bos) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.bos->bos, desc_overwrite_tmp.bos->bos, sizeof(struct usb_bos_descriptor));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
        
            
        if (desc_overwrite_tmp.bos->cap_ext) {
            desc_overwrite.bos->cap_ext = (struct usb_ext_cap_descriptor*)kmalloc(sizeof(struct usb_ext_cap_descriptor), GFP_KERNEL);
            if (!desc_overwrite.bos->cap_ext) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.bos->cap_ext, desc_overwrite_tmp.bos->cap_ext, sizeof(struct usb_ext_cap_descriptor));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
        
        if (desc_overwrite_tmp.bos->cap_ss) {
            desc_overwrite.bos->cap_ss = (struct usb_ss_cap_descriptor*)kmalloc(sizeof(struct usb_ss_cap_descriptor), GFP_KERNEL);
            if (!desc_overwrite.bos->cap_ss) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.bos->cap_ss, desc_overwrite_tmp.bos->cap_ss, sizeof(struct usb_ss_cap_descriptor));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
        
        if (desc_overwrite_tmp.bos->cap_ssp) {
            desc_overwrite.bos->cap_ssp = (struct usb_ssp_cap_descriptor_m*)kmalloc(sizeof(struct usb_ssp_cap_descriptor_m), GFP_KERNEL);
            if (!desc_overwrite.bos->cap_ssp) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
            ret = copy_from_user(desc_overwrite.bos->cap_ssp, desc_overwrite_tmp.bos->cap_ssp, sizeof(struct usb_ssp_cap_descriptor_m));
            if (ret < 0) {
                printk(KERN_ERR "write mem error line:%d\n", __LINE__);
                return -EFAULT;
            }
        }
    }
    
    return count;
}

static struct file_operations gadget_proc_fops = {
    .read  = gadget_proc_read,
    .write = gadget_proc_write,
};

int gadget_pfs_init(void)
{
	struct proc_dir_entry *entry;

	memset(&desc_overwrite, 0, sizeof(struct gadget_desc_overwrite_table));
	entry = proc_create(COMPOSITE_PROCFS, (S_IRUGO | S_IWUGO), NULL, &gadget_proc_fops);
	if (entry == NULL) {
	    printk(KERN_ERR "composite proc_create\n");
	    return -ENOMEM;
	}
	return 0;
}

void gadget_pfs_exit(void)
{
	remove_proc_entry(COMPOSITE_PROCFS, NULL);

	gadget_pfs_overwrite_desc_free();
}
