complete but low mAP
This commit is contained in:
		
							parent
							
								
									86b50fb075
								
							
						
					
					
						commit
						89c3326a59
					
				
							
								
								
									
										33
									
								
								model.py
								
								
								
								
							
							
						
						
									
										33
									
								
								model.py
								
								
								
								
							|  | @ -230,9 +230,9 @@ class FastRCNN(nn.Module): | ||||||
| 
 | 
 | ||||||
|         # forward heads, get predicted cls scores & offsets |         # forward heads, get predicted cls scores & offsets | ||||||
|         cls_scores = self.cls_head(feat) |         cls_scores = self.cls_head(feat) | ||||||
|         print(cls_scores.shape) |         # print(cls_scores.shape) | ||||||
|         bbox_offsets = self.bbox_head(feat) |         bbox_offsets = self.bbox_head(feat) | ||||||
|         print(bbox_offsets.shape) |         # print(bbox_offsets.shape) | ||||||
|         # get predicted boxes & class label & confidence probability |         # get predicted boxes & class label & confidence probability | ||||||
|         proposals = generate_proposal(proposals, bbox_offsets) |         proposals = generate_proposal(proposals, bbox_offsets) | ||||||
| 
 | 
 | ||||||
|  | @ -244,21 +244,28 @@ class FastRCNN(nn.Module): | ||||||
| 
 | 
 | ||||||
|             # filter by threshold |             # filter by threshold | ||||||
|             cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) |             cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) | ||||||
|             print(cls_prob.shape) |             # print(cls_prob) | ||||||
|             pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes) |             # print(torch.max(cls_prob, dim=1)[1].shape) | ||||||
|             print(pos_mask.shape) |             # print(torch.max(cls_prob, dim=1)[0]) | ||||||
|             proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] |             # print(torch.max(cls_prob, dim=1)[1]) | ||||||
|             print(proposals_img.shape) |  | ||||||
|              |              | ||||||
|             print(cls_prob.shape) |             pos_mask = torch.max(cls_prob, dim=1)[0] > thresh | ||||||
|             final_proposals.append(proposals_img) |             not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes | ||||||
|             final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1)) |             # print(pos_mask) | ||||||
|  |             # print(not_bg_mask) | ||||||
|  |             total_mask = pos_mask & not_bg_mask | ||||||
|  |             # print(final_mask) | ||||||
|  |             # print(pos_mask.shape) | ||||||
|  |             proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask] | ||||||
|  |             conf_probs=torch.max(cls_prob, dim=1)[0][total_mask] | ||||||
|  |             class_idx = torch.max(cls_prob, dim=1)[1][total_mask] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|             # nms |             # nms | ||||||
|             keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh) |             keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh) | ||||||
|             proposals_img = proposals_img[keep] |             final_proposals.append(proposals_obj[keep]) | ||||||
|             cls_prob = cls_prob[keep] |             final_conf_probs.append(conf_probs[keep].unsqueeze(1)) | ||||||
|  |             final_class.append(class_idx[keep].unsqueeze(1)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         ############################################################################## |         ############################################################################## | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue