complete but low mAP
This commit is contained in:
		
							parent
							
								
									86b50fb075
								
							
						
					
					
						commit
						89c3326a59
					
				
							
								
								
									
										35
									
								
								model.py
								
								
								
								
							
							
						
						
									
										35
									
								
								model.py
								
								
								
								
							| 
						 | 
				
			
			@ -230,9 +230,9 @@ class FastRCNN(nn.Module):
 | 
			
		|||
 | 
			
		||||
        # forward heads, get predicted cls scores & offsets
 | 
			
		||||
        cls_scores = self.cls_head(feat)
 | 
			
		||||
        print(cls_scores.shape)
 | 
			
		||||
        # print(cls_scores.shape)
 | 
			
		||||
        bbox_offsets = self.bbox_head(feat)
 | 
			
		||||
        print(bbox_offsets.shape)
 | 
			
		||||
        # print(bbox_offsets.shape)
 | 
			
		||||
        # get predicted boxes & class label & confidence probability
 | 
			
		||||
        proposals = generate_proposal(proposals, bbox_offsets)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -244,21 +244,28 @@ class FastRCNN(nn.Module):
 | 
			
		|||
 | 
			
		||||
            # filter by threshold
 | 
			
		||||
            cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
 | 
			
		||||
            print(cls_prob.shape)
 | 
			
		||||
            pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes)
 | 
			
		||||
            print(pos_mask.shape)
 | 
			
		||||
            proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
 | 
			
		||||
            print(proposals_img.shape)
 | 
			
		||||
 | 
			
		||||
            print(cls_prob.shape)
 | 
			
		||||
            final_proposals.append(proposals_img)
 | 
			
		||||
            final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1))
 | 
			
		||||
            # print(cls_prob)
 | 
			
		||||
            # print(torch.max(cls_prob, dim=1)[1].shape)
 | 
			
		||||
            # print(torch.max(cls_prob, dim=1)[0])
 | 
			
		||||
            # print(torch.max(cls_prob, dim=1)[1])
 | 
			
		||||
            
 | 
			
		||||
            pos_mask = torch.max(cls_prob, dim=1)[0] > thresh
 | 
			
		||||
            not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes
 | 
			
		||||
            # print(pos_mask)
 | 
			
		||||
            # print(not_bg_mask)
 | 
			
		||||
            total_mask = pos_mask & not_bg_mask
 | 
			
		||||
            # print(final_mask)
 | 
			
		||||
            # print(pos_mask.shape)
 | 
			
		||||
            proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask]
 | 
			
		||||
            conf_probs=torch.max(cls_prob, dim=1)[0][total_mask]
 | 
			
		||||
            class_idx = torch.max(cls_prob, dim=1)[1][total_mask]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            # nms
 | 
			
		||||
            keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh)
 | 
			
		||||
            proposals_img = proposals_img[keep]
 | 
			
		||||
            cls_prob = cls_prob[keep]
 | 
			
		||||
            keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh)
 | 
			
		||||
            final_proposals.append(proposals_obj[keep])
 | 
			
		||||
            final_conf_probs.append(conf_probs[keep].unsqueeze(1))
 | 
			
		||||
            final_class.append(class_idx[keep].unsqueeze(1))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        ##############################################################################
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue