complete but low mAP
This commit is contained in:
		
							parent
							
								
									86b50fb075
								
							
						
					
					
						commit
						89c3326a59
					
				
							
								
								
									
										33
									
								
								model.py
								
								
								
								
							
							
						
						
									
										33
									
								
								model.py
								
								
								
								
							|  | @ -230,9 +230,9 @@ class FastRCNN(nn.Module): | |||
| 
 | ||||
|         # forward heads, get predicted cls scores & offsets | ||||
|         cls_scores = self.cls_head(feat) | ||||
|         print(cls_scores.shape) | ||||
|         # print(cls_scores.shape) | ||||
|         bbox_offsets = self.bbox_head(feat) | ||||
|         print(bbox_offsets.shape) | ||||
|         # print(bbox_offsets.shape) | ||||
|         # get predicted boxes & class label & confidence probability | ||||
|         proposals = generate_proposal(proposals, bbox_offsets) | ||||
| 
 | ||||
|  | @ -244,21 +244,28 @@ class FastRCNN(nn.Module): | |||
| 
 | ||||
|             # filter by threshold | ||||
|             cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) | ||||
|             print(cls_prob.shape) | ||||
|             pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes) | ||||
|             print(pos_mask.shape) | ||||
|             proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] | ||||
|             print(proposals_img.shape) | ||||
|             # print(cls_prob) | ||||
|             # print(torch.max(cls_prob, dim=1)[1].shape) | ||||
|             # print(torch.max(cls_prob, dim=1)[0]) | ||||
|             # print(torch.max(cls_prob, dim=1)[1]) | ||||
|              | ||||
|             print(cls_prob.shape) | ||||
|             final_proposals.append(proposals_img) | ||||
|             final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1)) | ||||
|             pos_mask = torch.max(cls_prob, dim=1)[0] > thresh | ||||
|             not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes | ||||
|             # print(pos_mask) | ||||
|             # print(not_bg_mask) | ||||
|             total_mask = pos_mask & not_bg_mask | ||||
|             # print(final_mask) | ||||
|             # print(pos_mask.shape) | ||||
|             proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask] | ||||
|             conf_probs=torch.max(cls_prob, dim=1)[0][total_mask] | ||||
|             class_idx = torch.max(cls_prob, dim=1)[1][total_mask] | ||||
| 
 | ||||
| 
 | ||||
|             # nms | ||||
|             keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh) | ||||
|             proposals_img = proposals_img[keep] | ||||
|             cls_prob = cls_prob[keep] | ||||
|             keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh) | ||||
|             final_proposals.append(proposals_obj[keep]) | ||||
|             final_conf_probs.append(conf_probs[keep].unsqueeze(1)) | ||||
|             final_class.append(class_idx[keep].unsqueeze(1)) | ||||
| 
 | ||||
| 
 | ||||
|         ############################################################################## | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue